In [None]:
# Importing necessary libraries
from os import path
import torch
import torchvision.datasets as dset
from torch.autograd import Variable
import torch.nn as nn
import  torch.optim as optim
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from torchvision import transforms
from PIL import Image
import torchvision.utils as vutils
from IPython.display import clear_output
import datetime
from torch.utils.data import Dataset, DataLoader

In [None]:
# Checking torch version
torch.__version__

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# Initializing root directories
training_dir = "./train"
test_dir = "./test"

In [None]:
# Making custom pytorch dataset
folder_dataset = dset.ImageFolder(root=training_dir)

In [None]:
# Making custom pytorch datset class
scale = 255
class get_Dataset(Dataset):
    def __init__(self, imageFolderDataset):
        self.imageFolderDataset = imageFolderDataset 
        self.len = int(len(self.imageFolderDataset.imgs)/2)
        self.object = np.ones((self.len, 3, 256, 256))
        self.target = np.ones((self.len, 3, 256, 256))
        for i in range(0, self.len, 1):
            x = cv2.resize(cv2.imread(self.imageFolderDataset.imgs[i+2270][0]), (256,256))
            bo,go,ro = cv2.split(x)           # get b, g, r
            rgb_imgo = cv2.merge([ro,go,bo]) 
            self.object[i] = rgb_imgo.transpose(2, 1, 0)
            y = cv2.resize(cv2.imread(self.imageFolderDataset.imgs[i][0]), (256,256))
            bp,gp,rp = cv2.split(y)           # get b, g, r
            rgb_imgp = cv2.merge([rp,gp,bp]) 
            self.target[i] = rgb_imgp.transpose(2, 1, 0)
            #print(i)
        # Normalization between -1 to 1
        self.object = torch.from_numpy(((self.object/(scale / 2)) -1 )).float()
        self.target = torch.from_numpy(((self.target/(scale / 2)) -1 )).float()

    def __getitem__(self,index):
        return self.object[index], self.target[index]        
        
    def __len__(self):
        return self.len  

In [None]:
train_dataset = get_Dataset(imageFolderDataset = folder_dataset)    

In [None]:
train_loader = DataLoader(train_dataset , batch_size = 32 , shuffle = True )

In [None]:
class DownSampleConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        Paper details:
        - C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x

In [None]:
class UpSampleConv(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        activation=True,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.ReLU(True)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
# Generator
class G(nn.Module):

    def __init__(self):
        """
        Paper details:
        - Encoder: C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        - Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
        """
        super().__init__()

        # encoder/donwsample convs
        self.encoders = [
            DownSampleConv(3, 64, batchnorm=False),  # bs x 64 x 128 x 128
            DownSampleConv(64, 128),  # bs x 128 x 64 x 64
            DownSampleConv(128, 256),  # bs x 256 x 32 x 32
            DownSampleConv(256, 512),  # bs x 512 x 16 x 16
            DownSampleConv(512, 512),  # bs x 512 x 8 x 8
            DownSampleConv(512, 512),  # bs x 512 x 4 x 4
            DownSampleConv(512, 512),  # bs x 512 x 2 x 2
            DownSampleConv(512, 512, batchnorm=False),  # bs x 512 x 1 x 1
        ]

        # decoder/upsample convs
        self.decoders = [
            UpSampleConv(512, 512, dropout=True),  # bs x 512 x 2 x 2
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 4 x 4
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 8 x 8
            UpSampleConv(1024, 512),  # bs x 512 x 16 x 16
            UpSampleConv(1024, 256),  # bs x 256 x 32 x 32
            UpSampleConv(512, 128),  # bs x 128 x 64 x 64
            UpSampleConv(256, 64),  # bs x 64 x 128 x 128
        ]
        self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
        self.final_conv = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)

    def forward(self, x):
        skips_cons = []
        for encoder in self.encoders:
            x = encoder(x)

            skips_cons.append(x)

        skips_cons = list(reversed(skips_cons[:-1]))
        decoders = self.decoders[:-1]

        for decoder, skip in zip(decoders, skips_cons):
            x = decoder(x)
            # print(x.shape, skip.shape)
            x = torch.cat((x, skip), axis=1)

        x = self.decoders[-1](x)
        # print(x.shape)
        x = self.final_conv(x)
        return self.tanh(x)

