# Домашний проект на тему CycleGAN [DL School - spring 2020]

Источник вдохновения: https://github.com/hanyoseob/pytorch-CycleGAN

## Задачи:
1. Построить свою архитектуру CycleGAN
2. Выбрать задачу для этого инструмента для интереса
3. Выбрать задачу для этого инструмента из решенных
4. Решить задачу, реализация которой уже есть, используя свою архитектуру и свой pipeline
5. Решить свою задачу по аналогии с решенной задачей на наработанных инструментах

# Часть 1. Реализация решенной задачи

В качестве задачи-исходника был выбран проект перерисовки яблок в апельсины. За основу реализации взят проект человека с именем **hanyoseob**. Проект масштабный, лежал на гитхабе. Я постарался его максимально упростить и перенести в Jypyter Notebook

В ходе работы:
1. Был скачан датасет apple2orange,
2. Созданы 2 сети со стандартными архитектурами из [статьи-первоисточника](https://arxiv.org/pdf/1703.10593.pdf): генератор и дискриминатор
3. Приведен в порядок код, который нужен для правильной работы функций тренировки и тестирования сети
4. Приведены в порядок функции **train** и **test** для правильной работы в Notebook
5. Проведено обучение сети на 90 эпохах
6. Проведен тест сети на тестовом наборе данных

## 1. Download and unzip dataset 

In [None]:
import os, sys
import requests
import zipfile

# Choose dataset to download (if dataset not already downloaded)
ds_id = input(
    'Choose dataset to download:\n(1) \
apple2orange\n(2) cezanne2photo\n(3) \
horse2zebra\n(4) monet2photo\n(5) \
summer2winter_yosemite\n(6) ukiyoe2photo\n(7) vangogh2photo\n')

if ds_id == '1':
    ds_name = 'apple2orange'
elif ds_id == '2':
    ds_name = 'cezanne2photo'
elif ds_id == '3':
    ds_name = 'horse2zebra'
elif ds_id == '4':
    ds_name = 'monet2photo'
elif ds_id == '5':
    ds_name = 'summer2winter_yosemite'
elif ds_id == '6':
    ds_name = 'ukiyoe2photo'
elif ds_id == '7':
    ds_name = 'vangogh2photo'
else:
    sys.exit('Incorrect dataset')

if not os.path.exists('datasets'):
    os.makedirs('datasets')

if ds_name not in os.listdir('datasets'):
    url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + ds_name + '.zip'
    r = requests.get(url, allow_redirects=True)
    open('datasets/%s.zip' % ds_name, 'wb').write(r.content)
else:
    print('Dataset %s already downloaded' % ds_name)

if '%s.zip' % ds_name in os.listdir('datasets'):
    unzip_q = input('Want to unzip zip-file? (y/n)\n')
    if unzip_q.lower() == 'y':
        with zipfile.ZipFile('datasets/%s.zip' % ds_name, 'r') as zip_ref:
            zip_ref.extractall('datasets')
        print('Dataset %s unzip successfully' % ds_name)
    
    remove_q = input('Want to remove zip-file? (y/n)\n')
    if remove_q.lower() == 'y':
        os.remove('datasets/%s.zip' % ds_name)
        print('Dataset %s removed successfully' % ds_name)

print('Files in datasets-dir:', ' | '.join((_ for _ in os.listdir('datasets'))))

Для правильной работы необходимо, чтобы датасеты двух классов были одинакового размера. Удалим лишние картинки

In [None]:
testA_files = os.listdir('datasets/%s/testA' % ds_name)
testB_files = os.listdir('datasets/%s/testB' % ds_name)
trainA_files = os.listdir('datasets/%s/trainA' % ds_name)
trainB_files = os.listdir('datasets/%s/trainB' % ds_name)

diff_train = len(trainA_files) - len(trainB_files)
diff_test = len(testA_files) - len(testB_files)

ind = -1
while diff_train != 0:
    ind += 1
    if diff_train > 0:
        os.remove('datasets/%s/trainA/%s' % (ds_name, trainA_files[ind]))
        diff_train -= 1
    else:
        os.remove('datasets/%s/trainB/%s' % (ds_name, trainB_files[ind]))
        diff_train += 1

ind = -1
while diff_test != 0:
    ind += 1
    if diff_test > 0:
        os.remove('datasets/%s/testA/%s' % (ds_name, testA_files[ind]))
        diff_test -= 1
    else:
        os.remove('datasets/%s/testB/%s' % (ds_name, testB_files[ind]))
        diff_test += 1

## 2. Set parameters to train

In [None]:
name_data = ds_name
DIRECTION = 'A2B'
#DIRECTION = 'B2A'

EPOCHS = 100
BATCH_SIZE = 5

XY_PIX = 256  # size of picture

wgt_c_a = 10
wgt_c_b = 10
wgt_i = 0.5

LR_GEN = 0.0002
LR_DISC = 0.0002

# Params to save pics and weights
num_freq_disp = 10  # display every N epochs
num_freq_save = 1  # save every N epochs

dir_checkpoint = './checkpoint'
dir_data = './datasets'
dir_log = './log'
dir_result = './result'

dir_chck = os.path.join(dir_checkpoint, name_data)
dir_data_train = os.path.join(dir_data, name_data, 'train')
dir_log_train = os.path.join(dir_log, name_data, 'train')
dir_data_test = os.path.join(dir_data, name_data, 'test')
dir_result_save = os.path.join(dir_result, name_data, 'images')

if not os.path.exists(dir_chck):
    os.makedirs(dir_chck)
if not os.path.exists(dir_log):
    os.makedirs(dir_log)
if not os.path.exists(dir_result_save):
    os.makedirs(dir_result_save)

In [None]:
import os
import itertools

import torch
import torch.nn as nn
import torchsummary
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms

import cv2
from statistics import mean

from datetime import datetime
import numpy as np
import time

import matplotlib.pyplot as plt
from skimage import transform
from matplotlib import rcParams
from IPython.display import clear_output

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## 3.1 Build the Generator

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # ENCODER
        # 256 x 256 x 3 -> 128 x 128 x 64
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
        )
        # 128 x 128 x 64 -> 64 x 64 x 128
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
        )
        # 64 x 64 x 128 -> 32 x 32 x 256
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
        )
        # 32 x 32 x 256 -> 16 x 16 x 512
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
        )
        # 16 x 16 x 512 -> 8 x 8 x 512
        self.enc5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
        )
        # 8 x 8 x 512 -> 4 x 4 x 512
        self.enc6 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
        )
        # 4 x 4 x 512 -> 2 x 2 x 512
        self.enc7 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
        )
        # 2 x 2 x 512 -> 2 x 2 x 512
        self.enc8 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        
        
        # DECODER
        self.dec8 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
        self.dec7 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
        self.dec6 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
        

    def forward(self, x):

        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        enc6 = self.enc6(enc5)
        enc7 = self.enc7(enc6)
        enc8 = self.enc8(enc7)

        dec8 = self.dec8(enc8)
        dec7 = self.dec7(torch.cat([enc7, dec8], dim=1))
        dec6 = self.dec6(torch.cat([enc6, dec7], dim=1))
        dec5 = self.dec5(torch.cat([enc5, dec6], dim=1))
        dec4 = self.dec4(torch.cat([enc4, dec5], dim=1))
        dec3 = self.dec3(torch.cat([enc3, dec4], dim=1))
        dec2 = self.dec2(torch.cat([enc2, dec3], dim=1))
        dec1 = self.dec1(torch.cat([enc1, dec2], dim=1))

        return dec1

