In [None]:
import os
import torch
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls "/content/drive/MyDrive/Applied CV Project/Sketch to Image/datasets/sketchy_dataset/augmented_data/"

In [None]:
!unzip -q '/content/drive/MyDrive/Applied CV Project/Sketch to Image/datasets/sketchy_dataset/augmented_data/data_aug_sketch.zip' -d sketch_dataset/

In [None]:
!ls sketch_dataset

In [None]:
sketch_dir = "sketch_dataset/data_aug_sketch"

In [None]:
!unzip -q '/content/drive/MyDrive/Applied CV Project/Sketch to Image/datasets/sketchy_dataset/augmented_data/data_aug_photo.zip' -d photo_dataset/

In [None]:
photo_dir = "photo_dataset/data_aug_photo"

In [None]:
def show_image(tensor_image, title=None):
    tensor_image = tensor_image * 0.5 + 0.5
    np_image = tensor_image.cpu().detach().numpy().transpose(1, 2, 0)
    plt.imshow(np_image)
    if title:
        plt.title(title)
    plt.axis('off')
    plt.show()

In [None]:
class SketchToImageDataset(Dataset):
    def __init__(self, sketch_dir, real_dir, transform=None, max_images=10000):
        self.sketch_dir = Path(sketch_dir)
        self.real_dir = Path(real_dir)
        self.transform = transform

        self.sketch_filenames = sorted(os.listdir(sketch_dir))
        self.real_filenames = sorted(os.listdir(real_dir))

        self.sketch_filenames = self.sketch_filenames[:max_images]
        self.real_filenames = self.real_filenames[:max_images]

    def __len__(self):
        return len(self.sketch_filenames)

    def __getitem__(self, index):
        sketch_filename = self.sketch_filenames[index]
        real_filename = self.real_filenames[index]

        sketch_path = self.sketch_dir / sketch_filename
        real_path = self.real_dir / real_filename

        sketch_image = Image.open(sketch_path).convert('L')
        real_image = Image.open(real_path).convert('RGB')

        if self.transform:
            sketch_image = self.transform(sketch_image)
            real_image = self.transform(real_image)

        return sketch_image, real_image

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = SketchToImageDataset(sketch_dir, photo_dir, transform=transform)
data_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, features=[64, 128, 256, 512]):
        super(Generator, self).__init__()

        self.encoder = nn.ModuleList()
        current_channels = in_channels
        for feature in features:
            self.encoder.append(
                nn.Sequential(
                    nn.Conv2d(current_channels, feature, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(feature),
                    nn.LeakyReLU(0.2)
                )
            )
            current_channels = feature

        self.decoder = nn.ModuleList()
        reversed_features = list(reversed(features))

        for i in range(len(reversed_features) - 1):
            self.decoder.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        reversed_features[i] * 2 if i > 0 else reversed_features[i],
                        reversed_features[i + 1],
                        kernel_size=4,
                        stride=2,
                        padding=1
                    ),
                    nn.BatchNorm2d(reversed_features[i + 1]),
                    nn.ReLU()
                )
            )

        self.final_transpose = nn.Sequential(
            nn.ConvTranspose2d(
                reversed_features[-1] * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.BatchNorm2d(features[0]),
            nn.ReLU()
        )

        self.final_layer = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for layer in self.encoder:
            x = layer(x)
            skip_connections.append(x)

        skip_connections = skip_connections[::-1]

        for idx, layer in enumerate(self.decoder):
            x = layer(x)

            if idx < len(skip_connections) - 1:
                skip_feature = skip_connections[idx + 1]
                if x.shape[2:] == skip_feature.shape[2:]:
                    x = torch.cat([x, skip_feature], dim=1)

        x = self.final_transpose(x)
        return self.final_layer(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=4, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[0], features[1], kernel_size=4, stride=2, padding=1),  #64 to 128
            nn.BatchNorm2d(features[1]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[1], features[2], kernel_size=4, stride=2, padding=1),  #128 to 256
            nn.BatchNorm2d(features[2]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[2], features[3], kernel_size=4, stride=2, padding=1),  #256 to 512
            nn.BatchNorm2d(features[3]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[3], 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()  #Ensure output is in the [0, 1] range
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
# Initialize models
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Loss functions
adversarial_loss = nn.BCELoss()
reconstruction_loss = nn.L1Loss()

num_epochs = 75

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, (sketch, real_image) in enumerate(data_loader):
        optimizer_D.zero_grad()

        real_combined = torch.cat([sketch, real_image], dim=1)
        real_output = discriminator(real_combined)
        real_label = torch.ones_like(real_output)
        loss_real = adversarial_loss(real_output, real_label)

        fake_image = generator(sketch)
        fake_combined = torch.cat([sketch, fake_image], dim=1)
        fake_output = discriminator(fake_combined)
        fake_label = torch.zeros_like(fake_output)
        loss_fake = adversarial_loss(fake_output, fake_label)

        # Discriminator loss
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        # Training the generator
        optimizer_G.zero_grad()

        fake_image = generator(sketch)
        fake_combined = torch.cat([sketch, fake_image], dim=1)
        fake_output = discriminator(fake_combined)

        generator_label = torch.ones_like(fake_output)
        loss_adv = adversarial_loss(fake_output, generator_label)

        loss_rec = reconstruction_loss(fake_image, real_image)

        # Generator loss
        lambda_recon = 100
        loss_G = loss_adv + (lambda_recon * loss_rec)
        loss_G.backward()
        optimizer_G.step()