## 1. Init


In [None]:
from google.colab import drive
import os
import shutil
import subprocess

drive.mount('/content/drive')

CODE_FOLDER = 'code'
DATA_FOLDER = 'data'
TEST_FOLDER = 'test'
PROJECT_DIR = '/content/drive/MyDrive/Text2SVG'
CODE_PATH = f"{PROJECT_DIR}/{CODE_FOLDER}"
DATA_PATH = f"{PROJECT_DIR}/{DATA_FOLDER}"
TEST_PATH = f"{PROJECT_DIR}/{TEST_FOLDER}"


## 2. Clone Code Repository


In [None]:
# Clean existing code directory and clone fresh repository
if os.path.exists(CODE_PATH):
    shutil.rmtree(CODE_PATH)

os.chdir(PROJECT_DIR)
result = subprocess.run(['git', 'clone', 'https://github.com/huanbasara/dual-branch-vae.git', CODE_FOLDER], 
                      capture_output=True, text=True)

if result.returncode != 0:
    raise Exception(f"Failed to clone repository: {result.stderr}")

print(f"✅ Code repository cloned successfully to {CODE_PATH}")

# Display latest commit information
os.chdir(CODE_PATH)
commit_info = subprocess.run(['git', 'log', '-1', '--pretty=format:%H|%ci|%s'], 
                           capture_output=True, text=True)

if commit_info.returncode == 0:
    hash_code, commit_time, commit_msg = commit_info.stdout.strip().split('|', 2)
    print(f"📦 Latest commit:")
    print(f"   Hash: {hash_code[:8]}")
    print(f"   Time: {commit_time}")
    print(f"   Message: {commit_msg}")
else:
    print("⚠️ Could not get commit info")

# Add code path to Python sys.path so we can import our modules
import sys
if CODE_PATH not in sys.path:
    sys.path.insert(0, CODE_PATH)
    print(f"✅ Added {CODE_PATH} to Python path")
else:
    print(f"✅ {CODE_PATH} already in Python path")


## 3. Install Dependencies

In [None]:
%pip install torch torchvision numpy scipy pandas scikit-learn matplotlib pillow svglib svgpathtools
%pip install kornia opencv-python cairosvg pyyaml easydict tqdm
# Install additional packages for SVG rendering support
%pip install svglib reportlab cssutils

In [None]:
# Add code path to Python sys.path so we can import our modules
import sys
if CODE_PATH not in sys.path:
    sys.path.insert(0, CODE_PATH)
    print(f"✅ Added {CODE_PATH} to Python path")
else:
    print(f"✅ {CODE_PATH} already in Python path")


In [None]:
# Simple test: try to import pydiffvg_lite
try:
    import pydiffvg_lite
    print("✅ pydiffvg_lite import successful!")
except Exception as e:
    print(f"❌ pydiffvg_lite import failed: {e}")


In [None]:
# Test specific SVG file: aardvark
svg_file = f"{DATA_PATH}/aardvark/66543-200.svg"
print(f"Using test SVG: {svg_file}")

