# CycleGAN Experiment 1: Interactive Training with Visualizations

This notebook implements the CycleGAN training loop interactively, allowing for data visualization, train/val/test splitting, and real-time monitoring of generated images.

**Features:**
- Saves checkpoints to Google Drive every epoch.
- **Auto-Resuming:** Checks for existing checkpoints in Drive and resumes training if found.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# GitHub Configuration & Setup
import os

# --- CONFIGURATION ---
# Please double-check this path matches where you cloned the repo in your Drive
REPO_PATH = '/content/drive/MyDrive/Kaggle_GANS_I-m-Something-of-a-Painter-Myself_Competition'

try:
    # 1. Change to the repository directory
    if os.path.exists(REPO_PATH):
        %cd "$REPO_PATH"
        print(f"Changed directory to {REPO_PATH}")
    else:
        print(f"Warning: Could not find repository at {REPO_PATH}.")
        print("Please update 'REPO_PATH' variable to point to your cloned repository folder.")
        # Attempting to list MyDrive to help user find the folder
        print("Listing folders in MyDrive to help identify path:")
        !ls -d /content/drive/MyDrive/*/

    # 2. Configure Git
    user_name = "konstantine25b"
    mail = "konstantine25b@gmail.com"
    repo_url = "https://github.com/konstantine25b/Kaggle_GANS_I-m-Something-of-a-Painter-Myself_Competition.git"

    !git config --global user.name "{user_name}"
    !git config --global user.email "{mail}"
    
    # 3. Set Remote URL
    # Only run this if we are in a git repo
    if os.path.isdir(".git"):
        # Uncomment the next line if you need to set the remote URL with a token manually
        # !git remote set-url origin "{repo_url}"
        print("Git configured successfully.")
    else:
        print("Current directory is not a git repository. Skipping remote setup.")
        
except Exception as e:
    print(f"Error setting up GitHub: {e}")

In [None]:
# Install requirements
!pip install -r requirements.txt

In [None]:
import sys
import os
import glob
import re
import random
import time
import itertools
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.utils as vutils

# Add src to path so we can import modules
sys.path.append('src')

from models.generator.resnet_gan import ResNetGenerator
from models.discriminator.patch_gan import PatchDiscriminator
from utils.dataset import ImageDataset, get_transforms
from utils.helpers import ReplayBuffer, weights_init_normal

## 1. Data Loading and Splitting

We will load the file paths and split them into Train, Validation, and Test sets.

In [None]:
# Configuration
# Note: Adjust these paths if your data is located elsewhere (e.g., in Drive or downloaded)
MONET_PATH = 'data/monet_jpg'
PHOTO_PATH = 'data/photo_jpg'
CHECKPOINT_DIR = '/content/drive/MyDrive/MonetGAN_Checkpoints_Exp1'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

BATCH_SIZE = 1
N_EPOCHS = 5
LR = 0.0002
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

In [None]:
# Get file paths
monet_files = sorted(glob.glob(os.path.join(MONET_PATH, "*.*")))
photo_files = sorted(glob.glob(os.path.join(PHOTO_PATH, "*.*")))

print(f"Total Monet images: {len(monet_files)}")
print(f"Total Photo images: {len(photo_files)}")

# Shuffle files for splitting
random.seed(42)
random.shuffle(monet_files)
random.shuffle(photo_files)

# Split function
def split_data(files, train_frac=0.8, val_frac=0.1):
    n = len(files)
    train_end = int(n * train_frac)
    val_end = int(n * (train_frac + val_frac))
    
    train = files[:train_end]
    val = files[train_end:val_end]
    test = files[val_end:]
    return train, val, test

monet_train, monet_val, monet_test = split_data(monet_files)
photo_train, photo_val, photo_test = split_data(photo_files)

print(f"Monet Splits - Train: {len(monet_train)}, Val: {len(monet_val)}, Test: {len(monet_test)}")
print(f"Photo Splits - Train: {len(photo_train)}, Val: {len(photo_val)}, Test: {len(photo_test)}")

In [None]:
# Create Datasets and DataLoaders
transforms_ = get_transforms()

train_dataset = ImageDataset(monet_train, photo_train, transform=transforms_)
val_dataset = ImageDataset(monet_val, photo_val, transform=transforms_)
test_dataset = ImageDataset(monet_test, photo_test, transform=transforms_)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## 2. Visualize Data

Let's look at some samples from our training loader to verify the data is loading correctly.

In [None]:
def denormalize(tensor):
    """Reverses the normalization applied in transforms."""
    return tensor * 0.5 + 0.5

batch = next(iter(train_loader))
real_monet = batch['monet']
real_photo = batch['photo']

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(denormalize(real_monet[0]).permute(1, 2, 0).cpu().numpy())
plt.title("Real Monet")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(denormalize(real_photo[0]).permute(1, 2, 0).cpu().numpy())
plt.title("Real Photo")
plt.axis('off')

plt.show()

## 3. Model Initialization & Resume Logic

We initialize the models and check if there are any saved checkpoints in Google Drive to resume from.

In [None]:
# Initialize Generator and Discriminator
G_Monet = ResNetGenerator().to(DEVICE)
G_Photo = ResNetGenerator().to(DEVICE)
D_Monet = PatchDiscriminator().to(DEVICE)
D_Photo = PatchDiscriminator().to(DEVICE)

# Apply weights
G_Monet.apply(weights_init_normal)
G_Photo.apply(weights_init_normal)
D_Monet.apply(weights_init_normal)
D_Photo.apply(weights_init_normal)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_Monet.parameters(), G_Photo.parameters()),
    lr=LR, betas=(0.5, 0.999)
)
optimizer_D_Monet = torch.optim.Adam(D_Monet.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D_Photo = torch.optim.Adam(D_Photo.parameters(), lr=LR, betas=(0.5, 0.999))

# Learning Rate Schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 25) / float(50 - 25 + 1)
)
lr_scheduler_D_Monet = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_Monet, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 25) / float(50 - 25 + 1)
)
lr_scheduler_D_Photo = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_Photo, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 25) / float(50 - 25 + 1)
)

In [None]:
# --- RESUME LOGIC ---
def find_latest_checkpoint(checkpoint_dir):
    files = glob.glob(os.path.join(checkpoint_dir, "G_Monet_epoch_*.pth"))
    if not files:
        return 0
    epochs = [int(re.search(r'epoch_(\d+).pth', f).group(1)) for f in files]
    return max(epochs)

start_epoch = find_latest_checkpoint(CHECKPOINT_DIR)

if start_epoch > 0:
    print(f"Found checkpoint! Resuming from epoch {start_epoch}...")
    G_Monet.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'G_Monet_epoch_{start_epoch}.pth'), map_location=DEVICE))
    G_Photo.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'G_Photo_epoch_{start_epoch}.pth'), map_location=DEVICE))
    D_Monet.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'D_Monet_epoch_{start_epoch}.pth'), map_location=DEVICE))
    D_Photo.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'D_Photo_epoch_{start_epoch}.pth'), map_location=DEVICE))
    
    # Fast forward schedulers
    for _ in range(start_epoch):
        lr_scheduler_G.step()
        lr_scheduler_D_Monet.step()
        lr_scheduler_D_Photo.step()
else:
    print("No checkpoint found. Starting from scratch.")

## 4. Training Loop with Visualizations

We will train for the remaining epochs, visualizing the output on a fixed sample from the validation set every epoch.

In [None]:
# Get a fixed sample for visualization
val_batch = next(iter(val_loader))
fixed_monet = val_batch['monet'].to(DEVICE)
fixed_photo = val_batch['photo'].to(DEVICE)

fake_monet_buffer = ReplayBuffer()
fake_photo_buffer = ReplayBuffer()

def show_generated_images(real_p, fake_m, real_m, fake_p):
    """Helper to display images."""
    plt.figure(figsize=(10, 10))
    
    images = [real_p, fake_m, real_m, fake_p]
    titles = ['Real Photo', 'Generated Monet', 'Real Monet', 'Generated Photo']
    
    for i, img in enumerate(images):
        plt.subplot(2, 2, i+1)
        plt.imshow(denormalize(img[0]).permute(1, 2, 0).cpu().detach().numpy())
        plt.title(titles[i])
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

print(f"Starting training from epoch {start_epoch} to {N_EPOCHS}...")

for epoch in range(start_epoch, N_EPOCHS):
    start_time = time.time()
    
    # --- Training ---
    G_Monet.train(); G_Photo.train(); D_Monet.train(); D_Photo.train()
    
    epoch_loss_G = 0.0
    
    for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{N_EPOCHS}")):
        real_monet = batch['monet'].to(DEVICE)
        real_photo = batch['photo'].to(DEVICE)
        
        # ------------------
        #  Train Generators
        # ------------------
        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_Monet(real_monet), real_monet)
        loss_id_B = criterion_identity(G_Photo(real_photo), real_photo)
        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_monet = G_Monet(real_photo)
        loss_GAN_AB = criterion_GAN(D_Monet(fake_monet), torch.ones_like(D_Monet(fake_monet)))

        fake_photo = G_Photo(real_monet)
        loss_GAN_BA = criterion_GAN(D_Photo(fake_photo), torch.ones_like(D_Photo(fake_photo)))

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        rec_photo = G_Photo(fake_monet)
        loss_cycle_A = criterion_cycle(rec_photo, real_photo)

        rec_monet = G_Monet(fake_photo)
        loss_cycle_B = criterion_cycle(rec_monet, real_monet)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + (10.0 * loss_cycle) + (5.0 * loss_identity)
        loss_G.backward()
        optimizer_G.step()
        
        epoch_loss_G += loss_G.item()

        # -----------------------
        #  Train Discriminator A
        # -----------------------
        optimizer_D_Monet.zero_grad()
        loss_real = criterion_GAN(D_Monet(real_monet), torch.ones_like(D_Monet(real_monet)))
        fake_monet_ = fake_monet_buffer.push_and_pop(fake_monet)
        loss_fake = criterion_GAN(D_Monet(fake_monet_.detach()), torch.zeros_like(D_Monet(fake_monet_)))
        loss_D_Monet = (loss_real + loss_fake) / 2
        loss_D_Monet.backward()
        optimizer_D_Monet.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------
        optimizer_D_Photo.zero_grad()
        loss_real = criterion_GAN(D_Photo(real_photo), torch.ones_like(D_Photo(real_photo)))
        fake_photo_ = fake_photo_buffer.push_and_pop(fake_photo)
        loss_fake = criterion_GAN(D_Photo(fake_photo_.detach()), torch.zeros_like(D_Photo(fake_photo_)))
        loss_D_Photo = (loss_real + loss_fake) / 2
        loss_D_Photo.backward()
        optimizer_D_Photo.step()
    
    # --- Update Learning Rate ---
    lr_scheduler_G.step()
    lr_scheduler_D_Monet.step()
    lr_scheduler_D_Photo.step()
    
    print(f"Epoch {epoch+1} finished. Avg Generator Loss: {epoch_loss_G / len(train_loader):.4f}")
    
    # --- Validation / Visualization ---
    if (epoch + 1) % 1 == 0: # Visualize every epoch
        G_Monet.eval()
        G_Photo.eval()
        with torch.no_grad():
            fake_monet_vis = G_Monet(fixed_photo)
            fake_photo_vis = G_Photo(fixed_monet)
            show_generated_images(fixed_photo, fake_monet_vis, fixed_monet, fake_photo_vis)
            
    # --- Save Checkpoints ---
    # Save EVERY epoch now
    torch.save(G_Monet.state_dict(), os.path.join(CHECKPOINT_DIR, f'G_Monet_epoch_{epoch+1}.pth'))
    torch.save(G_Photo.state_dict(), os.path.join(CHECKPOINT_DIR, f'G_Photo_epoch_{epoch+1}.pth'))
    torch.save(D_Monet.state_dict(), os.path.join(CHECKPOINT_DIR, f'D_Monet_epoch_{epoch+1}.pth'))
    torch.save(D_Photo.state_dict(), os.path.join(CHECKPOINT_DIR, f'D_Photo_epoch_{epoch+1}.pth'))
    print(f"Saved checkpoint for epoch {epoch+1} to {CHECKPOINT_DIR}")

## 5. Testing

Finally, let's run the generator on the test set and see some results.

In [None]:
G_Monet.eval()

test_iter = iter(test_loader)

# Show first 5 test images
for i in range(5):
    try:
        batch = next(test_iter)
        real_photo = batch['photo'].to(DEVICE)
        with torch.no_grad():
            fake_monet = G_Monet(real_photo)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(denormalize(real_photo[0]).permute(1, 2, 0).cpu().numpy())
        plt.title("Real Photo (Test)")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(denormalize(fake_monet[0]).permute(1, 2, 0).cpu().numpy())
        plt.title("Generated Monet (Test)")
        plt.axis('off')
        plt.show()
    except StopIteration:
        break