# Coherent Semantic Attention - PyTorch Implementation

Student: Klaudia Palak

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.insert(0,"/content/drive/My Drive/Magisterka/Coherent_Semantic_Attention_code")

In [None]:
from IPython.display import clear_output 
# Unzip file - if you use Google GPU
!unrar e "/content/drive/My Drive/Magisterka/Coherent_Semantic_Attention_code/mask_dataset.rar" /content/mask/
!unzip "/content/drive/My Drive/Magisterka/Coherent_Semantic_Attention_code/paris_train_original.zip" -d /content/train_dataset/
clear_output()

# Libraries

In [None]:
import time
import torch
from utils.data_load import DataLoad
import os
import torchvision
from torch.utils import data
from torchvision.utils import save_image, make_grid
import torchvision.transforms as transforms

import pandas as pd

import plotly
import plotly.express as px
import plotly.graph_objects as go

# Settings

In [None]:
class Opion():
    
    def __init__(self):
            
        self.dataroot= r'/content/train_dataset/paris_train_original' # image dataroot
        self.maskroot= r'/content/mask'# mask dataroot
        self.batchSize= 1   # Need to be set to 1
        self.fineSize=256 # image size
        self.input_nc=3  # input channel size for first stage
        self.input_nc_g=6 # input channel size for second stage
        self.output_nc=3 # output channel size
        self.ngf=64 # inner channel
        self.ndf=64 # inner channel
        self.which_model_netD='basic' # patch discriminator
        self.which_model_netF='feature'# feature patch discriminator
        self.which_model_netG='unet_csa'# seconde stage network
        self.which_model_netP='unet_256'# first stage network
        self.triple_weight=1
        self.name='CSA_inpainting'
        self.n_layers_D='3' # network depth
        self.gpu_ids=[0]
        self.model='csa_net'
        self.checkpoints_dir=r'/content/drive/MyDrive/Magisterka/Coherent_Semantic_Attention_code/checkpoints' # checkpoints folder
        self.norm='instance'
        self.fixed_mask=1
        self.use_dropout=False
        self.init_type='normal'
        self.mask_type='random' # or 'center'
        self.lambda_A=100
        self.threshold=5/16.0
        self.stride=1
        self.shift_sz=1 # size of feature patch
        self.mask_thred=1
        self.bottleneck=512
        self.gp_lambda=10.0
        self.ncritic=5
        self.constrain='MSE'
        self.strength=1
        self.init_gain=0.02
        self.cosis=1
        self.gan_type='lsgan'
        self.gan_weight=0.2 # the weight with which the GAN loss function is taken into account in the calculation of the total generator loss
        self.ssim_weight = 100 # the weight with which the SSIM loss function is taken into account in the calculation of the total generator loss
        self.lorentzian_weight = 10 # the weight with which the Lorentzian loss function is taken into account in the calculation of the total generator loss
        self.overlap=4
        self.skip=0
        self.display_freq = 100
        self.print_freq = 2
        self.save_latest_freq = 5
        self.save_epoch_freq=1
        self.continue_train=True
        self.epoch_count=1
        self.phase='train'
        self.which_epoch=''
        self.niter = 20
        self.niter_decay = 100
        self.beta1=0.5
        self.lr=0.0002
        self.lr_policy='lambda'
        self.lr_decay_iters=50
        self.isTrain=True
        self.ssim_loss=True # True or False if we want to use / don't use SSIM loss additionally
        self.lorentzian_loss=False # True or False if we want to use / don't use Lorentzian loss additionally
        self.l1_weight=0.1 # similarly (1-self.l1_weight) for SSIM loss, SSIM i L1 loss sum up to 1
        
        # Parametry, które odpowiadają za dotrenowywanie sieci dalej:
#         self.which_epoch='10' # numer ostatniej epoki, którą wytranował i której modele znajdują się w folderze checkpoints
#         self.continue_train=True
#         self.epoch_count=11 # numer kolejnej epoki, od której ma zacząć trenować dalej

# Dataset

