In [2]:
!pip install -q gdown
!gdown --id 1vAuTUfkeRX045AMhUPcguxwov2dqAVtJ --output my_dataset.zip

import gzip
import shutil

with gzip.open('my_dataset.zip', 'rb') as f_in:
    with open('my_dataset', 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

import tarfile

with tarfile.open("my_dataset", "r:") as tar:
    tar.extractall("my_dataset_extracted")

Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=1vAuTUfkeRX045AMhUPcguxwov2dqAVtJ

but Gdown can't. Please check connections and permissions.


FileNotFoundError: [Errno 2] No such file or directory: 'my_dataset.zip'

In [None]:
# Cell 1: Install required packages
!pip install gigagan-pytorch==0.2.20 fvcore torch torchvision validators



In [None]:
import os
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from gigagan_pytorch import GigaGAN
import logging
import json
import time
import torchvision.utils as vutils
from fvcore.nn import parameter_count
import torch.nn.utils.prune as prune
from torch.cuda.amp import autocast

In [None]:
UNCONDITIONAL = True
IMAGE_SIZE = 256
BATCH_SIZE = 4
TRAINING_STEPS = 1000
FINETUNE_STEPS = 100
DATA_PATH = "my_dataset_extracted/GigaGAN_cond_imagenet256"
SAVE_DIR = "gigagan_pruned_results"
MODEL_FOLDER = "gigagan_pruning_checkpoints"
RESULTS_FOLDER = "gigagan_pruning_results"

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs("profiling_results", exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("gigagan_training.log"), logging.StreamHandler()]
)
logger = logging.getLogger("gigagan")

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

def load_imagenet_dataset(path):
    dataset = datasets.ImageFolder(root=path, transform=transform)
    logger.info(f"Loaded dataset with {len(dataset)} samples from {path}")
    return dataset

def collate_images_only(batch):
    images, _ = zip(*batch)
    return torch.stack(images)

In [None]:
def profile_dataloader(dataloader, num_batches=10):
    logger.info(f"Profiling dataloader for {num_batches} batches...")
    batch_times = []
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.cuda.reset_peak_memory_stats()
        start_memory = torch.cuda.memory_allocated() / (1024 * 1024)
    else:
        start_memory = 0

    start = time.time()
    for i, images in enumerate(dataloader):
        if i == num_batches:
            break
    total_time = time.time() - start
    avg_time = total_time / num_batches
    peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024) if use_cuda else 0

    stats = {
        "avg_batch_time_sec": avg_time,
        "peak_memory_mb": peak_memory,
        "start_memory_mb": start_memory
    }

    with open("profiling_results/dataloader_stats.json", "w") as f:
        json.dump(stats, f, indent=2)

    logger.info(f"Dataloader avg time: {avg_time:.4f}s, peak memory: {peak_memory:.2f} MB")
    return stats

In [None]:
def setup_model():
    logger.info("Setting up GigaGAN (Unconditional)...")
    gan = GigaGAN(
        train_upsampler=True,
        generator=dict(
            dim=32,
            style_network=dict(dim=64, depth=4),
            image_size=IMAGE_SIZE,
            input_image_size=64,
            unconditional=UNCONDITIONAL
        ),
        discriminator=dict(
            dim_capacity=16,
            dim_max=512,
            image_size=IMAGE_SIZE,
            multiscale_input_resolutions=(128,),
            num_skip_layers_excite=4,
            unconditional=UNCONDITIONAL
        ),
        learning_rate=1e-5,
        amp=True,
        model_folder=MODEL_FOLDER,
        results_folder=RESULTS_FOLDER
    ).cuda()

    total_params = sum(p.numel() for p in gan.parameters())
    trainable_params = sum(p.numel() for p in gan.parameters() if p.requires_grad)
    model_size_mb = sum(p.numel() * p.element_size() for p in gan.parameters()) / 1024**2

    logger.info(f"Model initialized: {total_params:,} params ({model_size_mb:.2f} MB)")
    return gan

In [None]:
def train_model(gan, dataloader, steps, save_name="model-final.ckpt"):
    gan.set_dataloader(dataloader)
    logger.info(f"Training for {steps} steps (AMP enabled)")
    with autocast():
        gan(steps=steps, grad_accum_every=8)
    save_path = os.path.join(MODEL_FOLDER, save_name)
    gan.save(save_path)
    logger.info(f"Model saved to {save_path}")


