In [1]:
import pandas as pd
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2
from PIL import Image
from torch import nn
from tqdm import tqdm
from torcheval.metrics import Mean
import matplotlib.pyplot as plt
from torch.nn.utils.clip_grad import clip_grad_norm_
import matplotlib.pyplot as plt

In [2]:
train_df = pd.read_csv('train/metadata.csv')
test_df = pd.read_csv('test/metadata.csv')

In [10]:
colors = {
    'blue': 0,
    'brown':1,
    'red': 2,
    'yellow': 3,
    'green': 4
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
bs = 5
lamb = 100

In [4]:
class AshrafiSet(Dataset):

    def __init__(self, data, phase):
        self.X = list()
        self.gt = list()
        self.con = list()
        to_tensor = v2.ToTensor()
        for i, row in data.iterrows():
            if i % 25 == 0:
               X = Image.open(f'{phase}/inputs/{row["input"]}')
            y_image = Image.open(f'{phase}/targets/{row["target"]}')
            hair, shirt = colors[row['hair']], colors[row['shirt']]

            self.con.append(torch.LongTensor([hair, shirt]))
            self.X.append(to_tensor(X))
            self.gt.append(to_tensor(y_image))

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, ind):
        return self.X[ind], self.gt[ind], self.con[ind]

In [5]:
train_set = AshrafiSet(train_df, 'train')
test_set = AshrafiSet(test_df, 'test')



In [6]:
train_loader = DataLoader(train_set, batch_size=bs, shuffle=True)
test_loader = DataLoader(test_set, batch_size=10)

# NeuralNetwork


In [7]:
class Block(nn.Module):

    def __init__(self, in_f, out_f, dropout=0.5):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_f, out_f, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_f),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout(dropout),
            nn.Conv2d(out_f, out_f, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_f),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.layers(x)


class UBadNet(nn.Module):

    def __init__(self, in_channels):
        super().__init__()
        # down sample, decoder phase
        self.block1 = Block(in_channels, 64)
        self.block2 = Block(64, 128)
        self.block3 = Block(128, 256)
        self.block4 = Block(256, 512)
        self.block5 = Block(512, 1024)
        self.block6 = Block(1024, 1024)

        # Going N_block for concated tensors on decoder
        self.n_bloc_1 = Block(1024, 512)
        self.n_bloc_2 = Block(512, 256)
        self.n_bloc_3 = Block(256, 128)
        self.n_bloc_4 = Block(128, 64)

        # up sample
        self.upc_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, )
        self.upc_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upc_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upc_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.last_layer = nn.Conv2d(64, 3, kernel_size=3, padding=1)

        self.mp = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.LeakyReLU(negative_slope=0.2)
        self.tanh = nn.Tanh()

    def forward(self, x):

        x_1 = self.block1(x)  
        x_m = self.mp(x_1)  
        x_2 = self.block2(x_m) 
        x_m = self.mp(x_2)  
        x_3 = self.block3(x_m) 
        x_m = self.mp(x_3) 
        x_4 = self.block4(x_m)  
        x_m = self.mp(x_4)  
        x_5 = self.block5(x_m) 

        y_temp = self.relu(self.dropout(self.upc_1(x_5)))  
        y = self.n_bloc_1(torch.cat([y_temp, x_4], dim=1))  
        y_temp = self.relu(self.dropout(self.upc_2(y)))  
        y = self.n_bloc_2(torch.cat([y_temp, x_3], dim=1))  
        y_temp = self.relu(self.dropout(self.upc_3(y)))  
        y = self.n_bloc_3(torch.cat([y_temp, x_2], dim=1))  
        y_temp = self.relu(self.dropout(self.upc_4(y)))  
        y = self.n_bloc_4(torch.cat([y_temp, x_1], dim=1))  

        y = self.last_layer(y)  
        y = self.tanh(y)
        return y


class Generator(nn.Module):

    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(len(colors), 128)
        self.unet = UBadNet(33)
        self.down_sampler = nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=4)
        self.up_sampler = nn.Sequential(
            nn.ConvTranspose2d(3, 3, kernel_size=2, stride=2),
            nn.ConvTranspose2d(3, 3, kernel_size=2, stride=2)
            )

    def forward(self, x, conditions=None):

        conditions = self.embedding(conditions)
        labels = conditions.unsqueeze(1).repeat(1, 1, 64, 1)
        x = self.down_sampler(x)
        x = torch.cat([x, labels], dim=1)
        output = self.unet(x)
        output = self.up_sampler(output)
        return output

    
