In [None]:
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import functools
import torchvision.models as models
import torch.nn.functional as F
#from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from skimage import io, transform
from torch.autograd import Variable
from torchvision.utils import save_image


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
IMG_WIDTH = 256
IMG_HEIGHT = 256


In [None]:
from transformations import *
DIR = 'datasets/dotted_noise_sketches_13_003/combined/'
n_gpus = 1
batch_size = 4 * n_gpus

train_ds = ImageFolder(DIR, transform=transforms.Compose([
        Train_Normalize()]))
train_dl = DataLoader(train_ds, batch_size, shuffle=True)

In [None]:
DIR = 'datasets/dotted_noise_sketches_13_003/combined1/'

batch_size = 20 * n_gpus

val_ds = ImageFolder(DIR, transform=transforms.Compose([
        Val_Normalize()]))
val_dl = DataLoader(val_ds, batch_size)

In [None]:
def imshow(inputs, target, figsize=(10, 5)):
    inputs = np.uint8(inputs)
    target = np.uint8(target)
    tar = np.rollaxis(target[0], 0, 3)
    inp = np.rollaxis(inputs[0], 0, 3)
    title = ['Input Image', 'Ground Truth']
    display_list = [inp, tar]
    plt.figure(figsize=figsize)
  
    for i in range(2):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.axis('off')
        plt.imshow(display_list[i])
    plt.axis('off')
 
    #plt.imshow(image)    

def show_batch(dl):
    j=0
    for (images_a, images_b), _ in dl:
        j += 1
        imshow(images_a, images_b)
        if j == 3:
            break
#show_batch(val_dl)

In [None]:
def weights_init(net, init_type='normal', scaling=0.02):

    def init_func(m): 
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv')) != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, scaling)
        elif classname.find('BatchNorm2d') != -1
            torch.nn.init.normal_(m.weight.data, 1.0, scaling)
            torch.nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  
def get_norm_layer():
    
    norm_type = 'batch'
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    return norm_layer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.device_count()


In [None]:
from generator import *
norm_layer = get_norm_layer()
generator = UnetGenerator(3,3, 64, norm_layer=norm_layer, use_dropout=False)#.cuda().float()
generator.apply(weights_init)
generator = torch.nn.DataParallel(generator)  # multi-GPUs

In [None]:
# if continue training
gen_state_dict = torch.load(
    'weights_pix2pix/oweights/generator_epoch_25.pth',
    map_location=device,
)
generator.load_state_dict(gen_state_dict)

In [None]:
from discriminator import *
discriminator = Discriminator(6, 64, n_layers=3, norm_layer=norm_layer)
discriminator.apply(weights_init)
discriminator = torch.nn.DataParallel(discriminator) 

In [None]:
# if continue training
dic_state_dict = torch.load(
    'weights_pix2pix/oweights/discriminator_epoch_25.pth',
    map_location=device,
)
discriminator.load_state_dict(dic_state_dict)

In [None]:
adversarial_loss = nn.BCELoss() 
l1_loss = nn.L1Loss()
def generator_loss(generated_image, target_img, G, real_target):
    gen_loss = adversarial_loss(G, real_target)
    l1_l = l1_loss(generated_image, target_img)
    gen_total_loss = gen_loss + (100 * l1_l)
    #print(gen_loss)
    return gen_total_loss
def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss
learning_rate = 2e-4
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(G_optimizer,  mode='min', factor=0.5,patience=5,verbose=True)


In [None]:
import time
num_epochs = 100
D_loss_plot, G_loss_plot = [], []
for epoch in range(1, num_epochs+1): 
  
    start = time.time()

    D_loss_list, G_loss_list = [], []
   
    for (input_img, target_img), _ in train_dl:
       
        D_optimizer.zero_grad()
        input_img = input_img.to(device)
        target_img = target_img.to(device)       
        generated_image = generator(input_img)
        
        disc_inp_fake = torch.cat((input_img, generated_image), 1)
       
        
        
        real_target = Variable(torch.ones(input_img.size(0), 1, 30, 30).to(device))
        fake_target = Variable(torch.zeros(input_img.size(0), 1, 30, 30).to(device))        
        D_fake = discriminator(disc_inp_fake.detach())
        
        D_fake_loss   =  discriminator_loss(D_fake, fake_target)
        disc_inp_real = torch.cat((input_img, target_img), 1)
        
                                         
        output = discriminator(disc_inp_real)
        D_real_loss = discriminator_loss(output,  real_target)
   
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)
      
        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        fake_gen = torch.cat((input_img, generated_image), 1)
        G = discriminator(fake_gen)
        G_loss = generator_loss(generated_image, target_img, G, real_target)                                 
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()
    
    g_scheduler.step(G_loss)
    end = time.time()
    print('time spent for epoch:{}'.format(end-start))
    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
             torch.mean(torch.FloatTensor(G_loss_list))))
    
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
    
    torch.save(generator.state_dict(), 'weights_pix2pix/oweights/generator_epoch_%d.pth' % (epoch))
    torch.save(discriminator.state_dict(), 'weights_pix2pix/oweights/discriminator_epoch_%d.pth' % (epoch))


