In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
from torch import optim
import os
import numpy as np
import time
import itertools

from PIL import Image
from copy import deepcopy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [6]:
data_path = './GANdata'
os.makedirs(data_path, exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_ds = datasets.MNIST(data_path, train=True, transform=transform, download=True)

In [7]:
train_dataloader = DataLoader(train_ds, batch_size=32, shuffle=True)

In [8]:
for x, y in train_dataloader:
    print(x.shape, y.shape)
    break

#for check

torch.Size([32, 1, 28, 28]) torch.Size([32])


In [2]:
# Some blocks

class ResBlock2d(nn.Module):
    def __init__(self, channels, kernel_size, stride, use_dropout, use_bias, padding):
        super(ResBlock2d, self).__init__()
        self.conv_block = []
        self.conv_block.append(nn.Conv2d(channels, channels, kernel_size, stride, bias=use_bias, padding=padding))
        if use_dropout > 0:
            self.conv_block.append(nn.Dropout(use_dropout))
        self.conv_block.append(nn.Conv2d(channels, channels, kernel_size, stride, bias=use_bias, padding=padding))
        
        self.conv_block = nn.Sequential(*self.conv_block)
        
        
    def forward(self, x):
        out = x + self.conv_block(x)
        return out


class ResNorm2d(nn.Module):
    def __init__(self, channels, kernel_size, stride, use_dropout=0, padding=0, use_bias=False, norm_type=nn.InstanceNorm2d, activation_type=nn.ReLU, activation_value=True):
        super(ResNorm2d, self).__init__()
        self.res = ResBlock2d(channels, kernel_size, stride, use_dropout, use_bias, padding)
        self.norm = norm_type(channels)
        if activation_type != nn.ReLU:
            self.activation = activation_type(activation_value, inplace=True)
        else:
            self.activation = activation_type(activation_value)
        
        self.model = nn.Sequential(self.res, self.norm, self.activation)
        
    
    def forward(self, x):
        #print(x.shape)
        return self.model(x)
        

class ConvNorm2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, use_dropout=0, use_bias=False, norm_type=nn.InstanceNorm2d, activation_type=nn.ReLU, activation_value=True):
        super(ConvNorm2d, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=use_bias, padding=padding)
        self.norm = norm_type(out_channels)
        if activation_type != nn.ReLU:
            self.activation = activation_type(activation_value, inplace=True)
        else:
            self.activation = activation_type(activation_value)
        
        self.model = nn.Sequential(self.conv2d, self.norm, self.activation)
        
    
    def forward(self, x):
        #print(x.shape)
        return self.model(x)
    

class UpConvNorm2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, output_padding=0, use_dropout=0, use_bias=False, norm_type=nn.InstanceNorm2d, activation_type=nn.ReLU, activation_value=True):
        super(UpConvNorm2d, self).__init__()
        self.upconv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=use_bias, padding=padding, output_padding=output_padding)
        self.norm = norm_type(out_channels)
        if activation_type != nn.ReLU:
            self.activation = activation_type(activation_value, inplace=True)
        else:
            self.activation = activation_type(activation_value)
        
        self.model = nn.Sequential(self.upconv2d, self.norm, self.activation)

    
    def forward(self, x):
        #print(x.shape)
        return self.model(x)
        

In [3]:
# Generator