class Discriminator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(len(colors), embedding_dim=128)
        block1 = Block(33, 64)
        block2 = Block(64, 128)
        self.mp = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
        block3 = Block(128, 256)
        block4 = Block(256, 512)
        block5 = Block(512, 256)
        block6 = Block(256, 64)
        self.blocks = nn.ModuleList([block1, block2, block3, block4, block5, block6])
        self.down_sampler = nn.Conv2d(6, 32, kernel_size=3, padding=1, stride=4)
        self.head = nn.Linear(256, 1)

        
    def forward(self, x, y, conditions):
        x = torch.cat([x, y], dim=1)
        x = self.down_sampler(x)
        conditions = self.embedding(conditions)
        conditions = conditions.unsqueeze(1).repeat(1, 1, 64, 1)

        x = torch.cat([conditions, x], dim=1)
        for module in self.blocks:
            x = module(x)
            x = self.mp(x)
            
        x = x.flatten(1)
        y = self.head(x)
        return y

In [8]:
# gen = Generator().to(device)
# disc = Discriminator().to(device)
gen = torch.load('gen.pt')
disc = torch.load('disc.pt')
gen_optimizer = torch.optim.Adam(gen.parameters())
disc_optimizer = torch.optim.Adam(disc.parameters())
l1_loss = nn.L1Loss()
disc_loss = nn.BCEWithLogitsLoss()

# Training

In [9]:
def train_one_epoch(disc, gen, loader, d_loss, l1_fn, lamb, g_optimizer, d_optimizer):
    mean_d = Mean().to(device)
    mean_g = Mean().to(device)
    with tqdm(loader, unit='batch') as tep:
        for x, y, cond in tep:
            x, y, cond = x.to(device), y.to(device), cond.to(device)
            valid = torch.ones([bs, 1]).to(device)
            fake = torch.zeros([bs, 1]).to(device)
            for i in range(5):
                loss_disc_real = d_loss(disc(x, y, cond), valid)
                gen_outputs = gen(x, cond)

                loss_disc_fake = d_loss(disc(x, gen_outputs, cond), fake)

                disc_loss = loss_disc_fake + loss_disc_real
                disc_loss.backward()
                d_optimizer.step()
                d_optimizer.zero_grad()

            g_optimizer.zero_grad()

            gen_outputs = gen(x, cond)

            gen_loss = d_loss(disc(gen_outputs, y, cond), valid)
            l1_loss = l1_fn(gen_outputs, y)

            total_gen_loss = gen_loss + (lamb * l1_loss)

            total_gen_loss.backward()
            clip_grad_norm_(gen.parameters(), 0.5)
            g_optimizer.step()
            mean_d.update(disc_loss)
            mean_g.update(total_gen_loss)
            
            tep.set_postfix(loss_gan=mean_g.compute().item(), loss_disc=mean_d.compute().item())
        return disc, gen, mean_g.compute().item(), mean_d.compute().item()

In [11]:
for i in range(1000):
        disc, gen, loss_g, loss_d = train_one_epoch(disc, 
                                                    gen, 
                                                    train_loader, 
                                                    disc_loss,
                                                    l1_loss, 
                                                    lamb, 
                                                    gen_optimizer, 
                                                    disc_optimizer)
        print()

100%|██████| 20/20 [00:07<00:00,  2.59batch/s, loss_disc=1.19e-5, loss_gan=4.15]





100%|██████| 20/20 [00:07<00:00,  2.83batch/s, loss_disc=5.72e-8, loss_gan=3.64]





100%|██████| 20/20 [00:07<00:00,  2.85batch/s, loss_disc=2.34e-7, loss_gan=3.49]





100%|██████| 20/20 [00:06<00:00,  2.87batch/s, loss_disc=3.93e-8, loss_gan=3.47]





 95%|█████▋| 19/20 [00:07<00:00,  2.68batch/s, loss_disc=1.25e-9, loss_gan=3.46]


KeyboardInterrupt: 

In [None]:
def visualize(x, y):
    fig = plt.figure(figsize=(15, 4))
    
    plt.subplot(1, 3, 1)
    plt.title('generated Image')
    plt.imshow(x.permute(1, 2, 0))
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title('real Image')

    plt.imshow(y.permute(1, 2, 0))
    plt.axis("off")