In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 7))
plt.plot(D_loss_plot, color='orange', label='d_loss')
plt.plot(G_loss_plot, color='red', label='gloss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
    #plt.savefig('outputs_vae/loss.jpg')
plt.show()

In [None]:
def drawing_figure(image):
    gray = image 
    gray = cv2.resize(gray,(800,800))
    th, threshed = cv2.threshold(gray, 40, 255,cv2.THRESH_BINARY)
    cnts = cv2.findContours(threshed, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[-2]
    sc =[]
    for ind,c in enumerate(cnts):
        area = cv2.contourArea(c)
        #print(area)
        if(area<25000): continue
        sc.append(c.copy())
    draw_img = np.zeros(gray.shape)
    for s in sc:
        for pt in s:
            draw_img[pt[0][1],pt[0][0]] = 255
    return draw_img, sc


In [None]:
img_pair = next(iter(val_dl))

predictions = generator(img_pair[0][0])
preds = []
for img in predictions:
    img = (img+1)/2
    preds.append(img)

In [None]:
from torchvision.utils import make_grid

figsize=(20, 10)
plt.figure(figsize=figsize)
img_grid = make_grid(predictions, nrow=10, padding=5,pad_value=1)
plt.imshow(np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0)),interpolation='nearest',cmap='gray')
plt.axis('off')
plt.show()

In [None]:
real_images = []
for img in img_pair[0][1]:
    img = (img +1)/2
    real_images.append(img)

In [None]:
from torchvision.utils import make_grid

figsize=(20, 10)
plt.figure(figsize=figsize)
img_grid = make_grid(real_images, nrow=10, padding=10,pad_value=1)
plt.imshow(np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0)),interpolation='nearest',cmap='gray')
plt.axis('off')
plt.show()

In [None]:
real_noised_images = []
for img in img_pair[0][0]:
    img = (img +1)/2
    real_noised_images.append(img)


In [None]:
from torchvision.utils import make_grid

figsize=(20, 10)
plt.figure(figsize=figsize)
img_grid = make_grid(real_noised_images[:10], nrow=10, padding=10,pad_value=1)
plt.imshow(np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0)),interpolation='nearest',cmap='gray')
plt.axis('off')
plt.show()

In [None]:
imgs_drawn = []
for i in range(len(real_images[:10])):
    noise_dotted_for_drawing = cv2.resize(real_noised_images[i][0].numpy(), dsize=(800,800))
    t = real_images[i][0].detach().cpu().numpy()*255
    t = t.astype('uint8')
    contoured_img, hh = drawing_figure(t) 
    im2 = cv2.drawContours(noise_dotted_for_drawing, hh, -1, (255, 255, 0), 8)
    im2 = torch.tensor(im2).unsqueeze(0)
    imgs_drawn.append(im2)

In [None]:
figsize=(20, 10)
plt.figure(figsize=figsize)
img_grid1 = make_grid(imgs_drawn, nrow=10, padding=0)
plt.imshow(np.transpose(img_grid1.detach().cpu().numpy(), (1, 2, 0)),cmap='gray')
plt.axis('off')

In [None]:
imgs_drawn = []

for i in range(len(predictions[:10])):
    noise_dotted_for_drawing = cv2.resize(real_noised_images[i][0].numpy(), dsize=(800,800))
    t = predictions[i][0].detach().cpu().numpy()*255
    t = t.astype('uint8')
    contoured_img, hh = drawing_figure(t) 
    im2 = cv2.drawContours(noise_dotted_for_drawing, hh, -1, (255, 255, 0), 8)
    im2 = torch.tensor(im2).unsqueeze(0)
    imgs_drawn.append(im2)


In [None]:
figsize=(20, 10)
fig = plt.figure(figsize=figsize)
img_grid1 = make_grid(imgs_drawn, nrow=10, padding=0)
ax = plt.imshow(np.transpose(img_grid1.detach().cpu().numpy(), (1, 2, 0)),cmap='gray')
plt.axis('off')