In [1]:
from torchvision import transforms

# Resize to 256x256 and normalize to [-1, 1]
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Converts to [0,1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Now in [-1, 1]
])

In [2]:
from torch.utils.data import Dataset
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.transform = transform
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A) if f.endswith(".jpg")])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B) if f.endswith(".jpg")])

    def __len__(self):
        return min(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        img_A = Image.open(self.files_A[idx]).convert("RGB")
        img_B = Image.open(self.files_B[idx]).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}


In [3]:
from torch.utils.data import DataLoader

dataset = ImageDataset("/kaggle/input/monet2photo/trainA", "/kaggle/input/monet2photo/trainB", transform=transform)


In [4]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [5]:
for i, batch in enumerate(dataloader):
    print(batch["A"].shape, batch["B"].shape)
    break

torch.Size([1, 3, 256, 256]) torch.Size([1, 3, 256, 256])


generator

In [6]:
# models/generator.py
import torch.nn as nn

class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x)

class ResnetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_blocks=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]
        # Downsample
        in_features = 64
        for _ in range(2):
            out_features = in_features * 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # ResBlocks
        for _ in range(n_blocks):
            model += [ResnetBlock(in_features)]

        # Upsample
        for _ in range(2):
            out_features = in_features // 2
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


discriminator

In [7]:
# models/discriminator.py
import torch.nn as nn

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def block(in_feat, out_feat, norm=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if norm:
                layers.append(nn.InstanceNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_channels, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)


loss

In [8]:
adv_criterion = nn.MSELoss()  # For LSGAN (better stability)

Cycle Consistency Loss

In [9]:
cycle_criterion = nn.L1Loss()

identity loss

In [10]:
identity_criterion = nn.L1Loss()

training loop

In [11]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 10

In [12]:
G_A2B = ResnetGenerator().to(device)
G_B2A = ResnetGenerator().to(device)
D_A = PatchDiscriminator().to(device)
D_B = PatchDiscriminator().to(device)


In [13]:
import torch.optim as optim

# Combine parameters of both generators
optimizer_G = optim.Adam(
    list(G_A2B.parameters()) + list(G_B2A.parameters()), 
    lr=0.0002, 
    betas=(0.5, 0.999)
)

# Optimizers for discriminators
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [14]:
lambda_identity = 5.0
lambda_cycle = 10.0

In [19]:
import matplotlib.pyplot as plt

# Store loss history
G_losses = []
D_A_losses = []
D_B_losses = []
epoch_times = []

In [None]:
import time

for epoch in range(num_epochs):
    epoch_start_time = time.time()  # Track start time of epoch
    total_loss_G = 0
    total_loss_D_A = 0
    total_loss_D_B = 0
    total_batches = len(dataloader)

    for batch_idx, batch in enumerate(dataloader):
        batch_start_time = time.time()  # Track start time of batch

        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)

        # === Train Generators ===
        optimizer_G.zero_grad()

        # Identity loss
        same_B = G_A2B(real_B)
        same_A = G_B2A(real_A)
        loss_id_A = identity_criterion(same_A, real_A) * lambda_identity
        loss_id_B = identity_criterion(same_B, real_B) * lambda_identity

        # GAN loss
        fake_B = G_A2B(real_A)
        pred_fake_B = D_B(fake_B)
        loss_GAN_A2B = adv_criterion(pred_fake_B, torch.ones_like(pred_fake_B))

        fake_A = G_B2A(real_B)
        pred_fake_A = D_A(fake_A)
        loss_GAN_B2A = adv_criterion(pred_fake_A, torch.ones_like(pred_fake_A))

        # Cycle loss
        rec_A = G_B2A(fake_B)
        rec_B = G_A2B(fake_A)
        loss_cycle_A = cycle_criterion(rec_A, real_A) * lambda_cycle
        loss_cycle_B = cycle_criterion(rec_B, real_B) * lambda_cycle

        # Total generator loss
        loss_G = loss_id_A + loss_id_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B
        loss_G.backward()
        optimizer_G.step()

        # === Train Discriminators ===
        optimizer_D_A.zero_grad()
        pred_real_A = D_A(real_A)
        pred_fake_A = D_A(fake_A.detach())
        loss_D_real_A = adv_criterion(pred_real_A, torch.ones_like(pred_real_A))
        loss_D_fake_A = adv_criterion(pred_fake_A, torch.zeros_like(pred_fake_A))
        loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        optimizer_D_B.zero_grad()
        pred_real_B = D_B(real_B)
        pred_fake_B = D_B(fake_B.detach())
        loss_D_real_B = adv_criterion(pred_real_B, torch.ones_like(pred_real_B))
        loss_D_fake_B = adv_criterion(pred_fake_B, torch.zeros_like(pred_fake_B))
        loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        # Accumulate losses
        total_loss_G += loss_G.item()
        total_loss_D_A += loss_D_A.item()
        total_loss_D_B += loss_D_B.item()

        # Track time per batch
        batch_time = time.time() - batch_start_time
        remaining_batches = total_batches - (batch_idx + 1)
        estimated_time_left = batch_time * remaining_batches
        print(f"Batch {batch_idx+1}/{total_batches} completed in {batch_time:.2f} seconds. Estimated time left: {estimated_time_left/60:.2f} minutes.", end='\r')

    # After all batches in an epoch, print the total results
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    G_losses.append(total_loss_G / total_batches)
    D_A_losses.append(total_loss_D_A / total_batches)
    D_B_losses.append(total_loss_D_B / total_batches)
    epoch_times.append(epoch_duration / 60)  # Save time in minutes

    print(f"\nEpoch [{epoch+1}/{num_epochs}] completed in {epoch_duration/60:.2f} minutes.")
    print(f"  Generator Loss: {total_loss_G/total_batches:.4f}")
    print(f"  Discriminator A Loss: {total_loss_D_A/total_batches:.4f}")
    print(f"  Discriminator B Loss: {total_loss_D_B/total_batches:.4f}")


