# Install & Import Libraries

In [None]:
!pip install torchvision

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2shoes.tar.gz
!tar -xvf edges2shoes.tar.gz

# Dataset Class

In [3]:
class Edges2ShoesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.images = os.listdir(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")

        w, h = image.size
        input_image = image.crop((0, 0, w//2, h))
        target_image = image.crop((w//2, 0, w, h))

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image


transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = Edges2ShoesDataset("/content/edges2shoes/train", transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
import os
os.listdir("/content/edges2shoes")

#  U-Net Generator

In [5]:
def down_block(in_channels, out_channels, normalize=True):
    layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
    if normalize:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))
    return nn.Sequential(*layers)

def up_block(in_channels, out_channels, dropout=False):
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    ]
    if dropout:
        layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)


class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()

        self.down1 = down_block(3, 64, normalize=False)
        self.down2 = down_block(64, 128)
        self.down3 = down_block(128, 256)
        self.down4 = down_block(256, 512)

        self.up1 = up_block(512, 256)
        self.up2 = up_block(512, 128)
        self.up3 = up_block(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u1 = torch.cat([u1, d3], dim=1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, d2], dim=1)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, d1], dim=1)

        return self.final(u3)

# PatchGAN

In [6]:
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 1, 4, 1, 1)
        )

    def forward(self, input_img, target_img):
        x = torch.cat([input_img, target_img], dim=1)
        return self.model(x)

# Initialize Models

In [7]:
G = UNetGenerator().to(device)
D = PatchDiscriminator().to(device)

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training the model

In [None]:
epochs = 30
lambda_L1 = 100

G_losses = []
D_losses = []

for epoch in range(epochs):
    for i, (input_img, target_img) in enumerate(dataloader):

        input_img = input_img.to(device)
        target_img = target_img.to(device)
        optimizer_D.zero_grad()

        fake_img = G(input_img)

        real_pred = D(input_img, target_img)
        fake_pred = D(input_img, fake_img.detach())

        real_loss = criterion_GAN(real_pred, torch.ones_like(real_pred))
        fake_loss = criterion_GAN(fake_pred, torch.zeros_like(fake_pred))

        D_loss = (real_loss + fake_loss) * 0.5
        D_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()

        fake_pred = D(input_img, fake_img)

        GAN_loss = criterion_GAN(fake_pred, torch.ones_like(fake_pred))
        L1_loss = criterion_L1(fake_img, target_img) * lambda_L1

        G_loss = GAN_loss + L1_loss
        G_loss.backward()
        optimizer_G.step()
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())

    print(f"Epoch [{epoch+1}/{epochs}]  D_loss: {D_loss.item():.4f}  G_loss: {G_loss.item():.4f}")

# Visualization

In [None]:
def show_images(input_img, fake_img, target_img):
    input_img = input_img[0].permute(1,2,0).cpu().detach().numpy()
    fake_img = fake_img[0].permute(1,2,0).cpu().detach().numpy()
    target_img = target_img[0].permute(1,2,0).cpu().detach().numpy()

    fig, axs = plt.subplots(1,3, figsize=(12,4))
    axs[0].imshow((input_img + 1)/2)
    axs[0].set_title("Input")

    axs[1].imshow((fake_img + 1)/2)
    axs[1].set_title("Generated")

    axs[2].imshow((target_img + 1)/2)
    axs[2].set_title("Target")

    for ax in axs:
        ax.axis("off")

    plt.show()


input_img, target_img = next(iter(dataloader))
input_img = input_img.to(device)
fake_img = G(input_img)

show_images(input_img, fake_img, target_img)

In [None]:
input_img, target_img = next(iter(dataloader))
input_img = input_img.to(device)

G.eval()
with torch.no_grad():
    fake_img = G(input_img)

import torchvision.utils as vutils

os.makedirs("saved_images", exist_ok=True)

vutils.save_image(
    (fake_img + 1) / 2,
    "saved_images/generated_sample.png"
)

print("Generated image saved!")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,5))
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss Curve")

plt.savefig("saved_images/loss_curve.png")
plt.show()

print("Loss curve saved!")

In [None]:
os.makedirs("saved_models", exist_ok=True)

torch.save(G.state_dict(), "saved_models/generator.pth")

torch.save(D.state_dict(), "saved_models/discriminator.pth")

print("Models saved successfully!")

# Markdown (Comparison Section)

# Performance Comparison

Baseline CNN:
- Produces blurry images
- Optimizes only L1 / MSE loss

Pix2Pix:
- Uses adversarial + L1 loss
- Produces sharper images
- Better texture and realism

Conclusion:
GAN improves perceptual quality significantly.