In [1]:
# Cell 1: Install required packages
!pip install datasets huggingface_hub torch torchvision validators transformers fvcore

Defaulting to user installation because normal site-packages is not writeable


In [2]:
#Colab path
#!pip install /content/gigagan_pytorch-0.3.9-py3-none-any.whl

#Jupyter path
#!pip install gigagan_pytorch-0.4.2-py3-none-any.whl

!pip install gigagan-pytorch==0.2.20

Defaulting to user installation because normal site-packages is not writeable


In [3]:
!pip install future-annotations

Defaulting to user installation because normal site-packages is not writeable


In [4]:
!python --version

Python 3.9.18


In [5]:
# Cell 2: Import libraries
import os
import torch
from datasets import load_dataset
from __future__ import annotations
from gigagan_pytorch import GigaGAN
from huggingface_hub import login
from PIL import Image
import requests
from io import BytesIO
#import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import validators
import gc
import time
import json
from pathlib import Path

# Import profiling tools
import torch.profiler as profiler
from torch.profiler import profile, record_function, ProfilerActivity
from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count

# Create profiling directory
os.makedirs("profiling_results", exist_ok=True)

# Setup logging
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('gigagan_training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('gigagan')

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
login(token="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")

In [7]:
# Cell 4: Setup image transformation
UNCONDITIONAL = True
IMAGE_SIZE = 256

# Cell 2: Setup image transformation
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [8]:
def fetch_image(url, caption=None):
    """Fetch an image from a URL and apply transformations.
    Returns just the image if unconditional, or (image, caption) if conditional.
    """
    try:
        if not validators.url(url):
            return None if UNCONDITIONAL else (None, None)
        
        response = requests.get(url, timeout=5)
        if response.status_code != 200:
            return None if UNCONDITIONAL else (None, None)
            
        img = Image.open(BytesIO(response.content)).convert('RGB')
        if min(img.size) < 64:  # Filter out tiny images
            return None if UNCONDITIONAL else (None, None)
            
        transformed_img = transform(img)
        
        if UNCONDITIONAL:
            return transformed_img
        else:
            return transformed_img, caption
    except Exception as e:
        logger.error(f"Error fetching image: {e}")
        return None if UNCONDITIONAL else (None, None)

In [9]:
def build_dataset(num_samples=1000):
    """Build dataset with progress tracking for either conditional or unconditional GAN"""
    logger.info(f"Building {'unconditional' if UNCONDITIONAL else 'conditional'} dataset with {num_samples} samples...")

    dataset = load_dataset("phiyodr/coco2017", split="train", streaming=True)
    stream_iter = iter(dataset)
    
    # Will hold either just images or (image, caption) pairs depending on UNCONDITIONAL flag
    samples = []

    # Start timing dataset creation
    start_time = time.time()

    while len(samples) < num_samples:
        if len(samples) % 100 == 0:
            logger.info(f"Collected {len(samples)}/{num_samples} samples...")

        try:
            sample = next(stream_iter)
            url = sample.get("coco_url")

            # Extract caption (only needed if conditional)
            caption = None
            if not UNCONDITIONAL:
                if 'captions' in sample and isinstance(sample['captions'], list) and len(sample['captions']) > 0:
                    caption = sample['captions'][0]
                elif 'caption' in sample:
                    caption = sample['caption']
                    if isinstance(caption, list) and len(caption) > 0:
                        caption = caption[0]
                else:
                    caption = "A photo"

            # Fetch image (and caption if conditional)
            result = fetch_image(url, caption)
            
            if result is not None:
                samples.append(result)
                
        except StopIteration:
            logger.warning("Dataset exhausted")
            break
        except Exception as e:
            logger.error(f"Error processing sample: {e}")
            continue

    dataset_time = time.time() - start_time
    logger.info(f"Dataset built in {dataset_time:.2f}s with {len(samples)} samples")

    # Save dataset statistics
    dataset_stats = {
        "mode": "unconditional" if UNCONDITIONAL else "conditional",
        "num_samples": len(samples),
        "build_time_seconds": dataset_time,
        "avg_time_per_sample": dataset_time / len(samples) if samples else 0
    }

    with open("profiling_results/dataset_stats.json", "w") as f:
        json.dump(dataset_stats, f, indent=2)

    return samples

In [10]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, samples):
        """Initialize dataset with samples
        
        Args:
            samples: Either list of image tensors (unconditional) or list of (image, caption) pairs (conditional)
        """
        self.samples = samples
        self.unconditional = UNCONDITIONAL

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.unconditional:
            return self.samples[idx]  # Just return the image tensor
        else:
            img, caption = self.samples[idx]
            return img, caption  # Return image and caption pair

