# Codes for Training the CycleGAN

In [None]:
import torch
torch.cuda.current_device()
torch.cuda.device(0)
torch.cuda.device_count()
torch.cuda.get_device_name(0)
torch.cuda.is_available()

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
import scipy.io as sio
from torch.utils.data import DataLoader, TensorDataset
import torchvision.models as models
import cv2
from numpy import inf
import torch.nn.functional as F
from torchsummary import summary
import h5py
import time


In [None]:
# Load the training data and print its shape

mat = h5py.File('ALL_TRAIN_1000.mat', 'r')
data = np.array(mat['all_multi_all'])
target = np.array(mat['all_envdb_norm_all'])
target = np.expand_dims(target, axis=1)

print('RAW data: ', data.shape)
print('Label images: ', target.shape)


In [None]:
# Hyper parameters

EPOCH = 200
BATCH_SIZE = 4
LR = 0.0001        # learning rate

torch_data = torch.from_numpy(data)
torch_delayed = torch.from_numpy(target)
train_dataset = TensorDataset(torch_data, torch_delayed)

# DataLoader for easy mini-batch return in training.
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    

In [None]:
# CycleGAN generator A implementation

class Generator_A(nn.Module):
    
    def contracting_block(self, in_channels, out_channels, kernel_size=3, padding=1):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=3, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    # torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Sigmoid(),
                    )
            return block
        
    def bottle_neck(self, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=256, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=512, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def __init__(self, in_channel, out_channel):
        super(Generator_A, self).__init__()
        #Encode
        self.dropout = torch.nn.Dropout(p=0.5)
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = self.bottle_neck()
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_pool1 = self.dropout(encode_pool1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_pool2 = self.dropout(encode_pool2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        encode_pool3 = self.dropout(encode_pool3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        decode_block3 = self.dropout(decode_block3)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        decode_block2 = self.dropout(decode_block2)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        decode_block1 = self.dropout(decode_block1)
        final_layer = self.final_layer(decode_block1)        
        return  final_layer
    

In [None]:
# CycleGAN generator B implementation

class Generator_B(nn.Module):
    
    def contracting_block(self, in_channels, out_channels, kernel_size=3, padding=1):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=padding),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.ReLU(),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=3, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    # torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Sigmoid(),
                    )
            return block
        
    def bottle_neck(self, kernel_size=3, padding=1):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=256, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=512, out_channels=512, padding=padding),
                    torch.nn.BatchNorm2d(512),
                    torch.nn.ReLU(),
                    torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1)
                    )
            return block
    
    def __init__(self, in_channel, out_channel):
        super(Generator_B, self).__init__()
        #Encode
        self.dropout = torch.nn.Dropout(p=0.5)
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = self.bottle_neck()
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_pool1 = self.dropout(encode_pool1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_pool2 = self.dropout(encode_pool2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        encode_pool3 = self.dropout(encode_pool3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        decode_block3 = self.dropout(decode_block3)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        decode_block2 = self.dropout(decode_block2)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        decode_block1 = self.dropout(decode_block1)
        final_layer = self.final_layer(decode_block1)
        
        return  final_layer    
    
model_g_a = Generator_A(in_channel=128, out_channel=1).cuda()
model_g_b = Generator_B(in_channel=1, out_channel=128).cuda()

criterion_g = nn.L1Loss().cuda()
optimizer_g_a = torch.optim.Adam(model_g_a.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_g_b = torch.optim.Adam(model_g_b.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
scheduler_g_a = torch.optim.lr_scheduler.StepLR(optimizer_g_a, step_size=50, gamma=0.3)
scheduler_g_b = torch.optim.lr_scheduler.StepLR(optimizer_g_b, step_size=50, gamma=0.3)


In [None]:
# CycleGAN [patch] discriminator A&B implementation

class Discriminator_A(nn.Module):
    def __init__(self):
        super(Discriminator_A, self).__init__()
        self.main = nn.Sequential(  
            nn.Conv2d(1, 16, 4, stride=(2,2), padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(16, 64, 4, stride=(1,2), padding=1),  
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 512, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, stride=(1,1), padding=1),  
            nn.Sigmoid()                
        )

    def forward(self, x):
        output = self.main(x)
        return output

    
class Discriminator_B(nn.Module):
    def __init__(self):
        super(Discriminator_B, self).__init__()
        self.main = nn.Sequential(  
            nn.Conv2d(128, 64, 4, stride=(2,2), padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=(1,2), padding=1),  
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 256, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, stride=(2,2), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 512, 4, stride=(1,1), padding=1),  
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, stride=(1,1), padding=1),  
            nn.Sigmoid()                
        )

    def forward(self, x):
        output = self.main(x)
        return output    
    
model_d_a = Discriminator_A().cuda()
model_d_b = Discriminator_B().cuda()

criterion_d = nn.BCELoss().cuda()
optimizer_d_a = torch.optim.Adam(model_d_a.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_d_b = torch.optim.Adam(model_d_b.parameters(), lr=LR, betas=(0.5, 0.999), weight_decay=1e-5)
scheduler_d_a = torch.optim.lr_scheduler.StepLR(optimizer_d_a, step_size=50, gamma=0.3)
scheduler_d_b = torch.optim.lr_scheduler.StepLR(optimizer_d_b, step_size=50, gamma=0.3)


In [None]:
# Train the CycleGAN model

num = 1;
loss_da_count = np.zeros((EPOCH, num))
loss_db_count = np.zeros((EPOCH, num))
loss_ga_count = np.zeros((EPOCH, num))
loss_gb_count = np.zeros((EPOCH, num))
time_count = np.zeros((1, num))

for i in range(num):
   
    time_start=time.time()
    real_label = 1
    fake_label = 0
    # Lists to keep track of progress
    img_list = []
    G_A_losses = []
    G_B_losses = []
    D_A_losses = []
    D_B_losses = []
    Cycle_A_losses = []
    Cycle_B_losses = []
    iters = 0

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    for epoch in range(EPOCH):
        epoch_start_time = time.time()
        for step, (a, b) in enumerate(train_loader):

            real_a = Variable(a.float().cuda())
            real_b = Variable(b.float().cuda())
            label = torch.full((real_b.size(0)*13*13,),real_label)

            ###################################
            #  Update Generator network
            ##################################
            optimizer_g_a.zero_grad()

            # generate real A to fake B; D_A(G_A(A))
            fake_b = model_g_a(real_a)
            result_d_a = model_d_a(fake_b).view(-1)
            loss_g_a = criterion_d(result_d_a, label.cuda())
            # reconstruct fake B to rec A; G_B(G_A(A))
            recon_a = model_g_b(fake_b)
            cycleloss_a = criterion_g(recon_a, real_a) * 50
            loss_ga = loss_g_a + cycleloss_a
            
            # Update G
            loss_ga.backward()
            optimizer_g_a.step()

            # generate real B to fake A; D_A(G_B(B))
            fake_a = model_g_b(real_b)
            result_d_b = model_d_b(fake_a).view(-1)
            loss_g_b = criterion_d(result_d_b, label.cuda())
            # reconstruct fake A to rec B G_A(G_B(B))
            recon_b = model_g_a(fake_a)
            cycleloss_b = criterion_g(recon_b, real_b) * 50
            # loss_g = loss_g_a + loss_g_b + cycleloss_a + cycleloss_b
            loss_gb = loss_g_b + cycleloss_b

            # Update G
            loss_gb.backward()
            optimizer_g_b.step()

            ###################################
            #  Update Discriminator A network
            ##################################
            optimizer_d_a.zero_grad()

            # train discriminator D_A
            real_d_a = model_d_a(real_b).view(-1)
            loss_d_a_real = criterion_d(real_d_a, label.cuda())
            label.fill_(fake_label)
            # fake_b = fake_A_buffer.push_and_pop(fake_b)
            fake_d_a= model_d_a(fake_b.detach()).view(-1)
            loss_d_a_fake = criterion_d(fake_d_a, label.cuda())
            loss_d_a = (loss_d_a_real + loss_d_a_fake) * 0.5

            # Update D_A
            loss_d_a.backward()
            optimizer_d_a.step()       

            ###################################
            #  Update Discriminator B network
            ################################## 
            optimizer_d_b.zero_grad()

            # train discriminator D_B
            label.fill_(real_label)
            # fake_a = fake_A_buffer.push_and_pop(fake_a)
            real_d_b = model_d_b(real_a).view(-1)
            loss_d_b_real = criterion_d(real_d_b, label.cuda())
            label.fill_(fake_label)
            fake_d_b= model_d_b(fake_a.detach()).view(-1)
            loss_d_b_fake = criterion_d(fake_d_b, label.cuda())
            loss_d_b = (loss_d_b_real + loss_d_b_fake) * 0.5

            # Update D_B
            loss_d_b.backward()
            optimizer_d_b.step() 

        scheduler_d_a.step()            
        scheduler_g_a.step()
        scheduler_d_b.step()            
        scheduler_g_b.step()
        
        loss_da_count[epoch, i] = loss_d_a.item()
        loss_db_count[epoch, i] = loss_d_b.item()
        loss_ga_count[epoch, i] = loss_g_a.item()
        loss_gb_count[epoch, i] = loss_g_b.item()

        print('Epoch: [%d/%d]\t Loss_DA: %.4f, Loss_DB: %.4f, Loss_GA: %.4f, Loss_GB: %.4f, Loss_CA: %.4f, Loss_CB: %.4f'
              % (epoch, EPOCH, loss_d_a.item(), loss_d_b.item(), loss_g_a.item(), loss_g_b.item(), cycleloss_a.item(), cycleloss_b.item()))        
        
    time_end=time.time()
    time_count[0, i] = time_end-time_start
    print('time cost',time_count[0, i],'s')
    
    plt.figure(figsize=(6,3))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(loss_ga_count[:, i],label="G_A")
    plt.plot(loss_gb_count[:, i],label="G_B")
    plt.plot(loss_da_count[:, i],label="D_A")
    plt.plot(loss_db_count[:, i],label="D_B")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.grid()
    plt.legend()
    plt.show()
    

In [None]:
# Save the well-trained generator A's parameter dictionary

torch.save(model_g_a.state_dict(), 'CycleGAN_Patch_GA_dict.pth')
