In [2]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

In [3]:
!pip install wandb -q

In [4]:
import wandb
wandb.login()

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import wandb
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import shutil
import re

In [None]:
# ==================== CONFIGURATION ====================
class Config:
    DATASET_PATH = '/content/drive/MyDrive/cs236/assignments/assignment1/datasets/'
    CHECKPOINT_DIR = '/content/drive/MyDrive/cs236/assignments/assignment1/checkpoints/'

    IMG_SIZE = 256
    BATCH_SIZE = 4
    NUM_EPOCHS = 30
    SAVE_EVERY = 1

    LR_G = 0.0002
    LR_D = 0.0002
    BETA1 = 0.5
    BETA2 = 0.999

    LOSS_TYPE = 'hinge'

    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)


# ==================== DATASET ====================
class MonetPhotoDataset(Dataset):
    def __init__(self, monet_dir, photo_dir, transform=None):
        self.monet_paths = sorted(list(Path(monet_dir).glob('*.jpg')))
        self.photo_paths = sorted(list(Path(photo_dir).glob('*.jpg')))
        self.transform = transform

        print(f"Found {len(self.monet_paths)} Monet images")
        print(f"Found {len(self.photo_paths)} Photo images")

    def __len__(self):
        return max(len(self.monet_paths), len(self.photo_paths))

    def __getitem__(self, idx):
        monet_img = Image.open(self.monet_paths[idx % len(self.monet_paths)]).convert('RGB')
        photo_img = Image.open(self.photo_paths[idx % len(self.photo_paths)]).convert('RGB')

        if self.transform:
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)

        return monet_img, photo_img


# ==================== GENERATOR ====================
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64, n_residual=9):
        super().__init__()

        self.initial = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, features, 7),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True)
        )

        self.down_blocks = nn.ModuleList()
        curr_dim = features
        for _ in range(2):
            self.down_blocks.append(nn.Sequential(
                nn.Conv2d(curr_dim, curr_dim * 2, 3, stride=2, padding=1),
                nn.InstanceNorm2d(curr_dim * 2),
                nn.ReLU(inplace=True)
            ))
            curr_dim *= 2

        self.res_blocks = nn.Sequential(
            *[ResidualBlock(curr_dim) for _ in range(n_residual)]
        )

        self.up_blocks = nn.ModuleList()
        for _ in range(2):
            self.up_blocks.append(nn.Sequential(
                nn.ConvTranspose2d(curr_dim, curr_dim // 2, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(curr_dim // 2),
                nn.ReLU(inplace=True)
            ))
            curr_dim //= 2

        self.output = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(curr_dim, out_channels, 7),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        for down in self.down_blocks:
            x = down(x)
        x = self.res_blocks(x)
        for up in self.up_blocks:
            x = up(x)
        return self.output(x)


# ==================== DISCRIMINATOR ====================
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()

        def discriminator_block(in_f, out_f, normalize=True):
            layers = [nn.Conv2d(in_f, out_f, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, features, normalize=False),
            *discriminator_block(features, features * 2),
            *discriminator_block(features * 2, features * 4),
            *discriminator_block(features * 4, features * 8),
            nn.Conv2d(features * 8, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)


# ==================== LOSSES ====================
class LSGANLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def discriminator_loss(self, real_pred, fake_pred):
        real_loss = self.mse(real_pred, torch.ones_like(real_pred))
        fake_loss = self.mse(fake_pred, torch.zeros_like(fake_pred))
        return (real_loss + fake_loss) * 0.5

    def generator_loss(self, fake_pred):
        return self.mse(fake_pred, torch.ones_like(fake_pred))

class HingeLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def discriminator_loss(self, real_pred, fake_pred):
        real_loss = torch.mean(torch.relu(1.0 - real_pred))
        fake_loss = torch.mean(torch.relu(1.0 + fake_pred))
        return (real_loss + fake_loss) * 0.5

    def generator_loss(self, fake_pred):
        return -torch.mean(fake_pred)


def get_adversarial_loss(loss_type):
    if loss_type == 'lsgan':
        return LSGANLoss()
    elif loss_type == 'hinge':
        return HingeLoss()

def cycle_loss(real_img, reconstructed_img):
    return nn.L1Loss()(real_img, reconstructed_img)

def identity_loss(real_img, same_img):
    return nn.L1Loss()(real_img, same_img)


# ==================== TRAINING FUNCTIONS ====================
def save_checkpoint(epoch, G_M2P, G_P2M, D_M, D_P, opt_G, opt_D, loss_type, checkpoint_dir):
    checkpoint = {
        'epoch': epoch,
        'G_M2P_state': G_M2P.state_dict(),
        'G_P2M_state': G_P2M.state_dict(),
        'D_M_state': D_M.state_dict(),
        'D_P_state': D_P.state_dict(),
        'opt_G_state': opt_G.state_dict(),
        'opt_D_state': opt_D.state_dict(),
        'loss_type': loss_type
    }

    final_path = os.path.join(checkpoint_dir, f'{loss_type}_epoch_{epoch}.pth')
    temp_path = final_path + '.tmp'

    try:
        torch.save(checkpoint, temp_path)
        shutil.move(temp_path, final_path)
        print(f"Checkpoint saved: {final_path}")
    except Exception as e:
        print(f"Failed to save checkpoint: {e}")
        if os.path.exists(temp_path):
            os.remove(temp_path)


def load_checkpoint(checkpoint_path, G_M2P, G_P2M, D_M, D_P, opt_G, opt_D):
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        G_M2P.load_state_dict(checkpoint['G_M2P_state'])
        G_P2M.load_state_dict(checkpoint['G_P2M_state'])
        D_M.load_state_dict(checkpoint['D_M_state'])
        D_P.load_state_dict(checkpoint['D_P_state'])
        opt_G.load_state_dict(checkpoint['opt_G_state'])
        opt_D.load_state_dict(checkpoint['opt_D_state'])
        return checkpoint['epoch']
    except Exception as e:
        raise RuntimeError(f"Failed to load checkpoint: {e}")


def find_valid_checkpoint(checkpoint_dir, loss_type):
    checkpoints = []

    for ckpt_path in Path(checkpoint_dir).glob(f'{loss_type}_epoch_*.pth'):
        match = re.search(r'epoch_(\d+)\.pth', ckpt_path.name)
        if match:
            epoch = int(match.group(1))
            checkpoints.append((epoch, ckpt_path))

    checkpoints.sort(key=lambda x: x[0], reverse=True)

    for epoch, ckpt_path in checkpoints:
        try:
            checkpoint = torch.load(str(ckpt_path), map_location='cpu')
            if 'epoch' in checkpoint and 'G_M2P_state' in checkpoint:
                return str(ckpt_path)
        except:
            print(f"Skipping corrupted checkpoint: {ckpt_path}")

    return None


def visualize_results(real_monet, fake_photo, real_photo, fake_monet, epoch):
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))

    def denorm(x):
        return (x * 0.5 + 0.5).clamp(0, 1)

    images = [
        (real_monet[0], "Real Monet"),
        (fake_photo[0], "Fake Photo (M→P)"),
        (real_photo[0], "Real Photo"),
        (fake_monet[0], "Fake Monet (P→M)"),
    ]

    for idx, (img, title) in enumerate(images):
        img_np = denorm(img).cpu().permute(1, 2, 0).numpy()
        axes[0, idx].imshow(img_np)
        axes[0, idx].set_title(title)
        axes[0, idx].axis('off')

    if len(real_monet) > 1:
        for idx in range(min(4, len(real_monet))):
            if idx == 0:
                img_np = denorm(fake_photo[idx]).cpu().permute(1, 2, 0).numpy()
            elif idx == 1:
                img_np = denorm(fake_monet[idx]).cpu().permute(1, 2, 0).numpy()
            elif idx == 2:
                img_np = denorm(real_monet[idx]).cpu().permute(1, 2, 0).numpy()
            else:
                img_np = denorm(real_photo[idx]).cpu().permute(1, 2, 0).numpy()
            axes[1, idx].imshow(img_np)
            axes[1, idx].axis('off')

    plt.tight_layout()
    return fig


