
# Assignment: CycleGAN — Unpaired **Style Transfer** (Monet ↔ Photo)

**Course:** Generative AI  
**Topic:** Unpaired Image Translation with CycleGAN  
**Estimated Time:** 8–12 hours (coding + experiments + report)

---

## Overview
You will implement and analyze a **CycleGAN** to translate images between **Monet paintings** and **real photos** using *unpaired* data. Focus on getting the **losses** right (LSGAN adversarial, identity, cycle consistency) and on interpreting training behavior.

### Learning Objectives
- Explain why **cycle consistency** enables unpaired translation.  
- Implement **LSGAN** adversarial, **identity**, and **cycle** losses.  
- Diagnose artifacts (color shifts, texture leakage) and stabilize training.



## 1) Setup & Utilities
Run the cell below to import packages and set up helper functions.


In [None]:

import os, glob, random
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

def show_tensor_images(image_tensor, num_images=8, size=(3,256,256)):
    """Visualize a grid of images from a [-1,1] normalized tensor."""
    image_tensor = (image_tensor + 1) / 2
    grid = make_grid(image_tensor.detach().cpu().view(-1,*size)[:num_images], nrow=4)
    plt.imshow(grid.permute(1,2,0).squeeze())
    plt.axis("off")
    plt.show()



## 2) Dataset (Monet2Photo)
Place the dataset in the following structure:

```
monet2photo/
  trainA/  # Monet
  trainB/  # Photo
  testA/
  testB/
```

We will perform Resize→RandomCrop to 256×256 and RandomHorizontalFlip, then normalize to **[-1, 1]**.


In [None]:

class ImageDataset(Dataset):
    """Unpaired dataset loader for Monet2Photo with on-the-fly pairing."""
    def __init__(self, root, transform=None, mode='train'):
        self.transform = transform
        self.files_A = sorted(glob.glob(os.path.join(root, f'{mode}A/*.*')))
        self.files_B = sorted(glob.glob(os.path.join(root, f'{mode}B/*.*')))
        if len(self.files_A) > len(self.files_B):
            # swap to ensure A is not longer than B
            self.files_A, self.files_B = self.files_B, self.files_A
        self.new_perm()
        assert len(self.files_A) > 0, "Dataset not found. Ensure monet2photo/trainA and trainB exist!"

    def new_perm(self):
        self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]

    def __getitem__(self, idx):
        A = self.transform(Image.open(self.files_A[idx % len(self.files_A)]))
        B = self.transform(Image.open(self.files_B[self.randperm[idx]]))
        if A.shape[0] != 3: A = A.repeat(3,1,1)
        if B.shape[0] != 3: B = B.repeat(3,1,1)
        if idx == len(self) - 1: self.new_perm()
        return (A-0.5)*2, (B-0.5)*2

    def __len__(self):
        return min(len(self.files_A), len(self.files_B))