In [None]:
## Code for check right architecture
#genA = UNet()
#torchsummary.summary(genA.to(device), (3,256,256))

## 3.2 Build the Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 256 x 256 x 3 -> 128 x 128 x 64
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
        )
        
        # 128 x 128 x 64 -> 64 x 64 x 128
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
        )
        
        # 64 x 64 x 128 -> 32 x 32 x 256
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
        )
        
        # 32 x 32 x 256 -> 32 x 32 x 512
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
        )
        
        # 32 x 32 x 512 -> 32 x 32 x 1
        self.enc5 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=3, padding=1),
            nn.ReLU(True),
        )
        
        
    def forward(self, x):

        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        x = self.enc5(x)

        return x

In [None]:
## Code for check right architecture
#discrA = Discriminator()
#torchsummary.summary(discrA.to(device), (3,256,256))

## 4.1 Initialize weights

In [None]:
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            nn.init.normal_(m.weight.data, 1.0, init_gain)
            nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[0]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if gpu_ids:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net

## 4.2 Create class to get data from dirs

In [None]:
class Dataset(torch.utils.data.Dataset):
    """
    dataset of image files of the form 
       stuff<number>_trans.pt
       stuff<number>_density.pt
    """

    def __init__(self, data_dir, direction='A2B', data_type='float32', nch=3, transform=[]):
        self.data_dir_a = data_dir + 'A'
        self.data_dir_b = data_dir + 'B'
        self.transform = transform
        self.direction = direction
        self.data_type = data_type
        self.nch = nch

        dataA = [f for f in os.listdir(self.data_dir_a) if f.endswith('.jpg')]
        dataA.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

        dataB = [f for f in os.listdir(self.data_dir_b) if f.endswith('.jpg')]
        dataB.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

        self.names = (dataA, dataB)

    def __getitem__(self, index):

        dataA = plt.imread(os.path.join(self.data_dir_a, self.names[0][index])).squeeze()
        dataB = plt.imread(os.path.join(self.data_dir_b, self.names[1][index])).squeeze()

        if dataA.dtype == np.uint8:
            dataA = dataA / 255.0

        if dataB.dtype == np.uint8:
            dataB = dataB / 255.0

        if len(dataA.shape) == 2:
            dataA = np.expand_dims(dataA, axis=2)
            dataA = np.tile(dataA, (1, 1, 3))
        if len(dataB.shape) == 2:
            dataB = np.expand_dims(dataB, axis=2)
            dataB = np.tile(dataB, (1, 1, 3))

        if self.direction == 'A2B':
            data = {'dataA': dataA, 'dataB': dataB}
        else:
            data = {'dataA': dataB, 'dataB': dataA}

        if self.transform:
            data = self.transform(data)

        return data

    def __len__(self):
        return len(self.names[0])