if os.path.exists(svg_file):
    try:
        # Test svg_to_scene function
        canvas_width, canvas_height, shapes, shape_groups = pydiffvg_lite.svg_to_scene(svg_file)
        
        print(f"✅ SVG parsed successfully:")
        print(f"   Canvas: {canvas_width} x {canvas_height}")
        print(f"   Shapes: {len(shapes)}")
        print(f"   Groups: {len(shape_groups)}")
        
        # Get first shape's coordinate points
        if len(shapes) > 0:
            first_shape = shapes[0]
            print(f"   Shape type: {type(first_shape).__name__}")
            if hasattr(first_shape, 'points'):
                points = first_shape.points
                print(f"   Points: {points.shape}")
                print(f"   First points: {points[:5]}")
            
        print("✅ SVG to coordinates conversion successful!")
        
    except Exception as e:
        print(f"❌ SVG processing failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"❌ SVG file not found: {svg_file}")


In [None]:
# Test reading SVG files from data folder and convert to coordinate points
import os

# Look for SVG files in data folder structure
print(f"Looking for SVG files in: {DATA_PATH}")

# Find first SVG file
svg_file = None
for root, dirs, files in os.walk(DATA_PATH):
    for file in files:
        if file.endswith('.svg'):
            svg_file = os.path.join(root, file)
            print(f"Found SVG file: {svg_file}")
            break
    if svg_file:
        break

if svg_file:
    try:
        # Test svg_to_scene function
        canvas_width, canvas_height, shapes, shape_groups = pydiffvg_lite.svg_to_scene(svg_file)
        
        print(f"✅ SVG parsed successfully:")
        print(f"   Canvas size: {canvas_width} x {canvas_height}")
        print(f"   Number of shapes: {len(shapes)}")
        print(f"   Number of shape groups: {len(shape_groups)}")
        
        # Get first shape's coordinate points
        if len(shapes) > 0:
            first_shape = shapes[0]
            print(f"   First shape type: {type(first_shape).__name__}")
            if hasattr(first_shape, 'points'):
                points = first_shape.points
                print(f"   Points shape: {points.shape}")
                print(f"   First few points: {points[:5]}")
            
        print("✅ SVG to coordinates conversion successful!")
        
    except Exception as e:
        print(f"❌ SVG processing failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("❌ No SVG files found in data folder")


In [None]:
# Test VAE model: SVG -> Model Input -> Training -> Output
print("🧠 Testing VAE model with our SVG...")

# Define test SVG file
svg_file = f"{DATA_PATH}/aardvark/66543-200.svg"

if os.path.exists(svg_file):
    try:
        # 1. Load VAE model config and create model
        import sys
        sys.path.append(f"{CODE_PATH}/deepsvg")
        
        from deepsvg.model.config import _DefaultConfig
        from deepsvg.model.model_pts_vae import SVGTransformer
        from deepsvg.my_svg_dataset_pts import Normalize, SVGDataset_nopadding
        from torch.utils.data import DataLoader
        import torch
        import yaml
        
        # Load config
        cfg = _DefaultConfig()
        yaml_path = f"{CODE_PATH}/configs/vae_config_cmd_10.yaml"
        with open(yaml_path, 'r') as f:
            config_data = yaml.safe_load(f)
        for key, value in config_data.items():
            setattr(cfg, key, value)
        cfg.img_latent_dim = int(cfg.d_img_model / 64.0)
        cfg.vq_edim = int(cfg.dim_z / cfg.vq_comb_num)
        
        # Create model
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = SVGTransformer(cfg)
        model = model.to(device)
        print(f"✅ Model created: {type(model).__name__}")
        
        # 2. Convert SVG to model input (like get_z_from_circle)
        dataset_h, dataset_w = 224, 224
        max_pts_len_thresh = cfg.max_pts_len_thresh
        
        # Create dataset with our SVG
        svg_dir = os.path.dirname(svg_file)
        svg_filename = os.path.basename(svg_file)
        
        dataset = SVGDataset_nopadding(
            directory=svg_dir, 
            h=dataset_h, w=dataset_w, 
            fixed_length=max_pts_len_thresh,
            file_list=[svg_filename],
            img_dir=svg_dir,
            transform=Normalize(dataset_w, dataset_h),
            use_model_fusion=cfg.use_model_fusion
        )
        
        loader = DataLoader(dataset, batch_size=1, shuffle=False)
        print(f"✅ Dataset created with {len(dataset)} samples")
        
        # 3. Get data and run model
        for batch_data in loader:
            # Extract inputs (like in get_z_from_circle)
            path_imgs = batch_data["path_img"].to(device)
            bat_s, _, _, _ = batch_data["cubics"].shape
            cubics_batch_fl = batch_data["cubics"].view(bat_s, -1, 2)
            data_pts = cubics_batch_fl.to(device)
            data_input = data_pts.unsqueeze(1)
            
            print(f"✅ Data prepared:")
            print(f"   Path images: {path_imgs.shape}")
            print(f"   Data points: {data_input.shape}")
            
            # 4. Forward pass (encoding + decoding)
            model.eval()
            with torch.no_grad():
                output = model(args_enc=data_input, args_dec=data_input, ref_img=path_imgs)
                
            print(f"✅ Model forward pass completed")
            print(f"   Output keys: {list(output.keys())}")
            
            # 5. Extract and save results
            if "args_logits" in output:
                generated_pts = output["args_logits"]
                print(f"   Generated points: {generated_pts.shape}")
                
            # Save path image as PNG
            if path_imgs.shape[0] > 0:
                os.makedirs(TEST_PATH, exist_ok=True)
                
                # Convert tensor to PIL image and save
                img_tensor = path_imgs[0].cpu()  # First image in batch
                if img_tensor.shape[0] == 1:  # Grayscale
                    img_array = img_tensor.squeeze().numpy() * 255
                    from PIL import Image
                    img = Image.fromarray(img_array.astype('uint8'), mode='L')
                else:  # RGB
                    img_array = img_tensor.permute(1, 2, 0).numpy() * 255
                    img = Image.fromarray(img_array.astype('uint8'), mode='RGB')
                
                output_path = f"{TEST_PATH}/vae_output.png"
                img.save(output_path)
                print(f"📁 VAE output saved to: {output_path}")
            
            break  # Process only first batch
            
        print("✅ VAE model test completed successfully!")
        
    except Exception as e:
        print(f"❌ VAE model test failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"❌ SVG file not found: {svg_file}")