class Generator(nn.Module):
    def __init__(self, ngf, nc, img_size=(3, 128, 128), norm_type=nn.InstanceNorm2d, use_dropout=0):
        super(Generator, self).__init__()
        self.img_size = img_size # 1 x 28 x 28 for MNIST
        self.ngf = ngf
        self.nc = nc
        self.use_bias = (norm_type == nn.InstanceNorm2d)
        self.norm_type = norm_type
        
        if img_size[-1] == 128:
            self.get_resnet(in_channels=self.img_size[0], down=2, mid=6, use_dropout=use_dropout)
        elif img_size[-1] == 256:
            self.get_resnet(in_channels=self.img_size[0], down=2, mid=9, use_dropout=use_dropout)
        else:
            raise AttributeError
        

    def get_resnet(self, in_channels, down, mid, use_dropout):
        self.model = []
        out_channels = self.ngf
        self.model.append(nn.ReflectionPad2d(3)) # paper
        self.model.append(ConvNorm2d(in_channels, out_channels, 7, 1, padding=0, use_dropout=use_dropout, use_bias=self.use_bias, norm_type=self.norm_type))
        # downsampling
        for i in range(1, down+1):
            self.model.append(ConvNorm2d(out_channels, out_channels*2, 3, 2, padding=1, use_dropout=use_dropout, use_bias=self.use_bias, norm_type=self.norm_type))
            out_channels *= 2
        
        # res
        for i in range(1, mid+1):
            self.model.append(ResNorm2d(out_channels, 3, 1, use_dropout=use_dropout, use_bias=self.use_bias, norm_type=self.norm_type, padding=1))
        # TODO - add reflection padding at ResNorm2d
        # upsampling
        for i in range(1, down+1):
            self.model.append(UpConvNorm2d(out_channels, out_channels//2, 3, 2, padding=1, output_padding=1, use_dropout=use_dropout, use_bias=self.use_bias, norm_type=self.norm_type))
            out_channels //= 2
        
        self.model.append(nn.ReflectionPad2d(3)) # paper
        self.model.append(nn.Conv2d(out_channels, self.nc, 7, 1, bias=self.use_bias, padding=0))
        self.model.append(nn.Tanh())
        
        self.model = nn.Sequential(*self.model)
    
    def forward(self, x):
        output = self.model(x)
        return output

In [4]:
with torch.no_grad():
    # test generator_resnet
    test_generator = Generator(64, 3, (3, 128, 128)).to(device)
    test_noise = torch.randn(1, 3, 128, 128).to(device)
    output = test_generator(test_noise)
    print(output.shape)

torch.Size([1, 3, 134, 134])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 3, 128, 128])


In [4]:
class Discriminator(nn.Module):
    def __init__(self, ngf, nc, img_size=(3, 128, 128), norm_type=nn.InstanceNorm2d, use_dropout=0):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.ngf = ngf
        self.nc = nc
        self.use_bias = (norm_type == nn.InstanceNorm2d)
        self.norm_type = norm_type
        
        self.get_model(img_size[0], 4, use_dropout)
        
    
    def get_model(self, in_channels, n, use_dropout):
        self.model = []
        # self.model.append(nn.Conv2d(in_channels, self.ngf, 4, 2, padding=1))
        # self.model.append(nn.LeakyReLU(0.2, True))
        #in_channels = self.ngf
        out_channels = self.ngf
        for _ in range(n-1):
            self.model.append(ConvNorm2d(in_channels, out_channels, 4, 2, use_bias=self.use_bias, padding=1, norm_type=self.norm_type, activation_type=nn.LeakyReLU, activation_value=0.2))
            in_channels = out_channels
            out_channels *= 2
        
        self.model.append(ConvNorm2d(in_channels, in_channels, 4, 1, use_bias=self.use_bias, padding=1, norm_type=self.norm_type, activation_type=nn.LeakyReLU, activation_value=0.2))
        
        self.model.append(nn.Conv2d(in_channels, 1, 4, 1, padding=1))
        self.model = nn.Sequential(*self.model)
    
    def forward(self, img):
        output = self.model(img)
        return output

In [31]:
with torch.no_grad():
    # test discriminator
    test_discriminator = Discriminator(64, 3, (3, 128, 128)).to(device)
    test_noise = torch.randn(1, 3, 128, 128).to(device)
    output = test_discriminator(test_noise)
    print(output.shape)

torch.Size([1, 3, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 16, 16])
torch.Size([1, 1, 14, 14])
torch.Size([1, 1, 14, 14])