def collate_fn(batch):
    """Collate function that handles both conditional and unconditional batches"""
    if UNCONDITIONAL:
        # For unconditional, batch is just a list of image tensors
        return torch.stack(batch)
    else:
        # For conditional, batch is a list of (image, caption) tuples
        images = []
        captions = []

        for img, caption in batch:
            images.append(img)
            captions.append(caption)

        images = torch.stack(images)
        return images, captions

In [11]:
# Cell 8: Dataloader profiling function
def profile_dataloader(dataloader, num_batches=10):
    """Profile dataloader performance"""
    logger.info(f"Profiling dataloader for {num_batches} batches...")

    batch_times = []
    batch_memory = []

    # Reset memory stats
    torch.cuda.reset_peak_memory_stats()
    start_memory = torch.cuda.memory_allocated() / (1024 * 1024)  # MB

    # Time batches
    for i, (images, captions) in enumerate(dataloader):
        if i == 0:
            # First batch may include initialization overhead
            batch_start = time.time()
            continue

        batch_end = time.time()
        batch_time = batch_end - batch_start
        batch_times.append(batch_time)

        # Track memory
        current_memory = torch.cuda.memory_allocated() / (1024 * 1024)  # MB
        batch_memory.append(current_memory)

        logger.info(f"Batch {i}: loaded in {batch_time:.4f}s, memory: {current_memory:.2f} MB")

        batch_start = time.time()

        if i >= num_batches:
            break

    # Calculate stats
    avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
    peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)  # MB

    # Record results
    dataloader_stats = {
        "avg_batch_time_seconds": avg_batch_time,
        "batches_per_second": 1 / avg_batch_time if avg_batch_time > 0 else 0,
        "starting_memory_mb": start_memory,
        "peak_memory_mb": peak_memory,
        "memory_increase_mb": peak_memory - start_memory
    }

    logger.info(f"Dataloader avg time: {avg_batch_time:.4f}s per batch")
    logger.info(f"Dataloader peak memory: {peak_memory:.2f} MB")

    with open("profiling_results/dataloader_stats.json", "w") as f:
        json.dump(dataloader_stats, f, indent=2)

    return dataloader_stats

In [12]:
def setup_model():
    """Set up GigaGAN model with the new configuration"""
    logger.info(f"Setting up {'unconditional' if UNCONDITIONAL else 'conditional'} GigaGAN model...")

    # Create the model with updated configuration
    '''
    gan = GigaGAN(
        generator=dict(
            dim_capacity=16,
            style_network=dict(
                dim=128,
                depth=4
            ),
            image_size=IMAGE_SIZE,
            dim_max=512,
            num_skip_layers_excite=4,
            unconditional=UNCONDITIONAL
        ),
        discriminator=dict(
            dim_capacity=32,
            dim_max=512,
            image_size=IMAGE_SIZE,
            num_skip_layers_excite=4,
            unconditional=UNCONDITIONAL
        ),
        learning_rate = 1e-6,
        accelerate_kwargs = {'gradient_accumulation_steps': 8},
        diff_augment=dict(
        prob=0.5,
        horizontal_flip=True,
        horizontal_flip_prob=0.5
        ),
        apply_gradient_penalty_every=16,  # Less frequent gradient penalty
        multiscale_divergence_loss_weight=0.15,  # Increased from default
        discr_aux_recon_loss_weight=0.8,
        amp=True  # Enable mixed precision
    ).cuda()

    gan.load('gigagan-models/model-5.ckpt')
    '''
    gan = GigaGAN(
        train_upsampler = True,
        generator = dict(
            dim = 32,                  # Use dim instead of dim_capacity for UnetUpsampler
            style_network = dict(
                dim = 64,
                depth = 4
            ),
            image_size = IMAGE_SIZE,          # Output resolution
            input_image_size = 64,     # Input resolution to upsample from
            unconditional = True
        ),
        discriminator = dict(
            dim_capacity = 16,
            dim_max = 512,
            image_size = IMAGE_SIZE,          # Match with generator output size
            multiscale_input_resolutions = (128,),  # Intermediate resolution
            num_skip_layers_excite = 4,
            unconditional = True
        ),
        learning_rate=1e-5,
        model_folder='./gigagan-modified-upsampler-models',
        results_folder='./gigagan-modified-upsampler-results',
        amp = True
    ).cuda()
    gan.load('gigagan-modified-upsampler-models/model-11.ckpt')

    # Profile model architecture
    gen_params = sum(p.numel() for p in gan.unwrapped_G.parameters())
    disc_params = sum(p.numel() for p in gan.unwrapped_D.parameters())
    total_params = sum(p.numel() for p in gan.parameters())
    trainable_params = sum(p.numel() for p in gan.parameters() if p.requires_grad)

    # Estimate model size
    param_size_bytes = sum(p.numel() * p.element_size() for p in gan.parameters())
    buffer_size_bytes = sum(b.numel() * b.element_size() for b in gan.buffers())
    model_size_mb = (param_size_bytes + buffer_size_bytes) / (1024 * 1024)

    # Record architecture stats
    architecture_stats = {
        "mode": "unconditional" if UNCONDITIONAL else "conditional",
        "generator_parameters": gen_params,
        "discriminator_parameters": disc_params,
        "total_parameters": total_params,
        "trainable_parameters": trainable_params,
        "model_size_mb": model_size_mb,
        "generator_percentage": gen_params / total_params * 100,
        "discriminator_percentage": disc_params / total_params * 100
    }

    logger.info(f"Model architecture: {gen_params:,} generator params, {disc_params:,} discriminator params")
    logger.info(f"Total parameters: {total_params:,} ({trainable_params:,} trainable)")
    logger.info(f"Model size: {model_size_mb:.2f} MB")

    with open("profiling_results/model_architecture.json", "w") as f:
        json.dump(architecture_stats, f, indent=2)

    return gan