In [None]:
def count_params(model, label):
    gen_params = sum(p.numel() for p in model.unwrapped_G.parameters())
    disc_params = sum(p.numel() for p in model.unwrapped_D.parameters())
    total = gen_params + disc_params
    logger.info(f"{label} Generator: {gen_params/1e6:.2f}M, Discriminator: {disc_params/1e6:.2f}M, Total: {total/1e6:.2f}M")

def prune_gigagan_model(model, amount=0.3, target="both"):
    logger.info(f"Pruning GigaGAN ({target}) with {amount*100:.0f}% sparsity")
    if target in ["generator", "both"]:
        for _, module in model.unwrapped_G.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                prune.ln_structured(module, name="weight", amount=amount, n=1, dim=0)
    if target in ["discriminator", "both"]:
        for _, module in model.unwrapped_D.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                prune.ln_structured(module, name="weight", amount=amount, n=1, dim=0)
    logger.info("Pruning complete.")

def remove_pruning_reparam(model, target="both"):
    logger.info("Stripping pruning reparametrizations...")
    for net in ([model.unwrapped_G] if target == "generator" else [model.unwrapped_D, model.unwrapped_G]):
        for _, module in net.named_modules():
            if isinstance(module, torch.nn.Conv2d) and hasattr(module, "weight_mask"):
                prune.remove(module, "weight")

In [None]:
def generate_and_save_images(gan, num_images=3):
    input_size = gan.unwrapped_G.input_image_size
    for i in range(num_images):
        noise = torch.randn(1, 3, input_size, input_size).cuda()
        with torch.no_grad(), autocast():
            img = gan.generate(lowres_image=noise)[0].cpu()
        img = torch.nan_to_num(img, nan=0.0).clamp(-1, 1)
        img = (img * 0.5 + 0.5)
        vutils.save_image(img, f"{SAVE_DIR}/generated_{i}.png")
        logger.info(f"Saved: {SAVE_DIR}/generated_{i}.png")

In [None]:
def count_nonzero_params(model, label):
    def count_nonzero(p):
        return (p != 0).sum().item()

    gen_nonzero = sum(count_nonzero(p) for p in model.unwrapped_G.parameters() if p.requires_grad)
    disc_nonzero = sum(count_nonzero(p) for p in model.unwrapped_D.parameters() if p.requires_grad)
    total = gen_nonzero + disc_nonzero

    logger.info(f"{label} Generator: {gen_nonzero/1e6:.2f}M nonzero, Discriminator: {disc_nonzero/1e6:.2f}M nonzero, Total: {total/1e6:.2f}M nonzero")
    return gen_nonzero, disc_nonzero

In [None]:
def train_gigagan():
    dataset = load_imagenet_dataset(DATA_PATH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=collate_images_only)
    profile_dataloader(dataloader)
    gan = setup_model()
    count_params(gan, "Initial (Unpruned)")
    train_model(gan, dataloader, steps=TRAINING_STEPS, save_name="model-unpruned.ckpt")

# ---------------------------- Phase 2: Prune + Fine-tune ----------------------------
def prune_and_finetune_gigagan():
    dataset = load_imagenet_dataset(DATA_PATH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=collate_images_only)
    gan = setup_model()
    gan.load(os.path.join(MODEL_FOLDER, "model-unpruned.ckpt"))
    count_params(gan, "Before Pruning")
    prune_gigagan_model(gan, amount=0.3, target="both")
    remove_pruning_reparam(gan, target="both")
    count_params(gan, "After Pruning + Reparam Removal")
    gan.unwrapped_G.load_state_dict(gan.unwrapped_G.state_dict())  # safe sync for EMA
    train_model(gan, dataloader, steps=FINETUNE_STEPS, save_name="model-pruned.ckpt")
    generate_and_save_images(gan)

