In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from torchvision import transforms
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import random
import wandb


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Step 1: Prepare your dataset
class SketchDataset(Dataset):
    def __init__(self, csv_file, image_root_dir, sketch_root_dir, transform=None):
        
        self.data_frame = pd.read_csv(csv_file)
        self.image_root_dir = image_root_dir
        self.sketch_root_dir = sketch_root_dir
        self.transform = transform
        self.num_sketches = 3594  # Update with the actual number of sketches
        self.num_samples = len(self.data_frame)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # print("idx: ",idx)
        img_name = os.path.join(self.image_root_dir, self.data_frame.iloc[idx, 0] + '.jpg')
        sketch_idx = idx % self.num_sketches  # Cyclic indexing for sketches
        # print(sketch_idx)
        sketch_name = os.path.join(self.sketch_root_dir, f"sketch_{sketch_idx + 1}.png")
        
        image = Image.open(img_name).convert('RGB')
        sketch = Image.open(sketch_name).convert('RGB')

        label = torch.tensor(self.data_frame.iloc[idx, 1:], dtype=torch.float32)

        rand_idx= random.randint(0, self.num_samples-1)
        rand_label = torch.tensor(self.data_frame.iloc[rand_idx, 1:], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)
            sketch = self.transform(sketch)
        
        return label, sketch, image,img_name, rand_label

# transform = transforms.Compose([
#     transforms.Resize((256,256)),
#     transforms.ToTensor(),
# ])

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.RandomRotation(15),       # Random rotation up to 15 degrees
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0), ratio=(0.75, 1.333)),  # Random resize and crop
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])


# Modify paths as needed
train_dataset = SketchDataset(csv_file = '/home/cvlab/Karan/A_3/Dataset_A4/Train_labels.csv', 
                              image_root_dir = '/home/cvlab/Karan/A_3/Dataset_A4/Train_data',
                              sketch_root_dir = '/home/cvlab/Karan/A_3/Dataset_A4/Unpaired_sketch',
                              transform = transform)
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        # self.ngpu = ngpu
        num_classes=7
        self.conv1 = nn.Sequential(
                    nn.Conv2d(10, 16, 3, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout(0.5, inplace=False),
                )
        # Convolution 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )
        # Convolution 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )
        # Convolution 4
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )
        # Convolution 5
        self.conv5 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )
        # Convolution 6
        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=False),
        )
        # discriminator fc
        self.fc_dis = nn.Linear(29*29*512, 1)
        # aux-classifier fc
        self.fc_aux = nn.Linear(29*29*512, num_classes)
        # softmax and sigmoid
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        # print("input: ",input.shape)
        conv1 = self.conv1(input)
        # print("conv1: ",conv1.shape)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        conv6 = self.conv6(conv5)
        # print("conv6: ",conv6.shape)
        flat6 = conv6.view(-1, 29*29*512)
        fc_dis = self.fc_dis(flat6)
        fc_aux = self.fc_aux(flat6)
        classes = self.softmax(fc_aux)
        realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1)
        # print("realfake: ",realfake.shape)
        # print("classes: ",classes.shape)
        return realfake, classes

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=10, features=64): # 3 earlier
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        last = self.final_up(torch.cat([up7, d1], 1))
        # print("last: ",last.shape)
        return last

In [None]:
disc = Discriminator(in_channels=10).to(device)
gen = Generator(in_channels=10, features=64).to(device)
opt_disc = optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
bce = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
# Compute the Binary Cross Entropy loss with logits
aux_criterion = nn.NLLLoss()

In [None]:
def compute_acc(preds, labels):
    correct = 0
    preds_ = preds.argmax(dim=1)  # Get the index of the maximum predicted probability for each sample
    # print("preds_: ",preds_)
    # print("labels: ",labels)
    correct = preds_.eq(labels).sum().item()  # Compare the indices with the target labels
    # print("correct: ",correct)
    acc = correct / len(labels) * 100.0  # Compute accuracy
    # print("acc: ",acc)
    return acc