In [13]:
def train_model(gan, dataloader, steps=100, grad_accum_every=8):
    """Train the GigaGAN model"""
    logger.info(f"Training {'unconditional' if UNCONDITIONAL else 'conditional'} GigaGAN for {steps} steps...")
    
    # Set dataloader
    gan.set_dataloader(dataloader)
    
    # Train for specified steps
    gan(steps=steps, grad_accum_every=grad_accum_every)
    
    logger.info(f"Training completed for {steps} steps")
    
    return gan

In [14]:
def generate_images(gan, batch_size=4, captions=None):
    """Generate images using the trained model
    
    For unconditional GAN, only batch_size is needed
    For conditional GAN, both batch_size and captions are needed
    """
    if UNCONDITIONAL:
        logger.info(f"Generating {batch_size} images unconditionally")
        with torch.no_grad():
            input_size = gan.unwrapped_G.input_image_size
            lowres = torch.randn(1, 3, input_size, input_size).cuda()
            images = gan.generate(lowres_image=lowres)
    else:
        if captions is None or len(captions) < batch_size:
            logger.error("Captions must be provided for conditional generation")
            return None
            
        logger.info(f"Generating {batch_size} images with captions")
        with torch.no_grad():
            images = gan.generate(batch_size=batch_size, texts=captions[:batch_size])
    
    return images


In [15]:
def save_images(images, save_dir="gigagan-256-results"):
    """Save generated images with proper denormalization"""
    os.makedirs(save_dir, exist_ok=True)

    # Create a transform to convert tensor to PIL image
    to_pil = transforms.ToPILImage()

    for i, img in enumerate(images):
        # Properly denormalize from [-1, 1] to [0, 1]
        img = torch.nan_to_num(img, nan=0.0)
        img = (img.clamp(-1, 1) * 0.5 + 0.5)

        # Convert to PIL and save
        pil_img = to_pil(img.cpu())
        filename = f"generated_image_{i}.png"
        filepath = os.path.join(save_dir, filename)

        logger.info(f"Saving image to {filepath}")
        pil_img.save(filepath)

In [16]:
def run_gigagan_workflow(unconditional=True, num_samples=1000, training_steps=100):
    """Run the complete GigaGAN workflow with the specified mode"""
    global UNCONDITIONAL
    UNCONDITIONAL = unconditional
    
    logger.info(f"Starting GigaGAN workflow in {'unconditional' if UNCONDITIONAL else 'conditional'} mode")
    
    # 1. Build dataset
    samples = build_dataset(num_samples=num_samples)
    
    # 2. Create dataset and dataloader
    dataset = ImageDataset(samples)
    dataloader = DataLoader(
        dataset,
        batch_size=4,  # Use small batch size due to memory constraints
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4
    )
    
    # 3. Setup model
    gan = setup_model()
    
    
    # 5. Train model
    gan = train_model(gan, dataloader, steps=training_steps, grad_accum_every=8)

    
    # 7. Generate and save images
    if UNCONDITIONAL:
        images = generate_images(gan, batch_size=4)
    else:
        # Sample captions for conditional generation
        test_captions = [
            "A dog running in a park",
            "A beautiful sunset over the ocean",
            "A cat sleeping on a sofa",
            "A mountain landscape with snow"
        ]
        images = generate_images(gan, batch_size=4, captions=test_captions)
    
    if images is not None:
        save_images(images)
    
    logger.info("GigaGAN workflow completed successfully")

