## 1. Init


In [None]:
from google.colab import drive
import os
import shutil
import subprocess

drive.mount('/content/drive')

CODE_FOLDER = 'code'
DATA_FOLDER = 'data'
TEST_FOLDER = 'test'
PROJECT_DIR = '/content/drive/MyDrive/Text2SVG'
CODE_PATH = f"{PROJECT_DIR}/{CODE_FOLDER}"
DATA_PATH = f"{PROJECT_DIR}/{DATA_FOLDER}"
TEST_PATH = f"{PROJECT_DIR}/{TEST_FOLDER}"


## 2. Clone Code Repository


In [None]:
# Clean existing code directory and clone fresh repository
if os.path.exists(CODE_PATH):
    shutil.rmtree(CODE_PATH)

os.chdir(PROJECT_DIR)
result = subprocess.run(['git', 'clone', 'https://github.com/huanbasara/dual-branch-vae.git', CODE_FOLDER], 
                      capture_output=True, text=True)

if result.returncode != 0:
    raise Exception(f"Failed to clone repository: {result.stderr}")

print(f"✅ Code repository cloned successfully to {CODE_PATH}")

# Display latest commit information
os.chdir(CODE_PATH)
commit_info = subprocess.run(['git', 'log', '-1', '--pretty=format:%H|%ci|%s'], 
                           capture_output=True, text=True)

if commit_info.returncode == 0:
    hash_code, commit_time, commit_msg = commit_info.stdout.strip().split('|', 2)
    print(f"📦 Latest commit:")
    print(f"   Hash: {hash_code[:8]}")
    print(f"   Time: {commit_time}")
    print(f"   Message: {commit_msg}")
else:
    print("⚠️ Could not get commit info")

# Add code path to Python sys.path so we can import our modules
import sys
if CODE_PATH not in sys.path:
    sys.path.insert(0, CODE_PATH)
    print(f"✅ Added {CODE_PATH} to Python path")
else:
    print(f"✅ {CODE_PATH} already in Python path")

## 3. Install Dependencies

In [None]:
%pip install torch torchvision numpy scipy pandas scikit-learn matplotlib pillow svglib svgpathtools
%pip install kornia opencv-python cairosvg pyyaml easydict tqdm
# Install additional packages for SVG rendering support
%pip install svglib reportlab cssutils

# Quick module reload after code update
import sys

print("🔄 Reloading modules...")

# Clear custom modules from cache
modules_to_clear = [
    'pydiffvg_lite', 
    'models', 
    'data', 
    'utils',
    'pydiffvg_lite.path_utils',
    'data.my_svg_dataset_pts'
]

for base in modules_to_clear:
    to_remove = [m for m in sys.modules if m.startswith(base)]
    for m in to_remove:
        del sys.modules[m]

# Re-import key modules
import pydiffvg_lite as pydiffvg
from models.config import _DefaultConfig  
from models.model_pts_vae import SVGTransformer
from data.my_svg_dataset_pts import Normalize, SVGDataset_nopadding, SVGDataset_GoogleDrive
from utils.test_utils import pts_to_pathObj, save_paths_svg

print("✅ Modules reloaded!")

In [None]:
# Test specific SVG file: aardvark
svg_file = f"{DATA_PATH}/aardvark/66543-200.svg"
print(f"Using test SVG: {svg_file}")

