In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from torchvision import transforms
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import random
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SketchDataset(Dataset):
    def __init__(self, csv_file, image_root_dir, sketch_root_dir, transform=None):
        
        self.data_frame = pd.read_csv(csv_file)
        self.image_root_dir = image_root_dir
        self.sketch_root_dir = sketch_root_dir
        self.transform = transform
        self.num_sketches = len(self.data_frame)  # Update with the actual number of sketches
        self.num_samples = len(self.data_frame)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # print("idx: ",idx)
        img_name = os.path.join(self.image_root_dir, self.data_frame.iloc[idx, 0] + '.jpg')
        # sketch_idx = idx % self.num_sketches  # Cyclic indexing for sketches
        # print(sketch_idx)
        sketch_name = os.path.join(self.sketch_root_dir, self.data_frame.iloc[idx, 0] + '_segmentation'+'.png')
        
        image = Image.open(img_name).convert('RGB')
        sketch = Image.open(sketch_name).convert('RGB')

        label = torch.tensor(self.data_frame.iloc[idx, 1:], dtype=torch.float32)

        rand_idx= random.randint(0, self.num_samples-1)
        rand_label = torch.tensor(self.data_frame.iloc[rand_idx, 1:], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)
            sketch = self.transform(sketch)
        
        return label, sketch, image, img_name, rand_label

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

# Modify paths as needed
train_dataset = SketchDataset(csv_file = '/home/cvlab/Karan/A_3/Dataset_A4/Train_labels.csv', 
                              image_root_dir = '/home/cvlab/Karan/A_3/Dataset_A4/Train_data',
                              sketch_root_dir = '/home/cvlab/Karan/A_3/Dataset_A4/Paired_train_sketches',
                              transform = transform)
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to display a batch of images
def show_images(images, titles=None):
    num_images = len(images)
    plt.figure(figsize=(12, 6))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(np.transpose(images[i], (1, 2, 0)))
        plt.axis('off')
        if titles:
            plt.title(titles[i])
    plt.show()

# Select 5 random samples
selected_samples = np.random.choice(len(train_dataset), 5, replace=False)

# Fetch and display the selected samples
for idx in selected_samples:
    label, sketch, image, img_name,rand = train_dataset[idx]
    print("Image Name:", img_name)
    print("Label:", label)
    show_images([image, sketch], titles=['Image', 'Sketch'])


In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=10, features=[64, 128, 256, 512]):
        super().__init__()
        # in_channels = 10
        self.initial = nn.Sequential(
            nn.Conv2d(
                # in_channels * 2,
                in_channels ,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        # print("x: ",x.shape)
        # print("y: ",y.shape)
        x = torch.cat([x, y], dim=1)
        # print("x: ",x.shape)
        # in_channels=3
        x = self.initial(x)
        # print("x: ",x.shape)
        x = self.model(x)
        # print("x: ",x.shape)
        return x

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=10, features=64): # 3 earlier
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        last = self.final_up(torch.cat([up7, d1], 1))
        # print("last: ",last.shape)
        return last

In [None]:
disc = Discriminator(in_channels=10).to(device)
gen = Generator(in_channels=10, features=64).to(device)
opt_disc = optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
bce = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

In [None]:
NUM_EPOCHS = 100
L1_LAMBDA = 100

# Set directory to save images
save_dir = "CGAN_paired"
os.makedirs(save_dir, exist_ok=True)
step = 0

# Lists to store losses for plotting
train_losses_D = []
train_losses_G = []

gen.train()
disc.train()

best_loss = float('inf')
best_model_path = os.path.join(save_dir, "best_model.pth")

for epoch in range(NUM_EPOCHS):

    # Initialize variables to accumulate losses over the epoch
    epoch_train_loss_D = 0
    epoch_train_loss_G = 0

    for batch_idx, (labels, sketch,real,name,rand_labels) in enumerate(tqdm(dataloader)):
        
        
        real = real.to(device)
        noise= sketch.to(device)
        labels = labels.to(device).long()
        rand_labels = rand_labels.to(device).long()
        # print("labels : ",labels)
        # print("rand_labels: ",rand_labels)
        img_size=256
        
        # Train Discriminator
        with torch.cuda.amp.autocast():
            embed_disc=nn.Embedding(7, img_size*img_size).to(device)
            embedding_disc=embed_disc(labels) # Real Label given to discriminator
            embedding_disc=embedding_disc.view(labels.shape[0],7,img_size,img_size)

            embed_gen = nn.Embedding(7, 1).to(device)
            embedding_gen = embed_gen(rand_labels) # Random Label to generator
            
            # print("embedding_gen: ",embedding_gen.shape)
            embedding_gen = embedding_gen.unsqueeze(3)
            upsampled_embedding_gen = F.interpolate(embedding_gen, size=(img_size,img_size), mode='nearest')
            real_gen = torch.cat([noise, upsampled_embedding_gen], dim=1)
            
            y_fake = gen(real_gen) # Generate fake images

            # print("y_fake: ",y_fake.shape)

            D_real = disc(real,embedding_disc) # Discriminator with real images
            D_real_loss = bce(D_real, torch.ones_like(D_real)) # Loss with real images

            # print("y_fake: ",y_fake.shape)
            # print("embedding_disc: ",embedding_disc.shape)
            D_fake = disc(y_fake.detach(),embedding_disc) # Discriminator with fake images
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake)) # Loss with fake images
            
            D_loss = (D_real_loss + D_fake_loss) / 2
            disc.zero_grad()
            d_scaler.scale(D_loss).backward(retain_graph=True)
            d_scaler.step(opt_disc)
            d_scaler.update()

        # Train Generator
            D_fake = disc(y_fake,embedding_disc)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            # print("y_fake: ",y_fake.shape)
            # print("real: ",real.shape)
            L1 = l1_loss(y_fake, real) * L1_LAMBDA
            G_loss = G_fake_loss + L1

            opt_gen.zero_grad()
            g_scaler.scale(G_loss).backward(retain_graph=True)
            g_scaler.step(opt_gen)
            g_scaler.update()

        if batch_idx % 50 == 0:
            with torch.no_grad():
                fake = gen(real_gen)

               # Create grid containing both real and fake images
                img_grid = torch.cat((real[:8], fake[:8], noise[:8]), dim=0)
               
                # Make grid for visualization
                img_grid = torchvision.utils.make_grid(img_grid, nrow=8, normalize=True)

                # fid = FrechetInceptionDistance(feature=64)
                # fid.update(real, real=True)
                # fid.update(fake, real=False)

                # Save the generated fake images locally
                torchvision.utils.save_image(img_grid, os.path.join(save_dir, f"epoch_{epoch}_batch_{batch_idx}.png"))

            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}"
            )

            step += 1

        # Accumulate losses
        epoch_train_loss_D += D_loss.item()
        epoch_train_loss_G += G_loss.item()

    # Calculate average training losses for the epoch
    avg_train_loss_D = epoch_train_loss_D / len(dataloader)
    avg_train_loss_G = epoch_train_loss_G / len(dataloader)

    # Store the average training losses for plotting
    train_losses_D.append(avg_train_loss_D)
    train_losses_G.append(avg_train_loss_G)


    # Save the model if it has the best performance on the training set
    if avg_train_loss_G < best_loss:
        best_loss = avg_train_loss_G
        torch.save(gen.state_dict(), best_model_path)





In [None]:
import matplotlib.pyplot as plt

# Plot the training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses_D, label='Train D loss')
plt.plot(train_losses_G, label='Train G loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Curve')
plt.legend()
plt.savefig(os.path.join(save_dir, "training_curve.png"))
plt.close()