In [None]:
import os
import cv2
import sys
import matplotlib.pyplot as plt
import glob
import itertools
import random
import numpy as np
import torch
import time
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
from model import *

In [None]:
batch_size = 3
lr=0.0002
num_epoch = 20
decay_epochs = 9
roots = 'data/ukiless'
dataname = 'ukiless'

## helper function for horse2zebra dataset.

In [None]:
def remove_images_from_folder(folder):# remove gray image 
    images = []
    for filename in os.listdir(folder):
        #print(filename)
        img = plt.imread(os.path.join(folder,filename))
        images.append(img)
        if len(img.shape)!=3:
            print(folder+filename,img.shape)
            os.remove(folder+'/'+filename)
            print('delete W&B image',folder+'/'+filename)
    img_size = img.shape[1]
    return img_size
direc = 'data/horse2zebra/train/B'
img_shape = remove_images_from_folder(direc)
#print(len(imgs),img_shape)
#print([i.shape for i in imgs if i.shape!=(256,256,3)])

## Create dataset

In [None]:
class Dataset(Dataset):
    def __init__(self, root, transform=None, unaligned=False, mode="train"):
        self.f_A = sorted(glob.glob(os.path.join(root, f"{mode}/A") + "/*.*"))#include all images
        #print(self._A)
        self.f_B = sorted(glob.glob(os.path.join(root, f"{mode}/B") + "/*.*"))
        self.transform = transform
        self.unaligned = unaligned

    def __getitem__(self, index):
        A = self.transform(Image.open(self.f_A[index % len(self.f_A)]))
        if self.unaligned:
            B = self.transform(Image.open(self.f_B[random.randint(0, len(self.f_B) - 1)]))#remove last element
        else:
            B = self.transform(Image.open(self.f_B[index % len(self.f_B)]))
        return {"A": A, "B": B}# dict for 2 sets

    def __len__(self):
        return max(len(self.f_A), len(self.f_B))

In [None]:
dataset = Dataset(root=roots,
                       transform=transforms.Compose([
                           transforms.Resize(int(img_shape * 1.12), transforms.InterpolationMode.BICUBIC),#random cropping
                           transforms.RandomCrop(img_shape),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#normalized in RGB channel
                       ]),
                           unaligned=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [None]:
random.seed(random.randint(1, 10000))
torch.manual_seed(random.randint(1, 10000))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print(device)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.zeros_(m.bias)#initialize bias
        torch.nn.init.normal_(m.weight, 1.0, 0.02)#set  mean and variance of Gaussian 
    


In [None]:
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

netG_A2B.apply(weights_init)#self define initialize
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)
print("init finish")

In [None]:
cycle_loss = torch.nn.L1Loss().to(device)#L1 loss
iden_loss = torch.nn.L1Loss().to(device)
adv_loss = torch.nn.MSELoss().to(device)
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))
print('Optimizer and loss')

In [None]:
class DecayLR:#Linear decay lr after some point from 2e-4 to 2e-6(almost 0)
    def __init__(self, epochs, offset, decay_epochs):
        epoch_flag = epochs - decay_epochs
        assert (epoch_flag > 0), "decay should start later"
        self.epochs = epochs
        self.offset = offset
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        return 1.0 - max(0, epoch  - self.decay_epochs) / (
                self.epochs - self.decay_epochs)


In [None]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert (max_size > 0), "must >0."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:#single image
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:# half probability
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())#Bootstrapping
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [None]:
lr_lambda = DecayLR(num_epoch, 0, decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)# linear lr decrease after some point
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