## 4.3 Create additional functions

Я скопировал функции для преобразования тензоров и массивов при работе с картинками, так как стандартные библиотеки выдавали ошибки. А здесь все прозрачно и очевидно

### 4.3.1 Normalizing

In [None]:
class Normalize(object):
    def __call__(self, data):
        # Nomalize [0, 1] => [-1, 1]

        # for key, value in data:
        #     data[key] = 2 * (value / 255) - 1
        #
        # return data

        dataA, dataB = data['dataA'], data['dataB']
        dataA = 2 * dataA - 1
        dataB = 2 * dataB - 1
        return {'dataA': dataA, 'dataB': dataB}

class Denormalize(object):
    def __call__(self, data):
        return (data + 1) / 2

### 4.3.2 Tensor <=> Numpy

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, data):
        # Swap color axis because numpy image: H x W x C
        #                         torch image: C x H x W

        # for key, value in data:
        #     data[key] = torch.from_numpy(value.transpose((2, 0, 1)))
        #
        # return data

        dataA, dataB = data['dataA'], data['dataB']

        dataA = dataA.transpose((2, 0, 1)).astype(np.float32)
        dataB = dataB.transpose((2, 0, 1)).astype(np.float32)
        return {'dataA': torch.from_numpy(dataA), 'dataB': torch.from_numpy(dataB)}

    
class ToNumpy(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, data):
        # Swap color axis because numpy image: H x W x C
        #                         torch image: C x H x W

        # for key, value in data:
        #     data[key] = value.transpose((2, 0, 1)).numpy()
        #
        # return data

        return data.to('cpu').detach().numpy().transpose(0, 2, 3, 1)

        # input, label = data['input'], data['label']
        # input = input.transpose((2, 0, 1))
        # label = label.transpose((2, 0, 1))
        # return {'input': input.detach().numpy(), 'label': label.detach().numpy()}

### 4.3.3 Image transforms

In [None]:
class RandomFlip(object):
    def __call__(self, data):
        # Random Left or Right Flip

        # for key, value in data:
        #     data[key] = 2 * (value / 255) - 1
        #
        # return data
        dataA, dataB = data['dataA'], data['dataB']

        if np.random.rand() > 0.5:
            dataA = np.fliplr(dataA)
            dataB = np.fliplr(dataB)

        # if np.random.rand() > 0.5:
        #     dataA = np.flipud(dataA)
        #     dataB = np.flipud(dataB)

        return {'dataA': dataA, 'dataB': dataB}

class Rescale(object):
    """Rescale the image in a sample to a given size
  
    Args:
      output_size (tuple or int): Desired output size.
                                  If tuple, output is matched to output_size.
                                  If int, smaller of image edges is matched
                                  to output_size keeping aspect ratio the same.
    """
  
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
  
    def __call__(self, data):
        dataA, dataB = data['dataA'], data['dataB']
    
        h, w = dataA.shape[:2]
    
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
    
        new_h, new_w = int(new_h), int(new_w)
    
        dataA = transform.resize(dataA, (new_h, new_w))
        dataB = transform.resize(dataB, (new_h, new_w))
    
        return {'dataA': dataA, 'dataB': dataB}


