In [32]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
from torch import nn, optim
from torchvision.utils import save_image


In [22]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residuals=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features *= 2

        # Residual blocks
        for _ in range(n_residuals):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features //= 2

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, kernel_size=7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [33]:
class CycleGANDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        """
        root_dir: path to dataset (from kagglehub)
        domain: 'trainA' (summer), 'trainB' (winter), 'testA', 'testB'
        """
        self.dir = os.path.join(root_dir, domain)
        self.files = sorted(os.listdir(self.dir))
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img_path = os.path.join(self.dir, self.files[index])
        image = Image.open(img_path).convert("RGB")
        return self.transform(image)

In [23]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def conv_block(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *conv_block(in_channels, 64, normalize=False),
            *conv_block(64, 128),
            *conv_block(128, 256),
            *conv_block(256, 512),
            nn.Conv2d(512, 1, 4, 1, 1)
        )

    def forward(self, x):
        return self.model(x)


In [24]:
!pip install -q kaggle

In [25]:
from google.colab import files
files.upload();

Saving kaggle.json to kaggle (1).json


In [26]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json      # Kaggle refuses if permissions are too open


In [27]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("balraj98/summer2winter-yosemite")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/summer2winter-yosemite


In [28]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 1
IMG_SIZE = 256
LR = 2e-4
EPOCHS = 100
LAMBDA_CYCLE = 10
LAMBDA_ID = 5

from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [34]:
summer_dataset = CycleGANDataset(root_dir=path, domain='trainA', transform=transform)
winter_dataset = CycleGANDataset(root_dir=path, domain='trainB', transform=transform)

In [35]:
summer_loader = DataLoader(summer_dataset, batch_size=BATCH_SIZE, shuffle=True)
winter_loader = DataLoader(winter_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
G_s2w = Generator().to(DEVICE)  # Summer → Winter
G_w2s = Generator().to(DEVICE)  # Winter → Summer
D_s = Discriminator().to(DEVICE)
D_w = Discriminator().to(DEVICE)

# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Optimizers
opt_G = optim.Adam(list(G_s2w.parameters()) + list(G_w2s.parameters()), lr=LR, betas=(0.5, 0.999))
opt_D = optim.Adam(list(D_s.parameters()) + list(D_w.parameters()), lr=LR, betas=(0.5, 0.999))

# Labels
real_label = 1.0
fake_label = 0.0

# Training Loop
for epoch in range(EPOCHS):
    for i, (summer_img, winter_img) in enumerate(zip(summer_loader, winter_loader)):
        summer_img = summer_img.to(DEVICE)
        winter_img = winter_img.to(DEVICE)

        # ------------------
        #  Train Generators
        # ------------------
        fake_winter = G_s2w(summer_img)
        fake_summer = G_w2s(winter_img)

        D_w_fake = D_w(fake_winter)
        D_s_fake = D_s(fake_summer)

        # GAN losses
        loss_GAN_s2w = criterion_GAN(D_w_fake, torch.ones_like(D_w_fake))
        loss_GAN_w2s = criterion_GAN(D_s_fake, torch.ones_like(D_s_fake))

        # Identity losses
        id_summer = G_w2s(summer_img)
        id_winter = G_s2w(winter_img)
        loss_identity = criterion_identity(id_summer, summer_img) + criterion_identity(id_winter, winter_img)

        # Cycle losses
        cycle_summer = G_w2s(fake_winter)
        cycle_winter = G_s2w(fake_summer)
        loss_cycle = criterion_cycle(cycle_summer, summer_img) + criterion_cycle(cycle_winter, winter_img)

        # Total Generator loss
        loss_G = (
            loss_GAN_s2w + loss_GAN_w2s
            + LAMBDA_CYCLE * loss_cycle
            + LAMBDA_ID * loss_identity
        )

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        # -----------------------
        #  Train Discriminator W
        # -----------------------
        D_w_real = D_w(winter_img)
        D_w_fake_detached = D_w(fake_winter.detach())
        loss_D_w = (
            criterion_GAN(D_w_real, torch.ones_like(D_w_real)) +
            criterion_GAN(D_w_fake_detached, torch.zeros_like(D_w_fake_detached))
        ) * 0.5

        opt_D.zero_grad()
        loss_D_w.backward()
        opt_D.step()

        # -----------------------
        #  Train Discriminator S
        # -----------------------
        D_s_real = D_s(summer_img)
        D_s_fake_detached = D_s(fake_summer.detach())
        loss_D_s = (
            criterion_GAN(D_s_real, torch.ones_like(D_s_real)) +
            criterion_GAN(D_s_fake_detached, torch.zeros_like(D_s_fake_detached))
        ) * 0.5

        opt_D.zero_grad()
        loss_D_s.backward()
        opt_D.step()

        if i % 200 == 0:
            print(
                f"[Epoch {epoch}/{EPOCHS}] [Batch {i}] "
                f"[D_S: {loss_D_s.item():.4f}, D_W: {loss_D_w.item():.4f}] "
                f"[G: {loss_G.item():.4f}]"
            )

    # Save samples
    if (epoch + 1) % 10 == 0:
        save_image(fake_winter * 0.5 + 0.5, f"results/fake_winter_epoch{epoch+1}.png")
        save_image(fake_summer * 0.5 + 0.5, f"results/fake_summer_epoch{epoch+1}.png")

        torch.save(G_s2w.state_dict(), f"weights/G_s2w_epoch{epoch+1}.pth")
        torch.save(G_w2s.state_dict(), f"weights/G_w2s_epoch{epoch+1}.pth")

print("✅ Training complete.")