In [None]:
#weights initializiation
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
# Creating the generator
netG1 = G().float().cuda()
netG1.load_state_dict(torch.load("./models/gen_2dat_256.pt"))
netG2 = G().float().cuda()
netG2.load_state_dict(torch.load("./models/gen_2dat_256.pt"))
#netG1.apply(_weights_init)
clear_output()
netG1.eval()
#netG(Variable(input)).shape

In [None]:
# Discriminator
class D(nn.Module):

    def __init__(self):
        super().__init__()
        self.d1 = DownSampleConv(6, 64, batchnorm=False)
        self.d2 = DownSampleConv(64, 128)
        self.d3 = DownSampleConv(128, 256)
        self.d4 = DownSampleConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size = 1)
        self.sig = nn.Sigmoid()

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        xs = self.sig(xn)
        return xs

In [None]:
# Creating the discriminator
netD = D().float().cuda()
netD.load_state_dict(torch.load("./models/dis_2dat_256.pt"))
# netD.apply(_weights_init)
clear_output()
netD.eval()
#netD(torch.cat((Variable(input), Variable(label)) , dim=1))

In [None]:
criterion = nn.BCEWithLogitsLoss() # We create a criterion object that will measure the error between the prediction and the target.
L1loss=nn.L1Loss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999)) # We create the optimizer object of the discriminator.
optimizerG1 = optim.Adam(netG1.parameters(), lr = 0.0002, betas = (0.5, 0.999)) # We create the optimizer object of the generator.
optimizerG2 = optim.Adam(netG2.parameters(), lr = 0.0002, betas = (0.5, 0.999)) # We create the optimizer object of the generator.

In [None]:
# clearing the previous result
sketch_result = "./results9/"

In [None]:
max_epoch = 200
alpha = 100
for epoch in range(max_epoch): # We iterate over 25 epochs.
    netG1.train()
    netG2.train()
    netD.train()

    for i, data in enumerate(train_loader, 0):     # We iterate over the images of the dataset.

        # 1st Step: Updating the weights of the neural network of the discriminator
        netD.zero_grad() 
        # Training the discriminator with a real image of the dataset
        obj1 , label1  =  data
        obj = Variable(obj1.cuda())
        label = Variable(label1.cuda())
        #label1 = label1[:,0,:,:]
        #label1 = label1.unsqueeze(1)
        
        #label1 = transforms.functional.rgb_to_grayscale(label1, 3)
        #label1 = torch.repeat_interleave(label1[:,None,:,:], 3, dim=1)
        #label1 = label1[:,:,0,:,:]
        
        #label_grey = Variable(label1.cuda())
        #print(label1.shape)
        
        target = torch.ones((label.size()[0] ,1, 16 , 16)).fill_(1).float().cuda()


    #---------------------------------------------------------------------------------------------------------------------------
        targetv = Variable(target)
        output = netD(obj, label) 
        # print("targetv.shape : ", targetv.shape )
        # print(output.shape)
        # print(label.shape)

        errD_color = criterion(output , targetv)
        errD_color = errD_color*0.5
        errD_color.backward(retain_graph=True)

    #--------------------------------------------------------------------------------------------------------------------        
        fake1 = netG1(obj) 
        targetv = Variable(target.fill_(0)) 
        fake2 = netG2(fake1)
        output = netD(obj , fake2)

        errD_fake = criterion(output, targetv) 
        errD_fake = errD_fake*0.5
        errD = (errD_color + errD_fake)
        #errD.backward(retain_graph = True)
        errD_fake.backward(retain_graph=True)
        optimizerD.step()

    #---------------------------------------------------------------------------------------------------------------------------        

        # 2nd Step: Updating the weights of the neural network of the generator

        netG1.zero_grad() 
        netG2.zero_grad()
        targetv = Variable(target.fill_(1))
        output = netD(obj, fake1) 
        #errG1 = criterion(output, targetv)
        errG1 = criterion(output, targetv)
        errG1_L1 = L1loss(fake1, label)
        errG2_L1 = L1loss(fake2 , label)
        errG = errG1  +  alpha *  errG1_L1 + alpha * errG2_L1
        errG.backward() 
        optimizerG1.step()
        optimizerG2.step()
        print('[%d/%d] [%d/%d] Loss_D: %.20f Loss_G: %.20f' % (epoch, max_epoch, i, len(train_loader), errD.item(), errG.item()/alpha))

        if i % 70 == 0 and i != 0:
            temp =obj.cpu().data
            #temp = np.transpose(temp, [0, 3, 2, 1])
            vutils.save_image(temp , '%sepoch_%2d_%03d_object.png' % ( sketch_result ,epoch   , i) , normalize=True)
            temp =fake1.cpu().data
            #print(temp.shape)
            #temp = np.transpose(temp, [0, 3, 2, 1])
            vutils.save_image(temp , '%sepoch_%2d_%03d_fromGenerator1.png' % ( sketch_result ,epoch   , i) , normalize=True)
            temp =fake2.cpu().data
            #print(temp.shape)
            #temp = np.transpose(temp, [0, 3, 2, 1])
            vutils.save_image(temp , '%sepoch_%2d_%03d_fromGenerator2.png' % ( sketch_result ,epoch   , i) , normalize=True)
            temp =label.cpu().data        
            #temp = np.transpose(temp, [0, 3, 2, 1])            
            vutils.save_image(temp , '%sepoch_%2d_%03d_target.png' % ( sketch_result ,epoch   , i) , normalize=True)
            # saving the loss in file
            f=open(sketch_result + "LOG.txt", "a+")
            f.write('[%d/%d] [%d/%d] Loss_D: %.20f Loss_G: %.20f\n' % (epoch, max_epoch, i, len(train_loader), errD.item(), errG.item()/alpha))
            f.close()
            clear_output()
        
            if epoch % 75 == 0 and epoch != 0:
                torch.save(netD.cuda().state_dict(), "./models/dis_s2f_"+str(epoch)+"_final.pt")
                print("Discriminator Saved Successfully")
                torch.save(netG1.cuda().state_dict(), "./models/gen1_s2f_"+str(epoch)+"_final.pt")
                print("Generator Saved Successfully")
                torch.save(netG2.cuda().state_dict(), "./models/gen2_s2f_"+str(epoch)+"_final.pt")
                print("Generator Saved Successfully")
 

