## 1. Init


In [None]:
from google.colab import drive
import os
import shutil
import subprocess

drive.mount('/content/drive')

CODE_FOLDER = 'code'
DATA_FOLDER = 'data'
TEST_FOLDER = 'test'
PROJECT_DIR = '/content/drive/MyDrive/Text2SVG'
CODE_PATH = f"{PROJECT_DIR}/{CODE_FOLDER}"
DATA_PATH = f"{PROJECT_DIR}/{DATA_FOLDER}"
TEST_PATH = f"{PROJECT_DIR}/{TEST_FOLDER}"

## 2. Clone Code Repository


In [None]:
# Test Google Drive Dataset: Save samples as SVG and PNG files (Fixed coordinates)
import torch
import yaml
import os
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image

print("💾 Saving Google Drive Dataset samples with corrected coordinates...")

try:
    # Load VAE config independently
    from models.config import _DefaultConfig
    from data.my_svg_dataset_pts import SVGDataset_GoogleDrive, Normalize
    from utils.test_utils import pts_to_pathObj, save_paths_svg
    import pydiffvg_lite as pydiffvg
    
    cfg = _DefaultConfig()
    yaml_path = f'{CODE_PATH}/configs/vae_config_cmd_10.yaml'
    with open(yaml_path, 'r') as f:
        config_data = yaml.safe_load(f)
    for key, value in config_data.items():
        setattr(cfg, key, value)
    cfg.img_latent_dim = int(cfg.d_img_model / 64.0)
    cfg.vq_edim = int(cfg.dim_z / cfg.vq_comb_num)
    
    print(f"✅ Config loaded:")
    print(f"   - use_model_fusion: {cfg.use_model_fusion}")
    print(f"   - max_pts_len_thresh: {cfg.max_pts_len_thresh}")
    print(f"   - img_size: {cfg.img_size}")
    
    # Create Google Drive dataset for aardvark category
    gdrive_dataset = SVGDataset_GoogleDrive(
        data_path=DATA_PATH,
        h=224, w=224,
        fixed_length=cfg.max_pts_len_thresh,
        category="aardvark",
        file_list=["66543-200.svg"],
        transform=Normalize(224, 224),
        use_model_fusion=cfg.use_model_fusion
    )
    
    print(f"✅ Google Drive Dataset created: {len(gdrive_dataset)} path samples")
    
    # Create normalizer for coordinate transformation
    normalizer = Normalize(224, 224)
    
    # Create test directory
    os.makedirs(TEST_PATH, exist_ok=True)
    
    if len(gdrive_dataset) > 0:
        print(f"\n💾 Saving {len(gdrive_dataset)} samples to {TEST_PATH}")
        
        for i in range(len(gdrive_dataset)):
            sample = gdrive_dataset[i]
            
            # Get sample info
            svg_path = sample["filepaths"]
            path_idx = sample["path_index"]
            normalized_points = sample["points"]  # These are normalized (0-1 range)
            cubics = sample["cubics"]
            path_img = sample["path_img"]
            
            print(f"\n📄 Processing sample {i+1}:")
            print(f"   Source: {os.path.basename(svg_path)}, path {path_idx}")
            print(f"   Normalized points: {normalized_points.shape}")
            print(f"   Coordinate range: x[{normalized_points[:, 0].min():.3f}, {normalized_points[:, 0].max():.3f}], y[{normalized_points[:, 1].min():.3f}, {normalized_points[:, 1].max():.3f}]")
            
            # 🔧 CRITICAL FIX: Apply inverse transform to get pixel coordinates
            denormalized_points = normalizer.inverse_transform(normalized_points)
            print(f"   Denormalized points: {denormalized_points.shape}")
            print(f"   Pixel coordinates: x[{denormalized_points[:, 0].min():.1f}, {denormalized_points[:, 0].max():.1f}], y[{denormalized_points[:, 1].min():.1f}, {denormalized_points[:, 1].max():.1f}]")
            
            # 1. Save reconstructed SVG from coordinates
            try:
                # Use denormalized points for SVG creation
                path_obj = pts_to_pathObj(denormalized_points)
                
                # Create colors for the SVG
                device = "cuda" if torch.cuda.is_available() else "cpu"
                fill_color = torch.tensor([0.2, 0.2, 0.8, 1.0], device=device)  # Blue fill
                stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)  # Black stroke
                
                # Save as SVG file
                svg_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_fixed.svg"
                save_paths_svg(
                    path_list=[path_obj],
                    fill_color_list=[fill_color],
                    stroke_width_list=[torch.tensor(2.0)],
                    stroke_color_list=[stroke_color],
                    svg_path_fp=svg_output_path,
                    canvas_height=224,
                    canvas_width=224
                )
                
                print(f"   📁 SVG saved: {os.path.basename(svg_output_path)}")
                
                # Show first few coordinates in the generated SVG
                with open(svg_output_path, 'r') as f:
                    svg_content = f.read()
                path_start = svg_content.find('d="') + 3
                path_end = svg_content.find('"', path_start)
                path_data = svg_content[path_start:path_end]
                print(f"   📝 Generated SVG path (first 80 chars): {path_data[:80]}...")
                
            except Exception as e:
                print(f"   ❌ SVG save failed: {e}")
                import traceback
                traceback.print_exc()
            
            # 2. Save rendered image
            try:
                if len(path_img) > 0 and path_img.numel() > 0:
                    # Convert tensor to PIL Image
                    if path_img.dim() == 3:  # HWC format
                        img_array = path_img.numpy()
                        if img_array.shape[2] == 3:  # RGB
                            # Normalize to 0-255 range
                            img_array = (img_array * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array)
                        else:  # Single channel
                            img_array = (img_array[:,:,0] * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array, mode='L')
                    else:
                        # Handle other tensor formats
                        img_array = (path_img.squeeze().numpy() * 255).clip(0, 255).astype('uint8')
                        img_pil = Image.fromarray(img_array)
                    
                    # Save PNG file
                    png_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_fixed.png"
                    img_pil.save(png_output_path)
                    
                    print(f"   📁 PNG saved: {os.path.basename(png_output_path)} ({img_pil.size})")
                    
                else:
                    print(f"   ⚠️  No image data available (use_model_fusion=False)")
                    
            except Exception as e:
                print(f"   ❌ PNG save failed: {e}")
                import traceback
                traceback.print_exc()
        
        print(f"\n✅ All samples saved to: {TEST_PATH}")
        print(f"📊 Output files:")
        
        # List saved files
        saved_files = [f for f in os.listdir(TEST_PATH) if f.startswith('sample_') and 'fixed' in f]
        for f in sorted(saved_files):
            print(f"   📄 {f}")
            
    else:
        print("❌ No samples found in dataset")