def iterative_pruning_workflow(iterations=5, initial_prune=0.1):
    dataset = load_imagenet_dataset(DATA_PATH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=collate_images_only)
    profile_dataloader(dataloader)

    gan = setup_model()
    train_model(gan, dataloader, steps=TRAINING_STEPS, save_name="model-unpruned.ckpt")
    gan.load(os.path.join(MODEL_FOLDER, "model-unpruned.ckpt"))

    cumulative_prune = 1.0
    for i in range(iterations):
        prune_step = 1 - initial_prune
        cumulative_prune *= prune_step
        logger.info(f"Iteration {i+1}: Pruning to cumulative {(1-cumulative_prune)*100:.1f}%")

        prune_gigagan_model(gan, amount=initial_prune, target="both")
        remove_pruning_reparam(gan, target="both")
        count_params(gan, f"Post-Prune Iteration {i+1}")

        gan.unwrapped_G.load_state_dict(gan.unwrapped_G.state_dict())
        train_model(gan, dataloader, steps=FINETUNE_STEPS, save_name=f"model-pruned-{i+1}.ckpt")

    generate_and_save_images(gan)

In [None]:
gan = setup_model()
# gan.load(os.path.join(MODEL_FOLDER, "model-unpruned.ckpt"))
# count_nonzero_params(gan, "After Training")



Generator: 43.71M
Discriminator: 30.74M




In [None]:
train_gigagan()

A100 GPU detected, using flash attention if input tensor is on cuda


Generator: 43.71M
Discriminator: 30.74M




  with autocast():
  self.gen = func(*args, **kwds)


G: -3.39 | MSG: -14.02 | VG: 0.00 | D: 5.51 | MSD: 15.46 | VD: 0.00 | GP: 0.00 | SSL: 14.49 | CL: 0.00 | MAL: 0.00


  2%|▏         | 21/1000 [00:58<46:30,  2.85s/it]

G: -281.25 | MSG: -203.91 | VG: 0.00 | D: 276.11 | MSD: 204.89 | VD: 0.00 | GP: 27868.06 | SSL: 13.34 | CL: 0.00 | MAL: 0.00


  4%|▍         | 41/1000 [01:52<45:38,  2.86s/it]

G: -481.83 | MSG: -144.95 | VG: 0.00 | D: 485.29 | MSD: 146.91 | VD: 0.00 | GP: 27182.68 | SSL: 16.27 | CL: 0.00 | MAL: 0.00


  6%|▌         | 61/1000 [02:46<44:53,  2.87s/it]

G: -700.80 | MSG: -125.20 | VG: 0.00 | D: 703.15 | MSD: 126.34 | VD: 0.00 | GP: 29783.50 | SSL: 12.87 | CL: 0.00 | MAL: 0.00


  8%|▊         | 81/1000 [03:40<43:45,  2.86s/it]

G: -1048.18 | MSG: -163.38 | VG: 0.00 | D: 1050.32 | MSD: 164.52 | VD: 0.00 | GP: 27097.93 | SSL: 11.86 | CL: 0.00 | MAL: 0.00


 10%|█         | 100/1000 [04:30<38:06,  2.54s/it]

G: -1085.39 | MSG: -155.58 | VG: 0.00 | D: 1084.97 | MSD: 156.82 | VD: 0.00 | GP: 28804.74 | SSL: 14.16 | CL: 0.00 | MAL: 0.00


 12%|█▏        | 121/1000 [05:30<41:39,  2.84s/it]

G: -1239.17 | MSG: -147.89 | VG: 0.00 | D: 1241.30 | MSD: 149.77 | VD: 0.00 | GP: 29105.22 | SSL: 13.67 | CL: 0.00 | MAL: 0.00


 14%|█▍        | 141/1000 [06:24<40:50,  2.85s/it]

G: -1427.19 | MSG: -151.59 | VG: 0.00 | D: 1428.10 | MSD: 153.31 | VD: 0.00 | GP: 29749.34 | SSL: 12.67 | CL: 0.00 | MAL: 0.00


 16%|█▌        | 161/1000 [07:18<40:47,  2.92s/it]

G: -1509.67 | MSG: -156.79 | VG: 0.00 | D: 1513.65 | MSD: 158.81 | VD: 0.00 | GP: 29453.61 | SSL: 15.30 | CL: 0.00 | MAL: 0.00


 18%|█▊        | 181/1000 [08:12<39:16,  2.88s/it]

G: -1623.84 | MSG: -161.05 | VG: 0.00 | D: 1628.78 | MSD: 162.49 | VD: 0.00 | GP: 29099.44 | SSL: 12.47 | CL: 0.00 | MAL: 0.00


 20%|██        | 200/1000 [09:03<34:23,  2.58s/it]