In [None]:
opt = Opion()
transform_mask = transforms.Compose(
    [transforms.Resize((opt.fineSize,opt.fineSize)),
     transforms.ToTensor(),
    ])
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.Resize((opt.fineSize,opt.fineSize)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

dataset_train = DataLoad(opt.dataroot, opt.maskroot, transform, transform_mask)
iterator_train = (data.DataLoader(dataset_train, batch_size=opt.batchSize,shuffle=True))
print(len(dataset_train))

# Model

In [None]:
from models.model import create_model

In [None]:
model = create_model(opt)

# Training

In [None]:
stotal_steps = 0
iter_start_time = time.time()

# Dictionary for evaluation metrics
evaluation_metrics = {"loss_G_GAN":[], "loss_G_L1":[], "loss_G_SSIM":[], "loss_G_Lorentzian":[], 
                      "loss_D":[], "loss_D_Lorentzian":[], "loss_F":[], "loss_D_overall":[], "loss_G_overall":[]}

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    
    epoch_start_time = time.time()
    epoch_iter = 0

    for image, mask in (iterator_train):
        image=image.cuda()
        mask=mask.cuda()
        mask=mask[0][0]
        mask=torch.unsqueeze(mask,0)
        mask=torch.unsqueeze(mask,1)
        mask=mask.byte()

        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(image,mask) # it not only sets the input data with mask, but also sets the latent mask.
        model.set_gt_latent()
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            real_A,real_B,fake_B=model.get_current_visuals()
            #real_A=input, real_B=ground truth fake_b=output
            pic = (torch.cat([real_A, real_B,fake_B], dim=0) + 1) / 2.0
            torchvision.utils.save_image(pic, '%s/Epoch_(%d)_(%dof%d).jpg' % (
            opt.checkpoints_dir, epoch, total_steps + 1, len(dataset_train)), nrow=2)
        if total_steps %1== 0:
            errors = model.get_current_errors()
            # Add errors to dictionary
            evaluation_metrics["loss_G_GAN"].append(errors["G_GAN"])
            evaluation_metrics["loss_G_L1"].append(errors["G_L1"])
            evaluation_metrics["loss_D"].append(errors["D"])
            if opt.lorentzian_loss:
                evaluation_metrics["loss_G_Lorentzian"].append(errors["G_Lorentzian"])
                evaluation_metrics["loss_D_Lorentzian"].append(errors["D_Lorentzian"])
            if opt.ssim_loss:
                evaluation_metrics["loss_G_SSIM"].append(errors["G_SSIM"])
            # Add overall errors for G and D
            evaluation_metrics["loss_F"].append(errors["F"])
            evaluation_metrics["loss_D_overall"].append(errors["D_overall"])
            evaluation_metrics["loss_G_overall"].append(errors["G_overall"])
            t = (time.time() - iter_start_time) / opt.batchSize
            print(errors)

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
            (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

    model.update_learning_rate()

# Visualization of training losses

In [None]:
# Check mean values for evaluation metrics
loss_G_GAN = sum(evaluation_metrics["loss_G_GAN"]) / len(evaluation_metrics["loss_G_GAN"])
print("Mean value for G GAN loss is: {}".format(loss_G_GAN))
loss_G_L1 = sum(evaluation_metrics["loss_G_L1"]) / len(evaluation_metrics["loss_G_L1"])
print("Mean value for G L1 loss is: {}".format(loss_G_L1))
loss_D = sum(evaluation_metrics["loss_D"]) / len(evaluation_metrics["loss_D"])
print("Mean value for D loss is: {}".format(loss_D))

if opt.lorentzian_loss:
    loss_G_Lorentzian = sum(evaluation_metrics["loss_G_Lorentzian"]) / len(evaluation_metrics["loss_G_Lorentzian"])
    print("Mean value for G Lorentzian loss is: {}".format(loss_G_Lorentzian))
    loss_D_Lorentzian = sum(evaluation_metrics["loss_D_Lorentzian"]) / len(evaluation_metrics["loss_D_Lorentzian"])
    print("Mean value for D Lorentzian loss is: {}".format(loss_D_Lorentzian))
if opt.ssim_loss:
    loss_G_SSIM = sum(evaluation_metrics["loss_G_SSIM"]) / len(evaluation_metrics["loss_G_SSIM"])
    print("Mean value for G SSIM loss is: {}".format(loss_G_SSIM))

loss_F = sum(evaluation_metrics["loss_F"]) / len(evaluation_metrics["loss_F"])
print("Mean value for F overall loss is: {}".format(loss_F))
loss_D_overall = sum(evaluation_metrics["loss_D_overall"]) / len(evaluation_metrics["loss_D_overall"])
print("Mean value for D overall loss is: {}".format(loss_D_overall))
loss_G_overall = sum(evaluation_metrics["loss_G_overall"]) / len(evaluation_metrics["loss_G_overall"])
print("Mean value for G overall loss is: {}".format(loss_G_overall))

In [None]:
def save_path(save_dir, loss_name):
    if os.path.exists(save_dir) is False:
        os.makedirs(save_dir)
    save_filename = '%s_plot.png' % (loss_name)
    save_path = os.path.join(save_dir, save_filename)
    return save_path

In [None]:
def loss_plot(evaluation_metrics_list, loss_name):
    counter = [x for x in range(1, len(evaluation_metrics_list)+1)]
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=counter, y=evaluation_metrics_list, mode='lines', name=loss_name))

    fig.update_layout(
        width=1000,
        height=500,
        title=loss_name,
        xaxis_title="Number of training examples seen",
        yaxis_title=loss_name),
    fig.show()
    # path = save_path(opt.checkpoints_dir, loss_name)
    # fig.write_image(path)

In [None]:
loss_plot(evaluation_metrics["loss_G_GAN"], "Generator Adversarial Loss")
loss_plot(evaluation_metrics["loss_G_L1"], "Generator L1 Loss")
loss_plot(evaluation_metrics["loss_D"], "Patch Discriminator Adversarial Loss")
if opt.lorentzian_loss:
    loss_plot(evaluation_metrics["loss_G_Lorentzian"], "Generator Lorentzian Loss")
    loss_plot(evaluation_metrics["loss_D_Lorentzian"], "Patch Discriminator Lorentzian Loss")
if opt.ssim_loss:
    loss_plot(evaluation_metrics["loss_G_SSIM"], "Generator SSIM Loss")

loss_plot(evaluation_metrics["loss_F"], "Feature Discriminator Overall Loss")
loss_plot(evaluation_metrics["loss_D_overall"], "Patch Discriminator Overall Loss")
loss_plot(evaluation_metrics["loss_G_overall"], "Generator Overall Loss")