In [None]:
def gradient_penalty(critic, labels, real, fake, img_size, device="cuda"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)
    # print("labels size: ",labels.shape)
    # print("interpolated_images: ",interpolated_images.shape)
    embed_disc=nn.Embedding(7, img_size*img_size).to(device)
    embedding_disc=embed_disc(labels.long()) # Real Label given to discriminator
    # print("embedding_disc: ",embedding_disc.shape)
    
    embedding_disc=embedding_disc.view(labels.shape[0],7,img_size,img_size)
    real_disc = torch.cat([interpolated_images, embedding_disc], dim=1)
    # print("real_disc: ",real_disc.shape)
    # Calculate critic scores
    mixed_scores = critic(real_disc)
    # print("mixed_scores: ",len(mixed_scores))
    # Take the gradient of the scores with respect to the images
    grad_outputs = [torch.ones_like(score) for score in mixed_scores]
    # print("grad_outputs: ",len(grad_outputs))
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
NUM_EPOCHS = 100
L1_LAMBDA = 100
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# Set directory to save images
save_dir = "ACGAN_Unpaired"
os.makedirs(save_dir, exist_ok=True)
step = 0

# Lists to store losses for plotting
train_losses_D = []
train_losses_G = []

gen.train()
disc.train()

best_loss = float('inf')
best_model_path = os.path.join(save_dir, "best_model.pth")