# ==================== MAIN TRAINING ====================
def train():
    wandb.init(
        project="cs236-gan-assignment",
        name=f"cyclegan-{Config.LOSS_TYPE}",
        config={
            'loss_type': Config.LOSS_TYPE,
            'img_size': Config.IMG_SIZE,
            'batch_size': Config.BATCH_SIZE,
            'num_epochs': Config.NUM_EPOCHS,
            'lr_g': Config.LR_G,
            'lr_d': Config.LR_D,
            'beta1': Config.BETA1,
            'beta2': Config.BETA2
        }
    )

    transform = transforms.Compose([
        transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    dataset = MonetPhotoDataset(
        monet_dir=os.path.join(Config.DATASET_PATH, 'monet_jpg'),
        photo_dir=os.path.join(Config.DATASET_PATH, 'photo_jpg'),
        transform=transform
    )

    dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)

    G_M2P = Generator().to(Config.DEVICE)
    G_P2M = Generator().to(Config.DEVICE)
    D_M = Discriminator().to(Config.DEVICE)
    D_P = Discriminator().to(Config.DEVICE)

    opt_G = optim.Adam(
        list(G_M2P.parameters()) + list(G_P2M.parameters()),
        lr=Config.LR_G, betas=(Config.BETA1, Config.BETA2)
    )
    opt_D = optim.Adam(
        list(D_M.parameters()) + list(D_P.parameters()),
        lr=Config.LR_D, betas=(Config.BETA1, Config.BETA2)
    )

    adv_loss = get_adversarial_loss(Config.LOSS_TYPE)

    start_epoch = 0
    valid_checkpoint = find_valid_checkpoint(Config.CHECKPOINT_DIR, Config.LOSS_TYPE)

    if valid_checkpoint:
        print(f"Found valid checkpoint: {valid_checkpoint}")
        try:
            start_epoch = load_checkpoint(valid_checkpoint, G_M2P, G_P2M, D_M, D_P, opt_G, opt_D) + 1
            print(f"Resuming from epoch {start_epoch}")
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            print("Starting from scratch...")
            start_epoch = 0
    else:
        print("No valid checkpoint found. Starting from scratch...")


    for epoch in range(start_epoch, Config.NUM_EPOCHS):
        G_M2P.train()
        G_P2M.train()
        D_M.train()
        D_P.train()

        epoch_losses = {'D_loss': 0, 'G_loss': 0, 'cycle_loss': 0, 'identity_loss': 0, 'adv_loss': 0}

        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for batch_idx, (real_monet, real_photo) in enumerate(pbar):
            real_monet = real_monet.to(Config.DEVICE)
            real_photo = real_photo.to(Config.DEVICE)

            # ==================== Train Discriminators ====================
            opt_D.zero_grad()

            fake_monet = G_P2M(real_photo)
            pred_real_monet = D_M(real_monet)
            pred_fake_monet = D_M(fake_monet.detach())
            D_M_loss = adv_loss.discriminator_loss(pred_real_monet, pred_fake_monet)

            fake_photo = G_M2P(real_monet)
            pred_real_photo = D_P(real_photo)
            pred_fake_photo = D_P(fake_photo.detach())
            D_P_loss = adv_loss.discriminator_loss(pred_real_photo, pred_fake_photo)

            D_loss = D_M_loss + D_P_loss
            D_loss.backward()
            opt_D.step()

            # ==================== Train Generators ====================
            opt_G.zero_grad()

            pred_fake_monet = D_M(fake_monet)
            pred_fake_photo = D_P(fake_photo)
            G_M2P_adv_loss = adv_loss.generator_loss(pred_fake_photo)
            G_P2M_adv_loss = adv_loss.generator_loss(pred_fake_monet)
            total_adv_loss = G_M2P_adv_loss + G_P2M_adv_loss

            reconstructed_monet = G_P2M(fake_photo)
            reconstructed_photo = G_M2P(fake_monet)
            cycle_M = cycle_loss(real_monet, reconstructed_monet)
            cycle_P = cycle_loss(real_photo, reconstructed_photo)
            total_cycle_loss = (cycle_M + cycle_P) * 10.0

            identity_monet = G_P2M(real_monet)
            identity_photo = G_M2P(real_photo)
            identity_M = identity_loss(real_monet, identity_monet)
            identity_P = identity_loss(real_photo, identity_photo)
            total_identity_loss = (identity_M + identity_P) * 5.0

            G_loss = G_M2P_adv_loss + G_P2M_adv_loss + total_cycle_loss + total_identity_loss
            G_loss.backward()
            opt_G.step()

            epoch_losses['D_loss'] += D_loss.item()
            epoch_losses['G_loss'] += G_loss.item()
            epoch_losses['cycle_loss'] += total_cycle_loss.item()
            epoch_losses['identity_loss'] += total_identity_loss.item()
            epoch_losses['adv_loss'] += total_adv_loss.item()

            pbar.set_postfix({
                'D': f"{D_loss.item():.4f}",
                'G': f"{G_loss.item():.4f}",
                'Cyc': f"{total_cycle_loss.item():.4f}"
            })

        for key in epoch_losses:
            epoch_losses[key] /= len(dataloader)

        wandb.log({
            'epoch': epoch + 1,
            **epoch_losses
        })

        with torch.no_grad():
            G_M2P.eval()
            G_P2M.eval()
            fake_photo = G_M2P(real_monet)
            fake_monet = G_P2M(real_photo)
            fig = visualize_results(real_monet, fake_photo, real_photo, fake_monet, epoch)
            wandb.log({"generated_images": wandb.Image(fig)})
            plt.close(fig)

        if (epoch + 1) % Config.SAVE_EVERY == 0:
            save_checkpoint(epoch, G_M2P, G_P2M, D_M, D_P, opt_G, opt_D, Config.LOSS_TYPE, Config.CHECKPOINT_DIR)

        print(f"Epoch {epoch+1}: D_loss={epoch_losses['D_loss']:.4f}, "
              f"G_loss={epoch_losses['G_loss']:.4f}, "
              f"Cycle={epoch_losses['cycle_loss']:.4f}")

    save_checkpoint(Config.NUM_EPOCHS - 1, G_M2P, G_P2M, D_M, D_P, opt_G, opt_D, Config.LOSS_TYPE, Config.CHECKPOINT_DIR)
    wandb.finish()
    print("Training complete!")

In [7]:
# ==================== RUN ====================
if __name__ == "__main__":
    print(f"Using device: {Config.DEVICE}")
    print(f"Loss type: {Config.LOSS_TYPE}")
    print(f"Batch size: {Config.BATCH_SIZE}")

    train()