In [5]:
class CycleGANCustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, mode='train', res=128, serial_batches=False):
        super(CycleGANCustomDataset, self).__init__()
        self.path_A = dataset_path + "/" + mode + "A/"
        self.path_B = dataset_path + "/" + mode + "B/"
        
        self.img_list_A = sorted(os.listdir(self.path_A))
        self.img_list_B = sorted(os.listdir(self.path_B))
        
        self.len_A = len(self.img_list_A)
        self.len_B = len(self.img_list_B)
        self.res = res
        self.serial_batches = serial_batches
        self.transform = transforms.Compose([
                            transforms.ToTensor(),
                          transforms.Resize(res, transforms.InterpolationMode.BICUBIC),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])
    def __len__(self):
        return max(self.len_A, self.len_B)
    
    
    def __getitem__(self, index):
        img_A = self.transform(Image.open(self.path_A+self.img_list_A[index % self.len_A]).convert('RGB'))
        if self.serial_batches:
            index_b = index % self.len_B
        else:
            index_b = int(torch.randint(0, self.len_B - 1, (1,))[0])
        img_B = self.transform(Image.open(self.path_B+self.img_list_B[index_b]).convert('RGB'))
        #print(img_A, img_B)
        return {'A': img_A, 'B': img_B}

In [6]:
class CycleGAN():
    
    def __init__(self, generator, discriminator, ngf, ndf, nc, img_size, device, lambda_cyc=10, lambda_id=0.5, gan_loss_type=nn.MSELoss):
        
        self.ngf = ngf
        self.ndf = ndf
        self.nc = nc
        self.img_size = img_size
        self.lambda_cyc = lambda_cyc
        self.lambda_id = lambda_id
        self.G_A = generator(ngf, nc, img_size, use_dropout = 0.5).to(device)
        self.G_B = generator(ngf, nc, img_size, use_dropout = 0.5).to(device)
        self.D_A = discriminator(ndf, nc, img_size).to(device)
        self.D_B = discriminator(ndf, nc, img_size).to(device)
        self.gan_loss_type = gan_loss_type
        
        self.initialize_network(init_type='normal', init_gain=0.02)
    
    def loss_GAN(self, prediction, target):
        return self.gan_loss_type()(prediction, target)
    
    def loss_cycle(self, f_g_x, x, g_f_y, y):
        return nn.L1Loss()(f_g_x, x) + nn.L1Loss()(g_f_y, y)
    
    def loss_id(self, g_y, y, f_x, x):
        return nn.L1Loss()(g_y, y) + nn.L1Loss()(f_x, x)
    
    def set_optimizer(self, lr, beta1):
        self.opt_G = optim.Adam(itertools.chain(self.G_A.parameters(), self.G_B.parameters()), lr=lr, betas=(beta1, 0.999))
        self.opt_D = optim.Adam(itertools.chain(self.D_A.parameters(), self.D_B.parameters()), lr=lr, betas=(beta1, 0.999))
    
    
    def get_scheduler(self, optimizer, num_epochs, lr_args):
        lr_policy = lr_args['lr_policy']
        if lr_policy == 'linear':
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch + lr_args['epoch_count'] - num_epochs) / float(lr_args['n_epochs_decay'] + 1) # start epoch는 1로?
                return lr_l
            scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
        elif lr_policy == 'step':
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_args['lr_decay_iters'], gamma=0.1)
        elif lr_policy == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
        elif lr_policy == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)
        else:
            return NotImplementedError('learning rate policy [%s] is not implemented', lr_args['lr_policy'])

        return scheduler
    
    def initialize_network(self, init_type='normal', init_gain=0.02):
        
        def initialize_model(model):
            classname = model.__class__.__name__
            
            if hasattr(model, 'weight') and classname.find('Conv') != -1 or classname.find('Linear') != -1:
                if init_type == 'normal':
                    nn.init.normal_(model.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(model.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(model.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(model.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(model, 'bias') and model.bias is not None:
                    nn.init.constant_(model.bias.data, 0)
            # batchnorm
            elif classname.find('BatchNorm') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
                nn.init.normal_(model.weight.data, 1.0, gain=init_gain)
                nn.init.constant_(model.bias.data, 0)
        
        self.D_A.apply(initialize_model)
        self.D_B.apply(initialize_model)
        self.G_A.apply(initialize_model)
        self.G_B.apply(initialize_model)
    
    
    def train(self, dataloader, num_epochs, lr_args):
        # log 관련
        log_dict = {'loss_G': [], 'loss_D': []}
        cnt = 0
        
        self.set_optimizer(lr_args['lr'], 0.5)
        scheduler_D = self.get_scheduler(self.opt_D, num_epochs, lr_args)
        scheduler_G = self.get_scheduler(self.opt_G, num_epochs, lr_args)
        for epoch in range(num_epochs):
            for i, data in enumerate(dataloader):
                # TODO : dataloader에 collate 작성해서 A, B를 따로 빼 올 수 있도록 하기
                #print(i, data)
                x_batch_A = data['A'].to(device)
                x_batch_B = data['B'].to(device)

                # G
                self.opt_G.zero_grad()
                self.opt_D.zero_grad()
                for param in self.D_A.parameters():
                    param.requires_grad=False
                for param in self.D_B.parameters():
                    param.requires_grad=False

                G_A_output = self.G_A(x_batch_A) # fake B
                D_B_G_A_output = self.D_B(G_A_output) # discriminate fake B
                B_G_A_output = self.G_B(G_A_output) # reconstructed A
                G_B_output = self.G_B(x_batch_B) # fake A
                D_A_G_B_output = self.D_A(G_B_output) # discriminate fake A
                A_G_B_output = self.G_A(G_B_output) # reconstructed B
                
                real_y = torch.ones_like(D_B_G_A_output).to(device)
                fake_y = torch.ones_like(D_B_G_A_output).to(device)
                
                g1 = self.loss_GAN(D_B_G_A_output, real_y.detach())
                g2 = self.loss_GAN(D_A_G_B_output, real_y.detach())
                loss_GAN = g1 + g2
                loss_cycle = self.lambda_cyc * self.loss_cycle(B_G_A_output, x_batch_A, A_G_B_output, x_batch_B)
                loss_id = self.lambda_id * self.loss_id(G_A_output, x_batch_A, G_B_output, x_batch_B)
                loss_G = loss_GAN + loss_cycle + loss_id
                
                loss_G.backward()
                self.opt_G.step()
                
                # D
                for param in self.D_A.parameters():
                    param.requires_grad=True
                for param in self.D_B.parameters():
                    param.requires_grad=True
                
                
                loss_GAN_D_A = (self.loss_GAN(self.D_A(x_batch_A), real_y) + self.loss_GAN(self.D_A(G_B_output.detach()), fake_y)) / 2
                loss_GAN_D_A.backward()
                loss_GAN_D_B = (self.loss_GAN(self.D_B(x_batch_B), real_y) + self.loss_GAN(self.D_B(G_A_output.detach()), fake_y)) / 2 # grad true인 상태에서!
                loss_GAN_D_B.backward()
                
                self.opt_D.step()
                
                
                # TODO - logging
                
                if cnt % 100 == 0:
                    log_dict['loss_G'].append(loss_G.item())
                    log_dict['loss_D'].append(loss_GAN_D_A.item() + loss_GAN_D_B.item())
                    print(f'epoch {epoch}, loss_GAN {g1.item()} {g2.item()}, loss_cycle {loss_cycle.item()}, loss_id {loss_id.item()}, loss_D {loss_GAN_D_A.item() + loss_GAN_D_B.item()}')
                
            scheduler_D.step()
            scheduler_G.step()

        
                    

                    
                    

In [7]:
# dataloader

train_dataset = CycleGANCustomDataset("../data/maps", 'train', 256, False)
print(train_dataset.len_A, train_dataset.len_B)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)

1096 1096


In [None]:
model = CycleGAN(Generator, Discriminator, 64, 64, 3, (3, 256, 256), lambda_cyc=10, lambda_id=0.5, gan_loss_type=nn.MSELoss, device=device)
lr_args = {'lr': 0.0002, 'lr_policy': 'linear', 'lr_decay_iters': 50, 'epoch_count': 1, 'n_epochs': 100, 'n_epochs_decay': 100} # 100 epoch동안 그대로, 100 epoch 동안 decay

model.train(train_dataloader, 201, lr_args)

In [13]:
torch.save(model.G_B.state_dict(), "CycleGAN/models/cyclegan_generator_A.pt")
torch.save(model.D_B.state_dict(), "CycleGAN/models/cyclegan_discriminator_A.pt")
torch.save(model.G_B.state_dict(), "CycleGAN/models/cyclegan_generator_B.pt")
torch.save(model.D_B.state_dict(), "CycleGAN/models/cyclegan_discriminator_B.pt")                   
                    