G: -1676.50 | MSG: -153.83 | VG: 0.00 | D: 1679.79 | MSD: 155.65 | VD: 0.00 | GP: 27498.29 | SSL: 11.69 | CL: 0.00 | MAL: 0.00


 22%|██▏       | 221/1000 [10:04<37:21,  2.88s/it]

G: -1793.17 | MSG: -155.31 | VG: 0.00 | D: 1796.67 | MSD: 156.88 | VD: 0.00 | GP: 25841.54 | SSL: 14.20 | CL: 0.00 | MAL: 0.00


 24%|██▍       | 241/1000 [10:58<36:38,  2.90s/it]

G: -1911.31 | MSG: -152.58 | VG: 0.00 | D: 1914.52 | MSD: 154.20 | VD: 0.00 | GP: 27447.73 | SSL: 12.04 | CL: 0.00 | MAL: 0.00


 26%|██▌       | 261/1000 [11:52<35:17,  2.87s/it]

G: -2006.81 | MSG: -154.91 | VG: 0.00 | D: 2010.58 | MSD: 156.56 | VD: 0.00 | GP: 29325.69 | SSL: 13.98 | CL: 0.00 | MAL: 0.00


 28%|██▊       | 281/1000 [12:46<34:34,  2.88s/it]

G: -2111.53 | MSG: -155.36 | VG: 0.00 | D: 2110.55 | MSD: 156.81 | VD: 0.00 | GP: 28528.46 | SSL: 11.09 | CL: 0.00 | MAL: 0.00


 30%|███       | 300/1000 [13:36<29:38,  2.54s/it]

G: -2206.39 | MSG: -161.36 | VG: 0.00 | D: 2211.28 | MSD: 163.08 | VD: 0.00 | GP: 29366.54 | SSL: 14.39 | CL: 0.00 | MAL: 0.00


 32%|███▏      | 321/1000 [14:37<32:15,  2.85s/it]

G: -2302.78 | MSG: -155.79 | VG: 0.00 | D: 2307.19 | MSD: 157.22 | VD: 0.00 | GP: 27709.11 | SSL: 13.59 | CL: 0.00 | MAL: 0.00


 34%|███▍      | 341/1000 [15:32<31:24,  2.86s/it]

G: -2426.56 | MSG: -157.76 | VG: 0.00 | D: 2427.71 | MSD: 159.29 | VD: 0.00 | GP: 29409.28 | SSL: 16.86 | CL: 0.00 | MAL: 0.00


 36%|███▌      | 361/1000 [16:25<30:28,  2.86s/it]

G: -2475.14 | MSG: -162.75 | VG: 0.00 | D: 2478.20 | MSD: 164.32 | VD: 0.00 | GP: 28859.73 | SSL: 14.15 | CL: 0.00 | MAL: 0.00


 38%|███▊      | 381/1000 [17:19<29:45,  2.88s/it]

G: -2594.19 | MSG: -157.59 | VG: 0.00 | D: 2594.54 | MSD: 159.15 | VD: 0.00 | GP: 27293.61 | SSL: 15.54 | CL: 0.00 | MAL: 0.00


 40%|████      | 400/1000 [18:10<25:30,  2.55s/it]

G: -2675.23 | MSG: -152.27 | VG: 0.00 | D: 2679.59 | MSD: 153.92 | VD: 0.00 | GP: 31326.24 | SSL: 12.14 | CL: 0.00 | MAL: 0.00


 42%|████▏     | 421/1000 [19:11<27:43,  2.87s/it]

G: -2733.28 | MSG: -158.94 | VG: 0.00 | D: 2734.12 | MSD: 160.54 | VD: 0.00 | GP: 28503.82 | SSL: 13.24 | CL: 0.00 | MAL: 0.00


 44%|████▍     | 441/1000 [20:05<26:50,  2.88s/it]

G: -2839.30 | MSG: -154.89 | VG: 0.00 | D: 2840.56 | MSD: 156.51 | VD: 0.00 | GP: 28207.96 | SSL: 12.96 | CL: 0.00 | MAL: 0.00


 45%|████▍     | 449/1000 [20:29<25:11,  2.74s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 382.88 MiB is free. Process 51716 has 39.17 GiB memory in use. Of the allocated memory 35.77 GiB is allocated by PyTorch, and 2.89 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
prune_and_finetune_gigagan()
gan = setup_model()
gan.load(os.path.join(MODEL_FOLDER, "model-unpruned.ckpt"))
count_nonzero_params(gan, "After Training")