In [None]:
g_losses = []
d_losses = []
identity_losses = []
gan_losses = []
cycle_losses = []
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [None]:
for epoch in range(num_epoch):
    t1 = time.time()
    progress = tqdm(enumerate(dataloader))
    for idx,imgs in progress:
        #Get data
        real_A = imgs["A"].to(device)
        real_B = imgs["B"].to(device)
        batch_size_real = real_A.shape[0]
        real_label = torch.full((batch_size_real, 1), 1, device=device, dtype=torch.float32)
        fake_label = torch.full((batch_size_real, 1), 0, device=device, dtype=torch.float32)
        #######Set generator

        # Set G_A and G_B's gradients to zero
        optimizer_G.zero_grad()

        # Identity loss
        # G_B2A(A) should equal A if real A is fed
        identity_image_A = netG_B2A(real_A)
        loss_identity_A = iden_loss(identity_image_A, real_A) * 5.0 #set weight
        # G_A2B(B) should equal B if real B is fed
        #print('!'*20,real_image_B.shape)
        identity_image_B = netG_A2B(real_B)
        loss_identity_B = iden_loss(identity_image_B, real_B) * 5.0

        # GAN loss
        # GAN loss D_A(G_A(A))
        fake_image_A = netG_B2A(real_B)
        fake_output_A = netD_A(fake_image_A)
        loss_GAN_B2A = adv_loss(fake_output_A, real_label)
        # GAN loss D_B(G_B(B))
        fake_image_B = netG_A2B(real_A)
        fake_output_B = netD_B(fake_image_B)
        loss_GAN_A2B = adv_loss(fake_output_B, real_label)

        # Cycle loss
        recovered_image_A = netG_B2A(fake_image_B)
        loss_cycle_ABA = cycle_loss(recovered_image_A, real_A) * 10.0

        recovered_image_B = netG_A2B(fake_image_A)
        loss_cycle_BAB = cycle_loss(recovered_image_B, real_B) * 10.0

        # Combined loss and calculate gradients
        errG = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

        # Calculate gradients for G_A and G_B
        errG.backward()
        # Update G_A and G_B's weights
        optimizer_G.step()
        #####Set discriminator  A

        # Set D_A gradients to zero
        optimizer_D_A.zero_grad()

        # Real A image loss
        real_output_A = netD_A(real_A)
        errD_real_A = adv_loss(real_output_A, real_label)

        # Fake A image loss
        fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
        fake_output_A = netD_A(fake_image_A.detach())
        errD_fake_A = adv_loss(fake_output_A, fake_label)

        # Combined loss and calculate gradients
        errD_A = (errD_real_A + errD_fake_A) / 2

        # Calculate gradients for D_A
        errD_A.backward()
        # Update D_A weights
        optimizer_D_A.step()

        #######Discrimintor B

        # Set D_B gradients to zero
        optimizer_D_B.zero_grad()

        # Real B image loss
        real_output_B = netD_B(real_B)
        errD_real_B = adv_loss(real_output_B, real_label)

        # Fake B image loss
        fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
        fake_output_B = netD_B(fake_image_B.detach())
        errD_fake_B = adv_loss(fake_output_B, fake_label)

        # Combined loss and calculate gradients
        errD_B = (errD_real_B + errD_fake_B) / 2

        # Calculate gradients for D_B
        errD_B.backward()
        # Update D_B weights
        optimizer_D_B.step()
        with torch.no_grad():
            d_losses += [(errD_A + errD_B).item()]
            g_losses += [errG.item()]
            identity_losses += [(loss_identity_A + loss_identity_B).item()]
            gan_losses += [(loss_GAN_A2B + loss_GAN_B2A).item()]
            cycle_losses += [(loss_cycle_ABA + loss_cycle_BAB).item()]
        progress.set_description(
            f"[{epoch}/{num_epoch - 1}][{idx}/{len(dataloader) - 1}] "
            f"Loss_D: {(errD_A + errD_B).item():.4f} "
            f"Loss_G: {errG.item():.4f} "
            f"Loss_G_identity: {(loss_identity_A + loss_identity_B).item():.4f} "
            f"loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
            f"loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}"
            f'time{time.time()-t1}')
            
    # save model
    if epoch%10 == 10:
        torch.save(netG_A2B.state_dict(), f"weights/{dataname}/netG_A2B_epoch_{epoch}.pth")
        torch.save(netG_B2A.state_dict(), f"weights/{dataname}/netG_B2A_epoch_{epoch}.pth")
        torch.save(netD_A.state_dict(), f"weights/{dataname}/netD_A_epoch_{epoch}.pth")
        torch.save(netD_B.state_dict(), f"weights/{dataname}/netD_B_epoch_{epoch}.pth")
# Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()
np.save('g_losses.npy',g_losses)
np.save('d_losses.npy',d_losses)
np.save('identity_losses.npy',identity_losses)
np.save('gan_losses.npy',gan_losses)
np.save('cycle_losses.npy',cycle_losses)
# save last check pointing
torch.save(netG_A2B.state_dict(), f"weights/{dataname}/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/{dataname}/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/{dataname}/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/{dataname}/netD_B.pth")


In [None]:
dataset = 'uki'
torch.save(netG_A2B.state_dict(), f"weights/{dataset}/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/{dataset}/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/{dataset}/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/{dataset}/netD_B.pth")
plt.plot(g_losses)
plt.show()

In [None]:
plt.plot(d_losses)
plt.show()

In [None]:
plt.plot(identity_losses)
plt.show()

In [None]:
plt.plot(gan_losses)
plt.show()

In [None]:
plt.plot(cycle_losses)
plt.show()