# Codes for Training the Pix2Pix GAN

In [None]:
import torch
torch.cuda.current_device()
torch.cuda.device(0)
torch.cuda.device_count()
torch.cuda.get_device_name(0)
torch.cuda.is_available()

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
import scipy.io as sio
from torch.utils.data import DataLoader, TensorDataset
import torchvision.models as models
import cv2
from numpy import inf
import torch.nn.functional as F
from torchsummary import summary
import h5py
import time


In [None]:
# Load the training data and print its shape

mat = h5py.File('ALL_TRAIN_1000.mat', 'r')
data = np.array(mat['all_multi_all'])
target = np.array(mat['all_envdb_norm_all'])
target = np.expand_dims(target, axis=1)

print('RAW data: ', data.shape)
print('Label images: ', target.shape)


In [None]:
# Hyper parameters

EPOCH = 200
BATCH_SIZE = 8
LR = 0.0001        # learning rate

torch_data = torch.from_numpy(data)
torch_delayed = torch.from_numpy(target)
train_dataset = TensorDataset(torch_data, torch_delayed)

# DataLoader for easy mini-batch return in training.
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
def avg_real(real):
    avg = torch.sum(real, dim=1)/real.shape[1]
    avg = avg.unsqueeze(1)
    
    return avg


In [None]:
# Pix2Pix GAN generator implementation

class Generator(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3, padding=1):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=3, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    # torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Sigmoid(),
                    )
            return block
        
    def bottle_neck(self, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=256, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=512, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def __init__(self, in_channel, out_channel):
        super(Generator, self).__init__()
        #Encode
        self.dropout = torch.nn.Dropout(p=0.5)
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = self.bottle_neck()
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_pool1 = self.dropout(encode_pool1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_pool2 = self.dropout(encode_pool2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        encode_pool3 = self.dropout(encode_pool3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        decode_block3 = self.dropout(decode_block3)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        decode_block2 = self.dropout(decode_block2)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        decode_block1 = self.dropout(decode_block1)
        final_layer = self.final_layer(decode_block1)
        return  final_layer
    
model_g = Generator(in_channel=128, out_channel=1).cuda()
criterion_g = nn.L1Loss().cuda()
optimizer_g = torch.optim.Adam(model_g.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_g, step_size=50, gamma=0.3)


In [None]:
# Pix2Pix GAN [image] discriminator implementation

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(  
            nn.Conv2d(2, 16, 4, stride=(2,2), padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(16, 64, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, stride=(2,4), padding=1),  
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 512, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, padding=1),  
            nn.Sigmoid()                
        )

    def forward(self, x):
        output = self.main(x)
        return output

model_d = Discriminator().cuda()
criterion_d = nn.BCELoss().cuda()
optimizer_d = torch.optim.Adam(model_d.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_d, step_size=50, gamma=0.3)

# Summary of the image discriminator architecture.
summary_d = summary(model_d, input_data=(2, 128, 256))


In [None]:
# Pix2Pix GAN [patch] discriminator implementation

class Patch_Discriminator(nn.Module):
    def __init__(self):
        super(Patch_Discriminator, self).__init__()
        self.main = nn.Sequential(  
            nn.Conv2d(2, 16, 4, stride=(2,2), padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(16, 64, 4, stride=(1,2), padding=1),  
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 512, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, stride=(1,1), padding=1),  
            nn.Sigmoid()                
        )

    def forward(self, x):
        output = self.main(x)
        return output

model_d = Patch_Discriminator().cuda()
criterion_d = nn.BCELoss().cuda()
optimizer_d = torch.optim.Adam(model_d.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_d, step_size=50, gamma=0.3)

# Summary of the patch discriminator architecture.
summary_d = summary(model_d, input_data=(2, 128, 256))


In [None]:
# Train the Pix2Pix GAN model

num = 1;
loss_d_count = np.zeros((EPOCH, num))
loss_g_count = np.zeros((EPOCH, num))
time_count = np.zeros((1, num))

for i in range(num):
    
    time_start=time.time()
    real_label = 1
    fake_label = 0
    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    for epoch in range(EPOCH):
        for step, (a, b) in enumerate(train_loader):

            real_a = avg_real(a.float().cuda())
            real_b = b.float().cuda()
            fake_b = model_g(a.float().cuda())

            #################################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ################################
            model_d.zero_grad()

            real_cat = torch.cat((real_a, real_b),1)
            label = torch.full((real_cat.size(0)*13*13,),real_label)
            # Forward pass real batch through D
            output = model_d(real_cat).view(-1)
            # Calculate loss on all-real batch
            loss_d_real = criterion_d(output, label.cuda())
            # Calculate gradients for D in backward pass
            D_x = output.mean().item()

            label.fill_(fake_label)
            fake_cat = torch.cat((real_a, fake_b),1)
            # Classify all fake batch with D
            output = model_d(fake_cat.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            loss_d_fake = criterion_d(output, label.cuda())
            # Calculate the gradients for this batch
            D_G_z1 = output.mean().item()

            # Add the gradients from the all-real and all-fake batches
            loss_d = (loss_d_real + loss_d_fake) * 0.5
            # Update D
            loss_d.backward()
            optimizer_d.step()

            #################################
            # (2) Update G network: maximize log(D(G(z)))
            ################################
            model_g.zero_grad()
            label.fill_(real_label)
            output = model_d(fake_cat).view(-1)

            # Calculate G's loss based on this output
            loss_gd = criterion_d(output, label.cuda())
            loss_gg = criterion_g(fake_cat, real_cat)
            loss_g = loss_gd + 100 * loss_gg

            # Calculate gradients for G
            loss_g.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizer_g.step()
            
        scheduler_d.step()            
        scheduler_g.step()
        loss_d_count[epoch, i] = loss_d.item()
        loss_g_count[epoch, i] = loss_g.item()
        
        print('Epoch: [%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
                      % (epoch, EPOCH, loss_d.item(), loss_g.item()))               
        
    time_end=time.time()
    time_count[0, i] = time_end-time_start
    print('time cost',time_count[0, i],'s')
    
    plt.figure(figsize=(6,3))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(loss_g_count[:, i],label="G")
    plt.plot(loss_d_count[:, i],label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.grid()
    plt.legend()
    plt.show()
    

In [None]:
# Save the well-trained generator's parameter dictionary

torch.save(model_g.state_dict(), 'Pix2PixGAN_Patch_G_dict.pth')