In [None]:
# For unconditional training:
run_gigagan_workflow(unconditional=True, num_samples=2000, training_steps=5000)

# For conditional training:
# run_gigagan_workflow(unconditional=False, num_samples=2000, training_steps=100)

INFO:gigagan:Starting GigaGAN workflow in unconditional mode
INFO:gigagan:Building unconditional dataset with 2000 samples...
INFO:gigagan:Collected 0/2000 samples...
INFO:gigagan:Collected 100/2000 samples...
INFO:gigagan:Collected 200/2000 samples...
INFO:gigagan:Collected 300/2000 samples...
INFO:gigagan:Collected 400/2000 samples...
INFO:gigagan:Collected 500/2000 samples...
INFO:gigagan:Collected 600/2000 samples...
INFO:gigagan:Collected 700/2000 samples...
INFO:gigagan:Collected 800/2000 samples...
INFO:gigagan:Collected 900/2000 samples...
INFO:gigagan:Collected 1000/2000 samples...
INFO:gigagan:Collected 1100/2000 samples...
INFO:gigagan:Collected 1200/2000 samples...
INFO:gigagan:Collected 1300/2000 samples...
INFO:gigagan:Collected 1400/2000 samples...
INFO:gigagan:Collected 1500/2000 samples...
INFO:gigagan:Collected 1600/2000 samples...
INFO:gigagan:Collected 1700/2000 samples...
INFO:gigagan:Collected 1800/2000 samples...
INFO:gigagan:Collected 1900/2000 samples...
INFO:g

A100 GPU detected, using flash attention if input tensor is on cuda


Generator: 43.71M
Discriminator: 30.74M




INFO:gigagan:Model architecture: 43,712,753 generator params, 30,737,601 discriminator params
INFO:gigagan:Total parameters: 118,163,107 (74,450,354 trainable)
INFO:gigagan:Model size: 450.76 MB
INFO:gigagan:Training unconditional GigaGAN for 5000 steps...
  self.gen = func(*args, **kwds)


G: 1.47 | MSG: -0.01 | VG: 0.00 | D: 1.90 | MSD: 2.00 | VD: 0.00 | GP: 0.24 | SSL: 0.08 | CL: 0.00 | MAL: 0.00


11021it [01:19,  3.04s/it]

G: 1.73 | MSG: -0.00 | VG: 0.00 | D: 0.98 | MSD: 2.00 | VD: 0.00 | GP: 0.29 | SSL: 0.10 | CL: 0.00 | MAL: 0.00


11041it [02:15,  2.95s/it]

G: 1.09 | MSG: 0.00 | VG: 0.00 | D: 1.16 | MSD: 2.00 | VD: 0.00 | GP: 0.18 | SSL: 0.08 | CL: 0.00 | MAL: 0.00


11061it [03:11,  2.96s/it]

G: 1.80 | MSG: -0.01 | VG: 0.00 | D: 1.13 | MSD: 2.00 | VD: 0.00 | GP: 0.17 | SSL: 0.10 | CL: 0.00 | MAL: 0.00


11081it [04:07,  2.96s/it]

G: 0.94 | MSG: -0.00 | VG: 0.00 | D: 0.84 | MSD: 2.00 | VD: 0.00 | GP: 0.18 | SSL: 0.09 | CL: 0.00 | MAL: 0.00


11101it [05:03,  2.94s/it]

G: 1.39 | MSG: -0.00 | VG: 0.00 | D: 0.93 | MSD: 2.00 | VD: 0.00 | GP: 0.20 | SSL: 0.07 | CL: 0.00 | MAL: 0.00


11121it [05:59,  2.96s/it]

G: 0.82 | MSG: -0.02 | VG: 0.00 | D: 1.42 | MSD: 2.00 | VD: 0.00 | GP: 0.21 | SSL: 0.10 | CL: 0.00 | MAL: 0.00


11141it [06:54,  2.91s/it]

G: 2.00 | MSG: -0.01 | VG: 0.00 | D: 0.72 | MSD: 2.00 | VD: 0.00 | GP: 0.18 | SSL: 0.10 | CL: 0.00 | MAL: 0.00


11144it [07:01,  2.60s/it]