class RandomCrop(object):
    """Crop randomly the image in a sample
  
    Args:
      output_size (tuple or int): Desired output size.
                                  If int, square crop is made.
    """
  
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
  
    def __call__(self, data):
        dataA, dataB = data['dataA'], data['dataB']
    
        h, w = dataA.shape[:2]
        new_h, new_w = self.output_size
    
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
    
        dataA = dataA[top: top + new_h, left: left + new_w]
        dataB = dataB[top: top + new_h, left: left + new_w]
  
        return {'dataA': dataA, 'dataB': dataB}

### 4.3.4 Function for right work gradients freeze

In [None]:
def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
        nets (network list)   -- a list of networks
        requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

## 5. Save and load functions for model

In [None]:
def save(dir_chck, netG_a2b, netG_b2a, netD_a, netD_b, optimG, optimD, epoch):
        if not os.path.exists(dir_chck):
            os.makedirs(dir_chck)

        torch.save({'netG_a2b': netG_a2b.state_dict(), 'netG_b2a': netG_b2a.state_dict(),
                    'netD_a': netD_a.state_dict(), 'netD_b': netD_b.state_dict(),
                    'optimG': optimG.state_dict(), 'optimD': optimD.state_dict()},
                   '%s/model_epoch%04d.pth' % (dir_chck, epoch))

def load(dir_chck, netG_a2b, netG_b2a, netD_a=[], netD_b=[], optimG=[], optimD=[], epoch=[], mode='train'):
    if not epoch:
        ckpt = os.listdir(dir_chck)
        ckpt.sort()
        epoch = int(ckpt[-1].split('epoch')[1].split('.pth')[0])

    dict_net = torch.load('%s/model_epoch%04d.pth' % (dir_chck, epoch))

    print('Loaded %dth network' % epoch)

    if mode == 'train':
        netG_a2b.load_state_dict(dict_net['netG_a2b'])
        netG_b2a.load_state_dict(dict_net['netG_b2a'])
        netD_a.load_state_dict(dict_net['netD_a'])
        netD_b.load_state_dict(dict_net['netD_b'])
        optimG.load_state_dict(dict_net['optimG'])
        optimD.load_state_dict(dict_net['optimD'])

        return netG_a2b, netG_b2a, netD_a, netD_b, optimG, optimD, epoch

    elif mode == 'test':
        netG_a2b.load_state_dict(dict_net['netG_a2b'])
        netG_b2a.load_state_dict(dict_net['netG_b2a'])

        return netG_a2b, netG_b2a, epoch

## 6. Train function