Batch 1072/1072 completed in 0.31 seconds. Estimated time left: 0.00 minutes.
Epoch [1/10] completed in 5.71 minutes.
  Generator Loss: 6.6729
  Discriminator A Loss: 0.1203
  Discriminator B Loss: 0.1759
Batch 1072/1072 completed in 0.31 seconds. Estimated time left: 0.00 minutes.
Epoch [2/10] completed in 5.70 minutes.
  Generator Loss: 6.5882
  Discriminator A Loss: 0.1115
  Discriminator B Loss: 0.1728
Batch 1072/1072 completed in 0.31 seconds. Estimated time left: 0.00 minutes.
Epoch [3/10] completed in 5.69 minutes.
  Generator Loss: 6.5082
  Discriminator A Loss: 0.1058
  Discriminator B Loss: 0.1726
Batch 1072/1072 completed in 0.31 seconds. Estimated time left: 0.00 minutes.
Epoch [4/10] completed in 5.72 minutes.
  Generator Loss: 6.4351
  Discriminator A Loss: 0.0991
  Discriminator B Loss: 0.1721
Batch 857/1072 completed in 0.31 seconds. Estimated time left: 1.12 minutes.

save models

In [17]:
torch.save(G_A2B.state_dict(), "G_A2B.pth")
torch.save(G_B2A.state_dict(), "G_B2A.pth")
torch.save(D_A.state_dict(), "D_A.pth")
torch.save(D_B.state_dict(), "D_B.pth")

plotting

In [None]:
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(12, 6))

# Plot Generator loss
plt.plot(epochs, G_losses, label='Generator Loss', marker='o')

# Plot Discriminator losses
plt.plot(epochs, D_A_losses, label='Discriminator A Loss', marker='x')
plt.plot(epochs, D_B_losses, label='Discriminator B Loss', marker='s')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses Over Epochs')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(12, 6))
plt.plot(epochs, epoch_times, label='Epoch Duration (min)', marker='D', color='purple')
plt.xlabel('Epoch')
plt.ylabel('Time (minutes)')
plt.title('Running Time per Epoch')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

evaluation

In [16]:
with torch.no_grad():
    for batch in test_dataloader:
        real_A = batch["A"].to(device)
        fake_B = G_A2B(real_A)
        # Save or plot fake_B

NameError: name 'test_dataloader' is not defined

save generators and so

In [None]:
# Save the models after training
torch.save(G_A2B.state_dict(), '/kaggle/working/G_A2B.pth')  # Save Generator A to B
torch.save(G_B2A.state_dict(), '/kaggle/working/G_B2A.pth')  # Save Generator B to A
torch.save(D_A.state_dict(), '/kaggle/working/D_A.pth')      # Save Discriminator A
torch.save(D_B.state_dict(), '/kaggle/working/D_B.pth')      # Save Discriminator B