if os.path.exists(svg_file):
    try:
        # Test svg_to_scene function
        canvas_width, canvas_height, shapes, shape_groups = pydiffvg_lite.svg_to_scene(svg_file)
        
        print(f"✅ SVG parsed successfully:")
        print(f"   Canvas: {canvas_width} x {canvas_height}")
        print(f"   Shapes: {len(shapes)}")
        print(f"   Groups: {len(shape_groups)}")
        
        # Get first shape's coordinate points
        if len(shapes) > 0:
            first_shape = shapes[0]
            print(f"   Shape type: {type(first_shape).__name__}")
            if hasattr(first_shape, 'points'):
                points = first_shape.points
                print(f"   Points: {points.shape}")
                print(f"   First points: {points[:5]}")
            
        print("✅ SVG to coordinates conversion successful!")
        
    except Exception as e:
        print(f"❌ SVG processing failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"❌ SVG file not found: {svg_file}")


In [None]:
# Test VAE Model: Generate both Image and SVG outputs (Fixed)
import torch
import yaml
import os
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from models.config import _DefaultConfig
from models.model_pts_vae import SVGTransformer
from data.my_svg_dataset_pts import Normalize, SVGDataset_nopadding
from utils.test_utils import pts_to_pathObj, save_paths_svg
import pydiffvg_lite as pydiffvg

print("🧠 Testing VAE Model: Image + SVG Generation (Fixed)...")

# Use the project's test SVG file
svg_file = f"{CODE_PATH}/data/vae_dataset/circle_10.svg"
print(f"Using test SVG: {svg_file}")

if not os.path.exists(svg_file):
    print(f"❌ SVG file not found: {svg_file}")
else:
    try:
        # 1. Load VAE model config
        cfg = _DefaultConfig()
        yaml_path = f"{CODE_PATH}/configs/vae_config_cmd_10.yaml"
        
        with open(yaml_path, 'r') as f:
            config_data = yaml.safe_load(f)
        
        for key, value in config_data.items():
            setattr(cfg, key, value)
        
        cfg.img_latent_dim = int(cfg.d_img_model / 64.0)
        cfg.vq_edim = int(cfg.dim_z / cfg.vq_comb_num)
        
        print(f"✅ Config loaded:")
        print(f"   use_model_fusion: {cfg.use_model_fusion}")
        print(f"   img_size: {cfg.img_size}")
        
        # 2. Create VAE model
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = SVGTransformer(cfg)
        model = model.to(device)
        model.eval()
        print(f"✅ Model created: {type(model).__name__} on {device}")
        
        # 3. Prepare dataset
        dataset_h, dataset_w = 224, 224
        svg_dir = f"{CODE_PATH}/data/vae_dataset"
        svg_filename = "circle_10.svg"
        
        os.makedirs(TEST_PATH, exist_ok=True)
        
        dataset = SVGDataset_nopadding(
            directory=svg_dir,
            h=dataset_h, w=dataset_w,
            fixed_length=cfg.max_pts_len_thresh,
            file_list=[svg_filename],
            img_dir=svg_dir,
            transform=Normalize(dataset_w, dataset_h),
            use_model_fusion=cfg.use_model_fusion
        )
        
        loader = DataLoader(dataset, batch_size=1, shuffle=False)
        print(f"✅ Dataset created: {len(dataset)} samples")
        
        # 4. Run model prediction
        for batch_data in loader:
            # Extract model inputs
            path_imgs = batch_data["path_img"].to(device)
            bat_s, _, _, _ = batch_data["cubics"].shape
            cubics_batch_fl = batch_data["cubics"].view(bat_s, -1, 2)
            data_pts = cubics_batch_fl.to(device)
            data_input = data_pts.unsqueeze(1)
            
            print(f"✅ Input data prepared:")
            print(f"   Path images: {path_imgs.shape}")
            print(f"   Data points: {data_input.shape}")
            
            # Forward pass
            with torch.no_grad():
                output = model(args_enc=data_input, args_dec=data_input, ref_img=path_imgs)
            
            print(f"✅ Model prediction completed")
            print(f"   Output keys: {list(output.keys())}")
            
            # Initialize all path variables
            input_path = f"{TEST_PATH}/input_image.png"
            rec_img_path = f"{TEST_PATH}/reconstructed_image.png"
            svg_output_path = f"{TEST_PATH}/reconstructed_shape.svg"
            coords_path = f"{TEST_PATH}/coordinates.txt"
            
            # 5. Save Input Image (reference)
            if path_imgs.shape[0] > 0:
                input_img_tensor = path_imgs[0].cpu()
                if input_img_tensor.shape[0] == 1:  # Grayscale
                    input_img_array = input_img_tensor.squeeze().numpy() * 255
                    input_img = Image.fromarray(input_img_array.astype('uint8'))
                else:  # RGB
                    input_img_array = input_img_tensor.permute(1, 2, 0).numpy() * 255
                    input_img = Image.fromarray(input_img_array.astype('uint8'))
                
                input_img.save(input_path)
                print(f"📁 Input image saved: {input_path}")
            
            # 6. Save Reconstructed Image from Image Decoder
            if "rec_img" in output:
                rec_img_tensor = output["rec_img"][0].cpu()
                print(f"   Reconstructed image shape: {rec_img_tensor.shape}")
                
                if len(rec_img_tensor.shape) == 3:  # CHW format
                    if rec_img_tensor.shape[0] == 1:  # Grayscale
                        rec_img_array = rec_img_tensor.squeeze().numpy()
                        rec_img_array = ((rec_img_array - rec_img_array.min()) / 
                                       (rec_img_array.max() - rec_img_array.min()) * 255)
                        rec_img = Image.fromarray(rec_img_array.astype('uint8'))
                    else:  # RGB
                        rec_img_array = rec_img_tensor.permute(1, 2, 0).numpy()
                        rec_img_array = ((rec_img_array - rec_img_array.min()) / 
                                       (rec_img_array.max() - rec_img_array.min()) * 255)
                        rec_img = Image.fromarray(rec_img_array.astype('uint8'))
                
                rec_img.save(rec_img_path)
                print(f"📁 Reconstructed image saved: {rec_img_path}")
            
            # 7. Convert Coordinates to SVG and Save
            if "args_logits" in output:
                generated_pts_batch = output["args_logits"]
                print(f"   Generated coordinates shape: {generated_pts_batch.shape}")
                
                # Get the first sample coordinates
                generated_coords = generated_pts_batch[0].squeeze().cpu()
                print(f"   Processed coordinates shape: {generated_coords.shape}")
                
                # Convert coordinates to path object
                try:
                    # Reshape coordinates for path creation
                    if len(generated_coords.shape) == 2:  # [num_points, 2]
                        convert_points = generated_coords
                    else:
                        convert_points = generated_coords.view(-1, 2)
                    
                    print(f"   Convert points shape: {convert_points.shape}")
                    
                    # Create path object
                    path_obj = pts_to_pathObj(convert_points)
                    print(f"✅ Path object created with {len(path_obj.points)} points")
                    
                    # Create colors for the SVG
                    fill_color = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)
                    stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)
                    
                    # Save as SVG file
                    save_paths_svg(
                        path_list=[path_obj],
                        fill_color_list=[fill_color],
                        stroke_width_list=[torch.tensor(1.0)],
                        stroke_color_list=[stroke_color],
                        svg_path_fp=svg_output_path,
                        canvas_height=dataset_h,
                        canvas_width=dataset_w
                    )
                    
                    print(f"📁 Reconstructed SVG saved: {svg_output_path}")
                    
                    # Save coordinates as text for inspection
                    with open(coords_path, 'w') as f:
                        f.write("VAE Generated Coordinates\n")
                        f.write("=" * 30 + "\n")
                        f.write(f"Shape: {generated_pts_batch.shape}\n")
                        f.write(f"Processed shape: {convert_points.shape}\n")
                        f.write("\nCoordinate Points:\n")
                        for i, pt in enumerate(convert_points[:20]):
                            f.write(f"Point {i:2d}: [{pt[0]:8.3f}, {pt[1]:8.3f}]\n")
                        if len(convert_points) > 20:
                            f.write(f"... and {len(convert_points) - 20} more points\n")
                    
                    print(f"📁 Coordinates saved: {coords_path}")
                    
                except Exception as e:
                    print(f"❌ SVG conversion failed: {e}")
                    import traceback
                    traceback.print_exc()
            
            break  # Process only first batch
        
        print("✅ VAE Dual Output Test Completed!")
        print("📊 Generated Files:")
        print(f"   🖼️  Input image: {input_path}")
        if "rec_img" in output:
            print(f"   🖼️  Reconstructed image: {rec_img_path}")
        if "args_logits" in output:
            print(f"   📄 Reconstructed SVG: {svg_output_path}")
            print(f"   📊 Coordinates data: {coords_path}")
        
        print("\n🎯 VAE dual-branch outputs successfully generated!")
        
    except Exception as e:
        print(f"❌ VAE model test failed: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Test Google Drive Dataset: Save samples as SVG and PNG files (Fixed coordinates)
import torch
import yaml
import os
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image

print("💾 Saving Google Drive Dataset samples with corrected coordinates...")

try:
    # Load VAE config independently
    from models.config import _DefaultConfig
    from data.my_svg_dataset_pts import SVGDataset_GoogleDrive, Normalize
    from utils.test_utils import pts_to_pathObj, save_paths_svg
    import pydiffvg_lite as pydiffvg
    
    cfg = _DefaultConfig()
    yaml_path = f'{CODE_PATH}/configs/vae_config_cmd_10.yaml'
    with open(yaml_path, 'r') as f:
        config_data = yaml.safe_load(f)
    for key, value in config_data.items():
        setattr(cfg, key, value)
    cfg.img_latent_dim = int(cfg.d_img_model / 64.0)
    cfg.vq_edim = int(cfg.dim_z / cfg.vq_comb_num)
    
    print(f"✅ Config loaded:")
    print(f"   - use_model_fusion: {cfg.use_model_fusion}")
    print(f"   - max_pts_len_thresh: {cfg.max_pts_len_thresh}")
    print(f"   - img_size: {cfg.img_size}")
    
    # Create Google Drive dataset for aardvark category
    gdrive_dataset = SVGDataset_GoogleDrive(
        data_path=DATA_PATH,
        h=224, w=224,
        fixed_length=cfg.max_pts_len_thresh,
        category="aardvark",
        file_list=["66543-200.svg"],
        transform=Normalize(224, 224),
        use_model_fusion=cfg.use_model_fusion
    )
    
    print(f"✅ Google Drive Dataset created: {len(gdrive_dataset)} path samples")
    
    # Create normalizer for coordinate transformation
    normalizer = Normalize(224, 224)
    
    # Create test directory
    os.makedirs(TEST_PATH, exist_ok=True)
    
    if len(gdrive_dataset) > 0:
        print(f"\n💾 Saving {len(gdrive_dataset)} samples to {TEST_PATH}")
        
        for i in range(len(gdrive_dataset)):
            sample = gdrive_dataset[i]
            
            # Get sample info
            svg_path = sample["filepaths"]
            path_idx = sample["path_index"]
            normalized_points = sample["points"]  # These are normalized (0-1 range)
            cubics = sample["cubics"]
            path_img = sample["path_img"]
            
            print(f"\n📄 Processing sample {i+1}:")
            print(f"   Source: {os.path.basename(svg_path)}, path {path_idx}")
            print(f"   Normalized points: {normalized_points.shape}")
            print(f"   Coordinate range: x[{normalized_points[:, 0].min():.3f}, {normalized_points[:, 0].max():.3f}], y[{normalized_points[:, 1].min():.3f}, {normalized_points[:, 1].max():.3f}]")
            
            # 🔧 CRITICAL FIX: Apply inverse transform to get pixel coordinates
            denormalized_points = normalizer.inverse_transform(normalized_points)
            print(f"   Denormalized points: {denormalized_points.shape}")
            print(f"   Pixel coordinates: x[{denormalized_points[:, 0].min():.1f}, {denormalized_points[:, 0].max():.1f}], y[{denormalized_points[:, 1].min():.1f}, {denormalized_points[:, 1].max():.1f}]")
            
            # 1. Save reconstructed SVG from coordinates
            try:
                # Use denormalized points for SVG creation
                path_obj = pts_to_pathObj(denormalized_points)
                
                # Create colors for the SVG
                device = "cuda" if torch.cuda.is_available() else "cpu"
                fill_color = torch.tensor([0.2, 0.2, 0.8, 1.0], device=device)  # Blue fill
                stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)  # Black stroke
                
                # Save as SVG file
                svg_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_fixed.svg"
                save_paths_svg(
                    path_list=[path_obj],
                    fill_color_list=[fill_color],
                    stroke_width_list=[torch.tensor(2.0)],
                    stroke_color_list=[stroke_color],
                    svg_path_fp=svg_output_path,
                    canvas_height=224,
                    canvas_width=224
                )
                
                print(f"   📁 SVG saved: {os.path.basename(svg_output_path)}")
                
                # Show first few coordinates in the generated SVG
                with open(svg_output_path, 'r') as f:
                    svg_content = f.read()
                path_start = svg_content.find('d="') + 3
                path_end = svg_content.find('"', path_start)
                path_data = svg_content[path_start:path_end]
                print(f"   📝 Generated SVG path (first 80 chars): {path_data[:80]}...")
                
            except Exception as e:
                print(f"   ❌ SVG save failed: {e}")
                import traceback
                traceback.print_exc()
            
            # 2. Save rendered image
            try:
                if len(path_img) > 0 and path_img.numel() > 0:
                    # Convert tensor to PIL Image
                    if path_img.dim() == 3:  # HWC format
                        img_array = path_img.numpy()
                        if img_array.shape[2] == 3:  # RGB
                            # Normalize to 0-255 range
                            img_array = (img_array * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array)
                        else:  # Single channel
                            img_array = (img_array[:,:,0] * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array, mode='L')
                    else:
                        # Handle other tensor formats
                        img_array = (path_img.squeeze().numpy() * 255).clip(0, 255).astype('uint8')
                        img_pil = Image.fromarray(img_array)
                    
                    # Save PNG file
                    png_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_fixed.png"
                    img_pil.save(png_output_path)
                    
                    print(f"   📁 PNG saved: {os.path.basename(png_output_path)} ({img_pil.size})")
                    
                else:
                    print(f"   ⚠️  No image data available (use_model_fusion=False)")
                    
            except Exception as e:
                print(f"   ❌ PNG save failed: {e}")
                import traceback
                traceback.print_exc()
        
        print(f"\n✅ All samples saved to: {TEST_PATH}")
        print(f"📊 Output files:")
        
        # List saved files
        saved_files = [f for f in os.listdir(TEST_PATH) if f.startswith('sample_') and 'fixed' in f]
        for f in sorted(saved_files):
            print(f"   📄 {f}")
            
    else:
        print("❌ No samples found in dataset")

except Exception as e:
    print(f"❌ Sample saving failed: {e}")
    import traceback
    traceback.print_exc()