In [None]:
def train(dataset, mode='train', train_continue=False, num_epoch=EPOCHS, lr_G=LR_GEN, lr_D=LR_DISC, batch_size=BATCH_SIZE, beta1=0.5):

    loader_train = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)

    num_train = len(dataset_train)

    num_batch_train = int((num_train / batch_size) + ((num_train % batch_size) != 0))

    ## setup network
    netG_a2b = UNet().to(device)
    netG_b2a = UNet().to(device)

    netD_a = Discriminator().to(device)
    netD_b = Discriminator().to(device)

    init_net(netG_a2b, init_type='normal', init_gain=0.02)
    init_net(netG_b2a, init_type='normal', init_gain=0.02)

    init_net(netD_a, init_type='normal', init_gain=0.02)
    init_net(netD_b, init_type='normal', init_gain=0.02)

    ## setup loss & optimization
    fn_Cycle = nn.L1Loss().to(device)   # L1
    fn_GAN = nn.BCEWithLogitsLoss().to(device)
    fn_Ident = nn.L1Loss().to(device)   # L1

    paramsG_a2b = netG_a2b.parameters()
    paramsG_b2a = netG_b2a.parameters()
    paramsD_a = netD_a.parameters()
    paramsD_b = netD_b.parameters()

    optimG = torch.optim.Adam(itertools.chain(paramsG_a2b, paramsG_b2a), lr=lr_G, betas=(beta1, 0.999))
    optimD = torch.optim.Adam(itertools.chain(paramsD_a, paramsD_b), lr=lr_D, betas=(beta1, 0.999))

    # schedG = get_scheduler(optimG, self.opts)
    # schedD = get_scheduler(optimD, self.opts)

    # schedG = torch.optim.lr_scheduler.ExponentialLR(optimG, gamma=0.9)
    # schedD = torch.optim.lr_scheduler.ExponentialLR(optimD, gamma=0.9)

    ## load from checkpoints
    st_epoch = 0

    if train_continue is True:
        netG_a2b, netG_b2a, netD_a, netD_b, optimG, optimD, st_epoch = \
            load(dir_chck, netG_a2b, netG_b2a, netD_a, netD_b, optimG, optimD, mode=mode)

    ## setup tensorboard
    writer_train = SummaryWriter(log_dir=dir_log_train)

    for epoch in range(st_epoch + 1, num_epoch + 1):
        ## training phase
        netG_a2b.train()
        netG_b2a.train()
        netD_a.train()
        netD_b.train()

        loss_G_a2b_train = []
        loss_G_b2a_train = []
        loss_D_a_train = []
        loss_D_b_train = []
        loss_C_a_train = []
        loss_C_b_train = []
        loss_I_a_train = []
        loss_I_b_train = []

        for i, data in enumerate(loader_train, 1):
            def should(freq):
                return freq > 0 and (i % freq == 0 or i == num_batch_train)

            try:
                input_a = data['dataA'].to(device)
                input_b = data['dataB'].to(device)
            except:
                break

            # forward netG
            output_b = netG_a2b(input_a)
            output_a = netG_b2a(input_b)

            recon_b = netG_a2b(output_a)
            recon_a = netG_b2a(output_b)

            # backward netD
            set_requires_grad([netD_a, netD_b], True)
            optimD.zero_grad()

            # backward netD_a
            pred_real_a = netD_a(input_a)
            pred_fake_a = netD_a(output_a.detach())

            loss_D_a_real = fn_GAN(pred_real_a, torch.ones_like(pred_real_a))
            loss_D_a_fake = fn_GAN(pred_fake_a, torch.zeros_like(pred_fake_a))
            loss_D_a = 0.5 * (loss_D_a_real + loss_D_a_fake)

            # backward netD_b
            pred_real_b = netD_b(input_b)
            pred_fake_b = netD_b(output_b.detach())

            loss_D_b_real = fn_GAN(pred_real_b, torch.ones_like(pred_real_b))
            loss_D_b_fake = fn_GAN(pred_fake_b, torch.zeros_like(pred_fake_b))
            loss_D_b = 0.5 * (loss_D_b_real + loss_D_b_fake)

            # backward netD
            loss_D = loss_D_a + loss_D_b
            loss_D.backward()
            optimD.step()

            # backward netG
            set_requires_grad([netD_a, netD_b], False)
            optimG.zero_grad()

            if wgt_i > 0:
                ident_b = netG_a2b(input_b)
                ident_a = netG_b2a(input_a)

                loss_I_a = fn_Ident(ident_a, input_a)
                loss_I_b = fn_Ident(ident_b, input_b)
            else:
                loss_I_a = 0
                loss_I_b = 0

            pred_fake_a = netD_a(output_a)
            pred_fake_b = netD_b(output_b)

            loss_G_a2b = fn_GAN(pred_fake_b, torch.ones_like(pred_fake_b))
            loss_G_b2a = fn_GAN(pred_fake_a, torch.ones_like(pred_fake_a))

            loss_C_a = fn_Cycle(input_a, recon_a)
            loss_C_b = fn_Cycle(input_b, recon_b)

            loss_G = (loss_G_a2b + loss_G_b2a) + \
                     (wgt_c_a * loss_C_a + wgt_c_b * loss_C_b) + \
                     (wgt_c_a * loss_I_a + wgt_c_b * loss_I_b) * wgt_i

            loss_G.backward()
            optimG.step()

            # get losses
            loss_G_a2b_train += [loss_G_a2b.item()]
            loss_G_b2a_train += [loss_G_b2a.item()]

            loss_D_a_train += [loss_D_a.item()]
            loss_D_b_train += [loss_D_b.item()]

            loss_C_a_train += [loss_C_a.item()]
            loss_C_b_train += [loss_C_b.item()]

            if wgt_i > 0:
                loss_I_a_train += [loss_I_a.item()]
                loss_I_b_train += [loss_I_b.item()]

            print('TRAIN: EPOCH %d: BATCH %04d/%04d: '
                  'G_a2b: %.4f G_b2a: %.4f D_a: %.4f D_b: %.4f C_a: %.4f C_b: %.4f I_a: %.4f I_b: %.4f'
                  % (epoch, i, num_batch_train,
                     mean(loss_G_a2b_train), mean(loss_G_b2a_train),
                     mean(loss_D_a_train), mean(loss_D_b_train),
                     mean(loss_C_a_train), mean(loss_C_b_train),
                     mean(loss_I_a_train), mean(loss_I_b_train)))

            if should(num_freq_disp):
                ## show output
                input_a = transform_inv(input_a)
                output_a = transform_inv(output_a)
                input_b = transform_inv(input_b)
                output_b = transform_inv(output_b)

                writer_train.add_images('input_a', input_a, num_batch_train * (epoch - 1) + i, dataformats='NHWC')
                writer_train.add_images('output_a', output_a, num_batch_train * (epoch - 1) + i, dataformats='NHWC')
                writer_train.add_images('input_b', input_b, num_batch_train * (epoch - 1) + i, dataformats='NHWC')
                writer_train.add_images('output_b', output_b, num_batch_train * (epoch - 1) + i, dataformats='NHWC')

        writer_train.add_scalar('loss_G_a2b', mean(loss_G_a2b_train), epoch)
        writer_train.add_scalar('loss_G_b2a', mean(loss_G_b2a_train), epoch)
        writer_train.add_scalar('loss_D_a', mean(loss_D_a_train), epoch)
        writer_train.add_scalar('loss_D_b', mean(loss_D_b_train), epoch)
        writer_train.add_scalar('loss_C_a', mean(loss_C_a_train), epoch)
        writer_train.add_scalar('loss_C_b', mean(loss_C_b_train), epoch)
        writer_train.add_scalar('loss_I_a', mean(loss_I_a_train), epoch)
        writer_train.add_scalar('loss_I_b', mean(loss_I_b_train), epoch)

        # # update scheduler
        # # schedG.step()
        # # schedD.step()

        ## save model weights
        if (epoch % num_freq_save) == 0:
            save(dir_chck, netG_a2b, netG_b2a, netD_a, netD_b, optimG, optimD, epoch)

    writer_train.close()