except Exception as e:
    print(f"❌ Sample saving failed: {e}")
    import traceback
    traceback.print_exc()

## 3. Install Dependencies

In [None]:
%pip install torch torchvision numpy scipy pandas scikit-learn matplotlib pillow svglib svgpathtools
%pip install kornia opencv-python cairosvg pyyaml easydict tqdm
# Install additional packages for SVG rendering support
%pip install svglib reportlab cssutils

In [None]:
# Test Google Drive Dataset: Save samples as SVG and PNG files (Fixed coordinates and consistency)
import torch
import yaml
import os
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image

print("💾 Saving Google Drive Dataset samples with corrected coordinates and consistent rendering...")

try:
    # Load VAE config independently
    from models.config import _DefaultConfig
    from data.my_svg_dataset_pts import SVGDataset_GoogleDrive, Normalize
    from utils.test_utils import pts_to_pathObj, save_paths_svg
    import pydiffvg_lite as pydiffvg
    
    cfg = _DefaultConfig()
    yaml_path = f'{CODE_PATH}/configs/vae_config_cmd_10.yaml'
    with open(yaml_path, 'r') as f:
        config_data = yaml.safe_load(f)
    for key, value in config_data.items():
        setattr(cfg, key, value)
    cfg.img_latent_dim = int(cfg.d_img_model / 64.0)
    cfg.vq_edim = int(cfg.dim_z / cfg.vq_comb_num)
    
    print(f"✅ Config loaded:")
    print(f"   - use_model_fusion: {cfg.use_model_fusion}")
    print(f"   - max_pts_len_thresh: {cfg.max_pts_len_thresh}")
    print(f"   - img_size: {cfg.img_size}")
    
    # Create Google Drive dataset for aardvark category
    gdrive_dataset = SVGDataset_GoogleDrive(
        data_path=DATA_PATH,
        h=224, w=224,
        fixed_length=cfg.max_pts_len_thresh,
        category="aardvark",
        file_list=["66543-200.svg"],
        transform=Normalize(224, 224),
        use_model_fusion=cfg.use_model_fusion
    )
    
    print(f"✅ Google Drive Dataset created: {len(gdrive_dataset)} path samples")
    
    # Create normalizer for coordinate transformation
    normalizer = Normalize(224, 224)
    
    # Create test directory
    os.makedirs(TEST_PATH, exist_ok=True)
    
    if len(gdrive_dataset) > 0:
        print(f"\n💾 Saving {len(gdrive_dataset)} samples to {TEST_PATH}")
        
        for i in range(len(gdrive_dataset)):
            sample = gdrive_dataset[i]
            
            # Get sample info
            svg_path = sample["filepaths"]
            path_idx = sample["path_index"]
            normalized_points = sample["points"]  # These are normalized (0-1 range)
            cubics = sample["cubics"]
            path_img = sample["path_img"]
            
            print(f"\n📄 Processing sample {i+1}:")
            print(f"   Source: {os.path.basename(svg_path)}, path {path_idx}")
            print(f"   Normalized points: {normalized_points.shape}")
            print(f"   Coordinate range: x[{normalized_points[:, 0].min():.3f}, {normalized_points[:, 0].max():.3f}], y[{normalized_points[:, 1].min():.3f}, {normalized_points[:, 1].max():.3f}]")
            
            # 🔧 CRITICAL FIX: Apply inverse transform to get pixel coordinates
            denormalized_points = normalizer.inverse_transform(normalized_points)
            print(f"   Denormalized points: {denormalized_points.shape}")
            print(f"   Pixel coordinates: x[{denormalized_points[:, 0].min():.1f}, {denormalized_points[:, 0].max():.1f}], y[{denormalized_points[:, 1].min():.1f}, {denormalized_points[:, 1].max():.1f}]")
            
            # 1. Save reconstructed SVG from coordinates
            try:
                # Use denormalized points for SVG creation
                path_obj = pts_to_pathObj(denormalized_points)
                
                # Create colors for the SVG
                device = "cuda" if torch.cuda.is_available() else "cpu"
                fill_color = torch.tensor([0.2, 0.2, 0.8, 1.0], device=device)  # Blue fill
                stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0], device=device)  # Black stroke
                
                # Save as SVG file
                svg_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_consistent.svg"
                save_paths_svg(
                    path_list=[path_obj],
                    fill_color_list=[fill_color],
                    stroke_width_list=[torch.tensor(2.0)],
                    stroke_color_list=[stroke_color],
                    svg_path_fp=svg_output_path,
                    canvas_height=224,
                    canvas_width=224
                )
                
                print(f"   📁 SVG saved: {os.path.basename(svg_output_path)}")
                
                # Show first few coordinates in the generated SVG
                with open(svg_output_path, 'r') as f:
                    svg_content = f.read()
                path_start = svg_content.find('d="') + 3
                path_end = svg_content.find('"', path_start)
                path_data = svg_content[path_start:path_end]
                print(f"   📝 Generated SVG path (first 80 chars): {path_data[:80]}...")
                
            except Exception as e:
                print(f"   ❌ SVG save failed: {e}")
                import traceback
                traceback.print_exc()
            
            # 2. Save rendered image (now consistent with SVG)
            try:
                if len(path_img) > 0 and path_img.numel() > 0:
                    # Check tensor content
                    print(f"   🔍 Path image tensor: shape={path_img.shape}, min={path_img.min():.3f}, max={path_img.max():.3f}, non_zero={torch.count_nonzero(path_img)}")
                    
                    # Convert tensor to PIL Image
                    if path_img.dim() == 3:  # HWC format
                        img_array = path_img.numpy()
                        if img_array.shape[2] == 3:  # RGB
                            # Normalize to 0-255 range
                            img_array = (img_array * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array)
                        else:  # Single channel
                            img_array = (img_array[:,:,0] * 255).clip(0, 255).astype('uint8')
                            img_pil = Image.fromarray(img_array, mode='L')
                    else:
                        # Handle other tensor formats
                        img_array = (path_img.squeeze().numpy() * 255).clip(0, 255).astype('uint8')
                        img_pil = Image.fromarray(img_array)
                    
                    # Save PNG file
                    png_output_path = f"{TEST_PATH}/sample_{i+1}_path_{path_idx}_consistent.png"
                    img_pil.save(png_output_path)
                    
                    print(f"   📁 PNG saved: {os.path.basename(png_output_path)} ({img_pil.size})")
                    
                else:
                    print(f"   ⚠️  No image data available (use_model_fusion=False)")
                    
            except Exception as e:
                print(f"   ❌ PNG save failed: {e}")
                import traceback
                traceback.print_exc()
        
        print(f"\n✅ All samples saved to: {TEST_PATH}")
        print(f"📊 Output files:")
        
        # List saved files
        saved_files = [f for f in os.listdir(TEST_PATH) if f.startswith('sample_') and 'consistent' in f]
        for f in sorted(saved_files):
            print(f"   📄 {f}")
            
    else:
        print("❌ No samples found in dataset")

except Exception as e:
    print(f"❌ Sample saving failed: {e}")
    import traceback
    traceback.print_exc()