# Transforms & DataLoader
load_shape, target_shape = 286, 256
tfm = transforms.Compose([
    transforms.Resize(load_shape),
    transforms.RandomCrop(target_shape),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

dataset = ImageDataset("monet2photo", transform=tfm, mode="train")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
print("Batches per epoch:", len(dataloader))



## 3) Models
We use the standard CycleGAN generator with **Residual Blocks** and a **PatchGAN** discriminator. You do not need to change these classes.


In [None]:

class ResidualBlock(nn.Module):
    """Two convs + InstanceNorm + skip addition."""
    def __init__(self, in_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(in_ch)
        self.activation = nn.ReLU(inplace=True)
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.instancenorm(x)
        return identity + x

class ContractingBlock(nn.Module):
    def __init__(self, in_ch, use_bn=True, kernel_size=3, activation='relu'):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, in_ch*2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.use_bn = use_bn
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(in_ch*2)
        self.activation = nn.ReLU(inplace=True) if activation=='relu' else nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x = self.conv(x)
        if self.use_bn: x = self.instancenorm(x)
        return self.activation(x)

class ExpandingBlock(nn.Module):
    def __init__(self, in_ch, use_bn=True):
        super().__init__()
        self.convT = nn.ConvTranspose2d(in_ch, in_ch//2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.use_bn = use_bn
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(in_ch//2)
        self.activation = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.convT(x)
        if self.use_bn: x = self.instancenorm(x)
        return self.activation(x)

class FeatureMapBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=7, padding=3, padding_mode='reflect')
    def forward(self, x):
        return self.conv(x)

class Generator(nn.Module):
    """2 downs -> 9 res -> 2 ups; tanh output."""
    def __init__(self, in_ch, out_ch, base=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(in_ch, base)
        self.c1 = ContractingBlock(base)
        self.c2 = ContractingBlock(base*2)
        ch = base*4
        self.res = nn.Sequential(*[ResidualBlock(ch) for _ in range(9)])
        self.e1 = ExpandingBlock(ch)
        self.e2 = ExpandingBlock(base*2)
        self.downfeature = FeatureMapBlock(base, out_ch)
        self.tanh = nn.Tanh()
    def forward(self, x):
        x0 = self.upfeature(x)
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        xr = self.res(x2)
        xe1 = self.e1(xr)
        xe2 = self.e2(xe1)
        out = self.downfeature(xe2)
        return self.tanh(out)

class Discriminator(nn.Module):
    """PatchGAN: downsampling conv stack -> 1x1 conv to logits map."""
    def __init__(self, in_ch, base=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(in_ch, base)
        self.c1 = ContractingBlock(base, use_bn=False, kernel_size=4, activation='lrelu')
        self.c2 = ContractingBlock(base*2, kernel_size=4, activation='lrelu')
        self.c3 = ContractingBlock(base*4, kernel_size=4, activation='lrelu')
        self.final = nn.Conv2d(base*8, 1, kernel_size=1)
    def forward(self, x):
        x0 = self.upfeature(x)
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        return self.final(x3)

# Instantiate models
gen_AB = Generator(3,3).to(device)  # Monet -> Photo
gen_BA = Generator(3,3).to(device)  # Photo -> Monet
disc_A = Discriminator(3).to(device) # Domain A (Monet)
disc_B = Discriminator(3).to(device) # Domain B (Photo)

# Losses & optimizers
adv_criterion   = nn.MSELoss()
recon_criterion = nn.L1Loss()
lr = 2e-4
gen_opt  = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=lr, betas=(0.5,0.999))
discA_opt = torch.optim.Adam(disc_A.parameters(), lr=lr, betas=(0.5,0.999))
discB_opt = torch.optim.Adam(disc_B.parameters(), lr=lr, betas=(0.5,0.999))

print("Models ready.")



## 4) **Student TODOs** — Implement Losses

Complete the functions below. We use **LSGAN** objectives (MSE to target 1/0), **identity L1**, and **cycle L1**.


In [None]:

# =============== STUDENT TODO 1: Discriminator Loss (LSGAN) ===============
def get_disc_loss(real_X, fake_X, disc_X, adv_criterion):
    """Return LSGAN discriminator loss: real→1, fake→0 (average the two)."""
    # TODO: Implement using the discriminator outputs on real_X and fake_X (detach fake).
    disc_fake = disc_X(fake_X.detach())
    loss_fake = adv_criterion(disc_fake, torch.zeros_like(disc_fake))
    disc_real = disc_X(real_X)
    loss_real = adv_criterion(disc_real, torch.ones_like(disc_real))
    return (loss_fake + loss_real) / 2

# Quick sanity check (shapes)
with torch.no_grad():
    dummy = torch.randn(1,3,256,256).to(device)
    _ = get_disc_loss(dummy, dummy, disc_A, adv_criterion)
print("Discriminator loss: OK (basic shape check).")


In [None]:

# =============== STUDENT TODO 2: Generator Adversarial Loss (LSGAN) ===============
def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
    """Return (loss, fake_Y) where loss = MSE(disc_Y(gen_XY(real_X)) → 1)."""
    # TODO: Implement forward through generator then discriminator to target ones.
    fake_Y = gen_XY(real_X)
    adv_loss = adv_criterion(disc_Y(fake_Y), torch.ones_like(disc_Y(fake_Y)))
    return adv_loss, fake_Y

print("Generator adversarial loss: OK (compiles).")


In [None]:

# =============== STUDENT TODO 3: Identity Loss (L1) ===============
def get_identity_loss(real_X, gen_YX, identity_criterion):
    """Identity mapping: if X is already in domain X, gen_YX(real_X) ≈ real_X."""
    # TODO: Implement identity pass and L1 to original.
    identity_X = gen_YX(real_X)
    loss = identity_criterion(identity_X, real_X)
    return loss, identity_X

print("Identity loss: OK (compiles).")


In [None]:

# =============== STUDENT TODO 4: Cycle Consistency Loss (L1) ===============
def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
    """Cycle: X->Y (fake_Y), then Y->X should reconstruct real_X."""
    # TODO: Implement cycle pass and L1 to original.
    cycle_X = gen_YX(fake_Y)
    loss = cycle_criterion(cycle_X, real_X)
    return loss, cycle_X

print("Cycle loss: OK (compiles).")


In [None]:

# =============== STUDENT TODO 5: Total Generator Loss ===============
def get_gen_loss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B,
                 adv_criterion, identity_criterion, cycle_criterion,
                 lambda_identity=0.1, lambda_cycle=10):
    """
    Total gen loss = adv_AB + adv_BA
                   + λ_id*(id_A + id_B)
                   + λ_cyc*(cyc_A + cyc_B)
    Return (total_loss, fake_A, fake_B).
    """
    # TODO: Combine all parts using helper functions above.
    adv_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
    adv_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)

    id_A, _ = get_identity_loss(real_A, gen_BA, identity_criterion)
    id_B, _ = get_identity_loss(real_B, gen_AB, identity_criterion)

    cyc_A, _ = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
    cyc_B, _ = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)

    total = (adv_AB + adv_BA) + lambda_identity*(id_A + id_B) + lambda_cycle*(cyc_A + cyc_B)
    return total, fake_A, fake_B

print("Total generator loss: OK (compiles).")



## 5) Training
Use the loop below. Start with 10–20 epochs to verify learning. Increase if you have GPU time.


In [None]:

def train(n_epochs=20, display_step=200, save_model=False):
    step = 0
    mean_gen, mean_disc = 0.0, 0.0
    for epoch in range(n_epochs):
        for real_A, real_B in dataloader:
            real_A = real_A.to(device); real_B = real_B.to(device)

            # ---- Update Discriminator A ----
            discA_opt.zero_grad(set_to_none=True)
            with torch.no_grad():
                fake_A = gen_BA(real_B)
            dA_loss = get_disc_loss(real_A, fake_A, disc_A, adv_criterion)
            dA_loss.backward()
            discA_opt.step()

            # ---- Update Discriminator B ----
            discB_opt.zero_grad(set_to_none=True)
            with torch.no_grad():
                fake_B = gen_AB(real_A)
            dB_loss = get_disc_loss(real_B, fake_B, disc_B, adv_criterion)
            dB_loss.backward()
            discB_opt.step()

            # ---- Update Generators ----
            gen_opt.zero_grad(set_to_none=True)
            g_loss, fake_A, fake_B = get_gen_loss(
                real_A, real_B, gen_AB, gen_BA, disc_A, disc_B,
                adv_criterion, recon_criterion, recon_criterion,
                lambda_identity=0.1, lambda_cycle=10
            )
            g_loss.backward()
            gen_opt.step()

            mean_disc += (dA_loss.item()+dB_loss.item())/2 / display_step
            mean_gen  += g_loss.item()/display_step

            if step % display_step == 0:
                print(f"Epoch {epoch} Step {step} | G: {mean_gen:.4f} D: {mean_disc:.4f}")
                show_tensor_images(torch.cat([real_A, real_B], dim=0), size=(3,256,256))
                show_tensor_images(torch.cat([fake_B, fake_A], dim=0), size=(3,256,256))
                mean_gen = 0.0; mean_disc = 0.0
                if save_model:
                    torch.save({
                        'gen_AB': gen_AB.state_dict(), 'gen_BA': gen_BA.state_dict(),
                        'gen_opt': gen_opt.state_dict(),
                        'disc_A': disc_A.state_dict(), 'disc_A_opt': discA_opt.state_dict(),
                        'disc_B': disc_B.state_dict(), 'disc_B_opt': discB_opt.state_dict()
                    }, f"cycleGAN_{step}.pth")
            step += 1

# Uncomment to train
# train(n_epochs=20, display_step=200, save_model=False)



## 6) Required Experiments (Ablations)
Perform **two** ablations:

1) **Cycle weight sweep:** `λ_cyc ∈ {5, 10, 20}` (keep `λ_id=0.1`).  
   - Qualitative: Compare structure retention vs. stylization strength.  
   - Quantitative proxy: mean **L1 cycle error** on 50 samples per setting.

2) **Identity on/off:** `λ_id ∈ {0.0, 0.1}` (keep `λ_cyc=10`).  
   - Comment on **color fidelity** and hue preservation.

*(Optional)* Try a stabilizer (fake-image replay buffer, one-sided label smoothing, LR decay) and note the effect.



## 7) Report Expectations (2–4 pages, professional)

**1. Method (≤1 page)**  
- Briefly describe CycleGAN (two generators, two discriminators, cycle & identity).  
- Include short equations for **LSGAN**, **identity**, and **cycle** losses.

**2. Experimental Setup**  
- Dataset (Monet vs Photo), image size, augmentations.  
- Hyperparameters: λ_cyc, λ_id, optimizer (Adam β₁=0.5, β₂=0.999), LR, hardware.

**3. Results**  
- Grids: `A→B` (Monet→Photo) and `B→A` (Photo→Monet) at two training points.  
- **Cycle reconstructions**: `A→B→A`, `B→A→B` (5–8 examples).  
- Loss curves (G and D).  
- L1 cycle error table for ablations.

**4. Discussion**  
- Effect of `λ_cyc` on faithfulness vs stylization.  
- Effect of `λ_id` on color preservation.  
- Artifacts and any stabilizers tried.

**5. Conclusion**  
- Key takeaways on unpaired translation and loss trade-offs; contrast with paired Pix2Pix.