for epoch in range(NUM_EPOCHS):

    # Initialize variables to accumulate losses over the epoch
    epoch_train_loss_D = 0
    epoch_train_loss_G = 0
    epoch_train_loss_A = 0

    for batch_idx, (labels, sketch,real,name,rand_labels) in enumerate(tqdm(dataloader)):
        
        
        real = real.to(device)
        noise= sketch.to(device)
        labels = labels.to(device)
        rand_labels = rand_labels.to(device).long()
        # print("labels : ",labels)
        # print("rand_labels: ",rand_labels)
        img_size=256
        
        # Train Discriminator
        with torch.cuda.amp.autocast():
            embed_disc=nn.Embedding(7, img_size*img_size).to(device)
            embedding_disc=embed_disc(labels.long()) # Real Label given to discriminator
            # print("embedding_discs: ",embedding_disc.shape)
            embedding_disc=embedding_disc.view(labels.shape[0],7,img_size,img_size)
            real_disc = torch.cat([real, embedding_disc], dim=1)
            
            embed_gen = nn.Embedding(7, img_size*img_size).to(device)
            embedding_gen = embed_gen(rand_labels) # Random Label to generator
            
            # print("embedding_gen: ",embedding_gen.shape)
            # embedding_gen = embedding_gen.unsqueeze(3)
            # upsampled_embedding_gen = F.interpolate(embedding_gen, size=(img_size,img_size), mode='nearest')
            upsampled_embedding_gen=embedding_gen.view(labels.shape[0],7,img_size,img_size)
            real_gen = torch.cat([noise, upsampled_embedding_gen], dim=1)
            for _ in range(CRITIC_ITERATIONS):
                disc.zero_grad()
                y_fake = gen(real_gen) # Generate fake images
                y_fake_conc=torch.cat([y_fake, embedding_disc], dim=1)
                # print("y_fake: ",y_fake.shape)

                # print("real_disc: ",real_disc.shape)
                D_real,D_labels = disc(real_disc) # Discriminator with real images
                dis_errD_real = bce(D_real, torch.ones_like(D_real)) # Loss with real images

                # Assuming D_labels contains raw logits from the discriminator network
                D_log_probs = nn.LogSoftmax(dim=1)(D_labels)
                labels1 = labels.argmax(dim=1)
                # print("D_labels: ",D_labels)
                # print("labels: ",labels)
                aux_errD_real = aux_criterion(D_log_probs, labels1) # Loss for labels
                # print("aux_errD_real: ",aux_errD_real)
                errD_real = dis_errD_real + aux_errD_real
                errD_real.backward(retain_graph=True)
                D_x = D_real.data.mean()
                
                # compute the current classification accuracy
                accuracy = compute_acc(D_log_probs, labels1)


                # print("y_fake: ",y_fake.shape)
                # print("embedding_disc: ",embedding_disc.shape)
                D_fake,D_fake_label = disc(y_fake_conc.detach()) # Discriminator with fake images
                dis_errD_fake = bce(D_fake, torch.zeros_like(D_fake)) # Loss with fake images

                D_log_probs_fake = nn.LogSoftmax(dim=1)(D_fake_label)
                aux_errD_fake = aux_criterion(D_log_probs_fake, labels1 )

                errD_fake = dis_errD_fake + aux_errD_fake
                # disc.zero_grad()
                # errD_fake.backward(retain_graph=True)
                # D_G_z1 = D_fake.data.mean()
                

                gp = gradient_penalty(disc, labels, real, y_fake, img_size, device=device)
                # opt_disc.step()
                # errD = -(torch.mean(errD_real) - torch.mean(errD_fake)) + LAMBDA_GP * gp
                errD = (errD_real + errD_fake)/2 + LAMBDA_GP * gp

                # D_loss = (D_real_loss + D_fake_loss) / 2
                
                d_scaler.scale(errD).backward(retain_graph=True)
                d_scaler.step(opt_disc)
                d_scaler.update()

        # Train Generator
            gen.zero_grad()
            D_fake, D_gen_label = disc(y_fake_conc)
            dis_errG = bce(D_fake, torch.ones_like(D_fake))
            # print("D_gen_label: ",D_gen_label)
            # print("labels: ",labels)

            D_log_probs_1 = nn.LogSoftmax(dim=1)(D_gen_label)
            aux_errG = aux_criterion(D_log_probs_1, labels1)

            L1 = l1_loss(y_fake, real) * L1_LAMBDA
            # L1 = l1_loss(y_fake, real)
            # ms_ssim_loss = 1 - ms_ssim( y_fake, real, data_range=1, size_average=True )
            # print(y_fake)
            # print("ms_ssim_loss: ",ms_ssim_loss)
            # print("L1: ",L1)
            # print("aux_errG: ",aux_errG)
            # print("dis_errG: ",dis_errG)
            errG = dis_errG + aux_errG + L1 
            # errG.backward(retain_graph=True)
            # D_G_z2 = D_fake.data.mean()
            # opt_gen.step()

            # # print("y_fake: ",y_fake.shape)
            # # print("real: ",real.shape)
            # L1 = l1_loss(y_fake, real) * L1_LAMBDA
            # G_loss = G_fake_loss + L1

            opt_gen.zero_grad()
            g_scaler.scale(errG).backward(retain_graph=True)
            g_scaler.step(opt_gen)
            g_scaler.update()

        if batch_idx % 50 == 0:
            with torch.no_grad():
                fake = gen(real_gen)
        
               # Create grid containing both real and fake images
                img_grid = torch.cat((real[:8], fake[:8], noise[:8]), dim=0)
                
                # Make grid for visualization
                img_grid = torchvision.utils.make_grid(img_grid, nrow=8, normalize=True)

                # fid = FrechetInceptionDistance(feature=64)
                # fid.update(real, real=True)
                # fid.update(fake, real=False)

                # Save the generated fake images locally
                torchvision.utils.save_image(img_grid, os.path.join(save_dir, f"epoch_{epoch}_batch_{batch_idx}.png"))

            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {errD:.4f}, loss G: {errG:.4f}, label accuracy: {accuracy:.4f}"
            )

            step += 1

        # Accumulate losses
        epoch_train_loss_D += errD.item()
        epoch_train_loss_G += errG.item()
        epoch_train_loss_A += accuracy

    # Calculate average training losses for the epoch
    avg_train_loss_D = epoch_train_loss_D / len(dataloader)
    avg_train_loss_G = epoch_train_loss_G / len(dataloader)

    # Store the average training losses for plotting
    train_losses_D.append(avg_train_loss_D)
    train_losses_G.append(avg_train_loss_G)


    # Save the model if it has the best performance on the training set
    if avg_train_loss_G < best_loss:
        best_loss = avg_train_loss_G
        torch.save(gen.state_dict(), best_model_path)





In [None]:
import matplotlib.pyplot as plt

# Plot the training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses_D, label='Train D loss')
plt.plot(train_losses_G, label='Train G loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Curve')
plt.legend()
plt.savefig(os.path.join(save_dir, "training_curve.png"))
plt.close()