## 7. Test function

In [None]:
def test(dataset, mode='test', batch_size=BATCH_SIZE):
    loader_test = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    num_test = len(dataset_test)

    num_batch_test = int((num_test / batch_size) + ((num_test % batch_size) != 0))

    ## setup network
    netG_a2b = UNet().to(device)
    netG_b2a = UNet().to(device)

    init_net(netG_a2b, init_type='normal', init_gain=0.02)
    init_net(netG_b2a, init_type='normal', init_gain=0.02)

    ## load from checkpoints
    st_epoch = 0

    netG_a2b, netG_b2a, st_epoch = load(dir_chck, netG_a2b, netG_b2a, mode=mode)

    ## test phase
    with torch.no_grad():
        netG_a2b.eval()
        netG_b2a.eval()
        # netG_a2b.train()
        # netG_b2a.train()

        gen_loss_l1_test = 0
        for i, data in enumerate(loader_test, 1):
            input_a = data['dataA'].to(device)
            input_b = data['dataB'].to(device)
            
            # forward netG
            output_b = netG_a2b(input_a)
            output_a = netG_b2a(input_b)

            recon_b = netG_a2b(output_a)
            recon_a = netG_b2a(output_b)

            input_a = transform_inv(input_a)
            input_b = transform_inv(input_b)
            output_a = transform_inv(output_a)
            output_b = transform_inv(output_b)
            recon_a = transform_inv(recon_a)
            recon_b = transform_inv(recon_b)

            for j in range(input_a.shape[0]):
                name = batch_size * (i - 1) + j
                fileset = {'name': name,
                           'input_a': "%04d-input_a.png" % name,
                           'output_b': "%04d-output_b.png" % name,
                           'input_b': "%04d-input_b.png" % name,
                           'output_a': "%04d-output_a.png" % name,
                           'recon_a': "%04d-recon_a.png" % name,
                           'recon_b': "%04d-recon_b.png" % name}

                plt.imsave(os.path.join(dir_result_save, fileset['input_a']), input_a[j, :, :, :].squeeze())
                plt.imsave(os.path.join(dir_result_save, fileset['output_b']), output_b[j, :, :, :].squeeze())
                plt.imsave(os.path.join(dir_result_save, fileset['input_b']), input_b[j, :, :, :].squeeze())
                plt.imsave(os.path.join(dir_result_save, fileset['output_a']), output_a[j, :, :, :].squeeze())
                
                plt.imsave(os.path.join(dir_result_save, fileset['recon_a']), recon_a[j, :, :, :].squeeze())
                plt.imsave(os.path.join(dir_result_save, fileset['recon_b']), recon_b[j, :, :, :].squeeze())

                append_index(dir_result, fileset)

                print("%d / %d" % (name + 1, num_test))