In [None]:
# Saving the model
torch.save(netD.cuda().state_dict(), "./models/dis_s2f_final.pt")
print("Discriminator Saved Successfully")
torch.save(netG1.cuda().state_dict(), "./models/gen1_s2f_final.pt")
print("Generator Saved Successfully")
torch.save(netG2.cuda().state_dict(), "./models/gen2_s2f_final.pt")
print("Generator Saved Successfully")

In [None]:
generator_1 = G().float().cuda()
generator_2 = G().float().cuda()

In [None]:
generator_1.load_state_dict(torch.load("./models/gen1_s2f_final.pt"))
generator_1.eval()
generator_2.load_state_dict(torch.load("./models/gen2_s2f_final.pt"))

In [None]:
path = "./test/sketches/0001.jpg"
img = cv2.resize(cv2.imread(path), (256,256))

In [None]:
b,g,r = cv2.split(img)           # get b, g, r
rgb_img = cv2.merge([r,g,b]) 

In [None]:
import matplotlib.pyplot as plt
plt.imshow(rgb_img)
plt.axis('off')
plt.show()

In [None]:
rgb_img = rgb_img.transpose(2,1,0)
rgb_img.shape

In [None]:
img_test = torch.from_numpy(((rgb_img/(255 / 2)) -1 )).float()
img_test=img_test.expand(1,3,256,256)

In [None]:
fake1 = generator_1(img_test.to('cuda'))
fake2 = generator_2(fake1.to('cuda'))

In [None]:
fake2.shape

In [None]:
fake_viz = fake2[0].cpu().detach().numpy()
fake_viz.shape

In [None]:
fake_viz = fake_viz.transpose(2,1,0)
fake_viz.shape

In [None]:
plt.imshow(fake_viz)
plt.axis('off')
plt.show()

In [None]:
fake_viz = np.transpose(fake_viz, [2,0,1])
fake_viz = torch.from_numpy(fake_viz)
vutils.save_image(fake_viz , 'testp3.png' , normalize=True)