## 8. Code to create html-result

In [None]:
def append_index(dir_result, fileset, step=False):
    index_path = os.path.join(dir_result, name_data, "index.html")
    if os.path.exists(index_path):
        index = open(index_path, "a")
    else:
        index = open(index_path, "w")
        index.write("<html><body><table><tr>")
        if step:
            index.write("<th>step</th>")
        for key, value in fileset.items():
            index.write("<th>%s</th>" % key)
        index.write('</tr>')

    # for fileset in filesets:
    index.write("<tr>")

    if step:
        index.write("<td>%d</td>" % fileset["step"])
    index.write("<td>%s</td>" % fileset["name"])

    del fileset['name']

    for key, value in fileset.items():
        index.write("<td><img src='images/%s'></td>" % value)

    index.write("</tr>")
    return index_path

## 9. Train and test model

Let's train

In [None]:
transform_train = transforms.Compose([
    Normalize(),
    RandomFlip(),
    Rescale((XY_PIX+30, XY_PIX+30)),
    RandomCrop((XY_PIX, XY_PIX)),
    ToTensor(),
])
transform_inv = transforms.Compose([ToNumpy(), Denormalize(),])

dataset_train = Dataset(dir_data_train, direction=DIRECTION, transform=transform_train)

train_cont=False
if os.listdir(dir_chck):
    train_cont=True
train(dataset=dataset_train, train_continue=train_cont)

And get test

In [None]:
transform_test = transforms.Compose([Normalize(), ToTensor()])
transform_inv = transforms.Compose([ToNumpy(), Denormalize()])

dataset_test = Dataset(dir_data_test, transform=transform_test)

test(dataset=dataset_test)

## 10. View results

Результаты работы сети можно посмотреть в файле `result/<name_ds>/index.html`

`input a | output b | input b | output a | recon a | recon b`<br>
`(apple to orange)  | (orange to apple)`

Также для просмотра Лоссов удобен инструмент **tensorboard**, в который с определенным интервалом записываются лоссы и сохраняет картинки.

Для его запуска нужно в командной строке ввести следующий код: `tensorboard --logdir ./log/<name_ds> --port 7777`

После запуска по адресу http://localhost:7777 станет доступен Dashboard, на котором во вкладке [scalars](http://localhost:7777/#scalars) находятся лоссы, а во вкладке [images](http://localhost:7777/#images) — картинки

# Часть 2. Реализация своей задачи

Вдохновение для задачи по интересу было получено из [этой статьи](https://towardsdatascience.com/turning-fortnite-into-pubg-with-deep-learning-cyclegan-2f9d339dcdb0). И готовая реализация на KERAS есть [здесь](https://github.com/bendangnuksung/fortnite-pubg)

Для себя я выбрал трансформацию игры CS:GO в OverWatch.

Для получения необходимых картинок было скачано несколько роликов игр с youtube, из видео получены картинки и вручную удалено лишнее (заставки, менюшки и тд)

## 1. Код для перевода видео в фото

In [None]:
def get_frame(vidcap, sec, start_cnt, dir_dir):
    vidcap.set(cv2.CAP_PROP_POS_MSEC,sec*1000)
    hasFrames, image = vidcap.read()
    if hasFrames:
        cv2.imwrite("%s/image%s.jpg" % (dir_dir, str(count+start_cnt)), image)     # save frame as JPG file
    return vidcap, hasFrames

Было несколько видеороликов, которые должны лежать соответственно в папках **video_cs** и **video_ow** и иметь название **N_video.mp4**, где N - номер по порядку (начиная с 0).

In [None]:
def create_photos_from_video(video_dir_name, photo_dir_name, n, frame_rate=0.25):
    start_cnt = 0
    count = 0
    success = 1
    for i in range(n):
        print(i, "from", n)
        sec = 0
        vidcap = cv2.VideoCapture('%s/%s_video.mp4' % (video_dir_name, i))
        fps = vidcap.get(cv2.CAP_PROP_FPS) 
        length_sec = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT) / fps)
        while sec <= length_sec:
            count = count + 1
            if count % 50 == 0:
                print(count, "/", int(length_sec/frame_rate))
            sec += frame_rate
            sec = round(sec, 2)
            vidcap, success = get_frame(vidcap, sec, start_cnt, photo_dir_name)
        start_cnt = count
    print("Success!", video_dir_name)

In [None]:
#create_photos_from_video(video_dir_name='video_cs', photo_dir_name='photo_cs', n=6)
#create_photos_from_video(video_dir_name='video_ow', photo_dir_name='photo_ow', n=1)

После получения изображений перенес их в папку `datasets/overwatch_to_cs` по аналогии с другими датасетами

## 2. Подготовка к обучению

Можно скачать мой датасет, используя следующий код:

In [None]:
import os, sys
import requests
import zipfile

ds_name = 'overwatch2csgo'

if ds_name not in os.listdir('datasets'):
    url = 'http://jackssn.com/datasets/%s.zip' % ds_name
    r = requests.get(url, allow_redirects=True)
    open('datasets/%s.zip' % ds_name, 'wb').write(r.content)
else:
    print('Dataset %s already downloaded' % ds_name)

if '%s.zip' % ds_name in os.listdir('datasets'):
    unzip_q = input('Want to unzip zip-file? (y/n)\n')
    if unzip_q.lower() == 'y':
        with zipfile.ZipFile('datasets/%s.zip' % ds_name, 'r') as zip_ref:
            zip_ref.extractall('datasets')
        print('Dataset %s unzip successfully' % ds_name)
    
    remove_q = input('Want to remove zip-file? (y/n)\n')
    if remove_q.lower() == 'y':
        os.remove('datasets/%s.zip' % ds_name)
        print('Dataset %s removed successfully' % ds_name)

print('Files in datasets-dir:', ' | '.join((_ for _ in os.listdir('datasets'))))

In [None]:
import os, sys

ds_name = 'overwatch2csgo'
name_data = ds_name
DIRECTION = 'A2B'
#DIRECTION = 'B2A'

EPOCHS = 100
BATCH_SIZE = 5

XY_PIX = 256  # size of picture

wgt_c_a = 10
wgt_c_b = 10
wgt_i = 0.5

LR_GEN = 0.0002
LR_DISC = 0.0002

# Params to save pics and weights
num_freq_disp = 10  # display every N epochs
num_freq_save = 1  # save every N epochs

dir_checkpoint = './checkpoint'
dir_data = './datasets'
dir_log = './log'
dir_result = './result'

dir_chck = os.path.join(dir_checkpoint, name_data)
dir_data_train = os.path.join(dir_data, name_data, 'train')
dir_log_train = os.path.join(dir_log, name_data, 'train')
dir_data_test = os.path.join(dir_data, name_data, 'test')
dir_result_save = os.path.join(dir_result, name_data, 'images')

if not os.path.exists(dir_chck):
    os.makedirs(dir_chck)
if not os.path.exists(dir_log):
    os.makedirs(dir_log)
if not os.path.exists(dir_result_save):
    os.makedirs(dir_result_save)

## 3. Запуск обучения

Если память видеокарты переполнена, нужно перезагрузить кернел и еще раз определить модели и функции, а затем запускать обучение.

In [None]:
transform_train = transforms.Compose([
    Normalize(),
    RandomFlip(),
    Rescale((XY_PIX+30, XY_PIX+30)),
    RandomCrop((XY_PIX, XY_PIX)),
    ToTensor(),
])
transform_inv = transforms.Compose([ToNumpy(), Denormalize(),])

dataset_train = Dataset(dir_data_train, direction=DIRECTION, transform=transform_train)

train_cont=False
if os.listdir(dir_chck):
    train_cont=True
train(dataset=dataset_train, train_continue=train_cont)

In [None]:
transform_test = transforms.Compose([Normalize(), Rescale((XY_PIX+30, XY_PIX+30)),RandomCrop((XY_PIX, XY_PIX)),ToTensor()])
transform_inv = transforms.Compose([ToNumpy(), Denormalize()])

dataset_test = Dataset(dir_data_test, transform=transform_test)

test(dataset=dataset_test)

## 4. Результаты работы модели

Результаты работы сети можно посмотреть в файле `result/<name_ds>/index.html`

`input a | output b | input b | output a | recon a | recon b`<br>
`(apple to orange)  | (orange to apple)`

Также результаты можно посмотреть по ссылке: https://jackssn.com/result/overwatch2csgo