# CYCLE GAN

## Setup

In [1]:
!pip3 install numpy torch wandb torchvision sklearn tqdm matplotlib



## Imports

In [2]:
import numpy as np

import matplotlib.pyplot as plt

import wandb

import torch
from torch import nn
from torch.utils import data

from torchvision import transforms,datasets
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

from sklearn.metrics import accuracy_score

from tqdm.auto import trange, tqdm

In [3]:
torch.backends.cudnn.enabled = False

### Inits

In [4]:
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
device

device(type='cpu')

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlalunedeidees[0m (use `wandb login --relogin` to force relogin)


True

## Data

In [6]:
from PIL import Image


class ImageLoader():
    def __init__(self, transform=None):
        self.transform = transform

    def load(self, path):
        out = Image.open(path).convert('RGB')

        if self.transform is not None:
            out = self.transform(out)

        return out

In [7]:
import os
import glob

from torch.utils.data import Dataset


class ImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.loader = ImageLoader(transform)
        
        self.paths = self.get_paths(root)
        
    def get_paths(self, root):
        path = os.path.join(root, '*')
        paths = glob.glob(path)
        
        exts = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        paths = filter(lambda p: p.endswith(exts), paths)
        paths = list(paths)
        
        return paths

    def __getitem__(self, index):
        path = self.paths[index]
        img = self.loader.load(path)
        return img

    def __len__(self):
        return len(self.paths)

In [8]:
import torchvision.transforms as transforms


def get_transform(image_size, aug=False):
    if aug:
        bigger_image_size = (image_size // 8 + 1) * 8
        ts = [
            transforms.Resize((bigger_image_size, bigger_image_size)),
            transforms.RandomResizedCrop((image_size, image_size)),
            transforms.RandomHorizontalFlip()
        ]
    else:
        ts = [transforms.Resize((image_size, image_size))]
        
    ts += [
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]
    
    return transforms.Compose(ts)

In [9]:
# !wget https://people.eecs.berkeley.edu/%7Etaesung_park/CycleGAN/datasets/horse2zebra.zip
# !unzip horse2zebra.zip

In [10]:
image_size = 128

train_transforms = get_transform(image_size, aug=True)
train_dataset_a = ImageDataset(root='horse2zebra/trainA', transform=train_transforms)
train_dataset_b = ImageDataset(root='horse2zebra/trainB', transform=train_transforms)

val_transforms = get_transform(image_size, aug=False)
val_dataset_a = ImageDataset(root='horse2zebra/testA', transform=val_transforms)
val_dataset_b = ImageDataset(root='horse2zebra/testB', transform=val_transforms)

In [11]:
from torch.utils.data import DataLoader


class JointDataLoader():
    def __init__(self, dataset_a, dataset_b, batch_size, shuffle=False, num_workers=1):
        self.dataset_a = dataset_a
        self.dataset_b = dataset_b
        
        self.dataloader_a = DataLoader(
            dataset=dataset_a,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=True
        )
        self.dataloader_b = DataLoader(
            dataset=dataset_b,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=True
        )
        
        self.n = min(len(self.dataloader_a), len(self.dataloader_b))
        
        if shuffle:
            self.dataloader_a = self.infinit_dataloader(self.dataloader_a)
            self.dataloader_b = self.infinit_dataloader(self.dataloader_b)
        
    def infinit_dataloader(self, dataloader):
        while True:
            for x in dataloader:
                yield x
        
    def __iter__(self):
        for _, a, b in zip(range(self.n), self.dataloader_a, self.dataloader_b):
            yield a, b

## Model

In [12]:
class DownSampleBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=4,
                 stride=2, padding=1,activation='ReLU'):
        super(DownSampleBlock, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding)
        self.ins = nn.InstanceNorm2d(out_planes)
        self.relu = nn.ReLU() if activation == 'ReLU' else nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.ins(x)
        x = self.relu(x)
        return x

# DownSampleBlock(64 * k_ch, 128 * k_ch)
# заменит
# nn.Conv2d(64*k_ch,128*k_ch,4,2,1,bias=False),
# nn.InstanceNorm2d(128*k_ch),
# nn.LeakyReLU(0.2)

class CenterDownSampleBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=2,
                 stride=1, padding=0):
        super(CenterDownSampleBlock, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)

        return x

# CenterDownSampleBlock(128 * k_ch, 256 * k_ch)
# заменит
# nn.Conv2d(128*k_ch,256*k_ch,2,bias=False),
# nn.LeakyReLU(0.2),

class CenterUpSampleBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=2,
                 stride=1, padding=0):
        super(CenterUpSampleBlock, self).__init__()

        self.trans_conv = nn.ConvTranspose2d(
            in_planes, out_planes, kernel_size=kernel_size, 
            stride=stride, padding=padding,
        )
        self.ins = nn.InstanceNorm2d(out_planes)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.trans_conv(x)
        x = self.relu(x)

        return x

# CenterUpSampleBlock(64 * k_ch, 128 * k_ch)
# заменит
# nn.ConvTranspose2d(256*k_ch,128*k_ch,2,bias=False),
# nn.InstanceNorm2d(128*k_ch),
# nn.LeakyReLU(0.2),

class UpSampleBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=4,
                 stride=2, padding=1, activation='ReLU',output_padding=0):
        super(UpSampleBlock, self).__init__()

        self.trans_conv = nn.ConvTranspose2d(
            in_planes, out_planes, kernel_size=kernel_size,
            stride=stride, padding=padding,output_padding=output_padding
        )
        self.ins = nn.InstanceNorm2d(out_planes)
        self.relu = nn.ReLU() if activation == 'ReLU' else nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.trans_conv(x)
        x = self.ins(x)
        x = self.relu(x)

        return x

class ResnetBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=4,
                 stride=2, padding=1):
        super(ResnetBlock, self).__init__()

        self.down_sample = DownSampleBlock(in_planes,out_planes,3,1,1,activation='ReLU')
        self.conv2d = nn.Conv2d(out_planes,out_planes,3,1,1)
        self.ins = nn.InstanceNorm2d(out_planes)
#         self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        identity = x
        x = self.down_sample(x)
        x = self.conv2d(x)
        x = self.ins(x)
        x = x + identity
        return x
# UpSampleBlock(4 * k_ch, 2 * k_ch)
# заменит
# nn.ConvTranspose2d(4*k_ch,2*k_ch,4,2,1,bias=False),
# nn.InstanceNorm2d(2*k_ch),
# nn.LeakyReLU(0.2),

In [13]:
class Generator(nn.Module):
    def __init__(self,input_size):
        super(Generator, self).__init__()
        k_ch = 3 # number of color chanels
        layers = [
            DownSampleBlock(1*k_ch,64,7,1,3,activation='ReLU'),
            DownSampleBlock(64,128,3,1,1,activation='ReLU'),
            DownSampleBlock(128,256,3,1,1,activation='ReLU'),
            
            ResnetBlock(256,256),
            ResnetBlock(256,256),
            ResnetBlock(256,256),
            ResnetBlock(256,256),
            ResnetBlock(256,256),
            ResnetBlock(256,256),
#             ResnetBlock(256,256),
#             ResnetBlock(256,256),
#             ResnetBlock(256,256),
            
            
            UpSampleBlock(256,128,3,1,1,activation='ReLU'),
            UpSampleBlock(128,64,3,1,1,activation='ReLU'),
#             UpSampleBlock(128,64,3,2,1),
            nn.ConvTranspose2d(64,3,7,1,3),
#             nn.InstanceNorm2d(3),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.model(x)
        return x


In [14]:
res = Generator(1)(train_dataset_a[0].reshape((1,3,128,128)))
res.shape

torch.Size([1, 3, 128, 128])

In [15]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        k_ch = 3 # number of color chanels
        layers = [
            DownSampleBlock(1*k_ch,64,4,2,2,activation='LeakyReLU'),
            DownSampleBlock(64,128,4,2,2,activation='LeakyReLU'),
            DownSampleBlock(128,256,4,2,2,activation='LeakyReLU'),
            DownSampleBlock(256,512,4,2,2,activation='LeakyReLU'),
            DownSampleBlock(512,512,4,1,2,activation='LeakyReLU'),

#             DownSampleBlock(512,512),
#             DownSampleBlock(512,512),

            # nn.Conv2d(512,1,4,1,2),

#             CenterDownSampleBlock(512,1),

#             DownSampleBlock(256,512,4,1,1),
            nn.Conv2d(512,1,4,1,1),
            
#             nn.Flatten(),
            # nn.Linear(k_ch,1),
#             nn.Sigmoid()
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.model(x)
        return x

In [16]:
Discriminator()(res).shape

torch.Size([1, 1, 9, 9])

## Training

In [17]:
def log_losses(title, losses, step):
    title = f'{title}_losses'

    loss_names = [
        'gen_ab_loss',
        'gen_ba_loss',
        'cycle_aba_loss',
        'cycle_bab_loss',
        'indentity_a_loss',
        'indentity_b_loss',
        'disc_a_loss',
        'disc_b_loss'
    ]
    
    for loss, loss_name in zip(losses, loss_names):
        wandb.log({f'{title}/{loss_name}': loss}, step=step)

In [18]:

def log_images(title, x, step,grid_size=(5,5)):
    title = f'{title}'

    x = x[:grid_size[0]*grid_size[1]].cpu()

    imgs = make_grid(x, nrow=grid_size[0])
    
    wandb.log({f'{title}': [wandb.Image(imgs)]}, step=step)


In [19]:
def calc_disc_loss(a, b, gen_ab, disc_b, criterion):
    z = a

    b_fake = gen_ab(z)

    out_real = disc_b(b)
    out_fake = disc_b(b_fake)

    target_real = torch.ones_like(out_real).to(device)
    target_fake = torch.zeros_like(out_fake).to(device)

    loss_real = criterion(out_real, target_real)
    loss_fake = criterion(out_fake, target_fake)

    disc_b_loss = loss_real + loss_fake
    
    return disc_b_loss

In [20]:
def calc_losses_discs(gen_ab, gen_ba, disc_a, disc_b, a, b, criterion):
    disc_b_loss = calc_disc_loss(a, b, gen_ab, disc_b, criterion)
    disc_a_loss = calc_disc_loss(b, a, gen_ba, disc_a, criterion)

    loss = disc_a_loss + disc_b_loss

    losses = np.array([disc_a_loss.item(), disc_b_loss.item()])

    return loss, losses

In [21]:
def calc_gen_loss(a, gen_ab, gen_ba, disc_b, criterion_bce, criterion_l2):
    fake_b = gen_ab(a)
    out_fake_b = disc_b(fake_b)
    
    target_fake = torch.ones_like(out_fake_b).to(device)
    gen_ab_loss = criterion_bce(out_fake_b, target_fake)
    
    fake_ba = gen_ba(fake_b)
    aba_loss = criterion_l2(fake_ba, a)
    
    return gen_ab_loss, aba_loss

In [22]:
def calc_losses_gens(gen_ab, gen_ba, disc_a, disc_b, a, b, criterion_bce, criterion_l2):
    gen_ab_loss, aba_loss = calc_gen_loss(a, gen_ab, gen_ba, disc_b, criterion_bce, criterion_l2)
    gen_ba_loss, bab_loss = calc_gen_loss(b, gen_ba, gen_ab, disc_a, criterion_bce, criterion_l2)
    
    indentity_a_loss = criterion_l2(gen_ba(a), a)
    indentity_b_loss = criterion_l2(gen_ab(b), b)

    loss = gen_ab_loss + gen_ba_loss
    loss = loss + aba_loss + bab_loss
    loss = loss + indentity_a_loss + indentity_b_loss

    losses = np.array([
        gen_ab_loss.item(),
        gen_ba_loss.item(),
        aba_loss.item(),
        bab_loss.item(),
        indentity_a_loss.item(),
        indentity_b_loss.item()
    ])

    return loss, losses

In [23]:
def train_step(gen_ab,gen_ba,disc_a,disc_b, a, b, criterion_bce,criterion_l2,optim_gens, optim_discs, step):
    # calc losses gens
    loss_gens, losses_gens  = calc_losses_gens(gen_ab, gen_ba, disc_a, disc_b, a, b, criterion_bce, criterion_l2)
    # take steps
    optim_gens.zero_grad()
    loss_gens.backward()
    optim_gens.step()

    # calc losses for discs
    loss_discs, losses_discs  = calc_losses_discs(gen_ab, gen_ba, disc_a, disc_b, a, b, criterion_bce)
    # take steps
    optim_discs.zero_grad()
    loss_discs.backward()
    optim_discs.step()
    
    # log lossses
    losses = np.concatenate([
      losses_gens,
      losses_discs,
    ])
    
    log_losses('Train', losses, step)

   

In [24]:
def train_epoch(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,criterion_l2,optim_gens, optim_discs, dataloader, step):
    gen_ab.train()
    gen_ba.train()
    disc_a.train()
    disc_b.train()
    
    for a, b in dataloader:
        a = a.to(device)
        b = b.to(device)
        
        train_step(gen_ab,gen_ba,disc_a,disc_b, a, b, criterion_bce,criterion_l2,optim_gens, optim_discs, step)
        if step % 100 == 0:
            a, b = next(iter(valloader))
            a.to(device)
            b.to(device)

            log_generate_cycles(a, b, gen_ab, gen_ba, step)
    
        step += 1
    return step

In [25]:

def log_cycle(a, gen_ab, gen_ba, step, title='forward cycle'):
    # to middle
    a1 = a[:1].to(device)
    a_b = gen_ab(a1)
    # to tail
    a_b_a = gen_ba(a_b)
    imgs= torch.cat((a1,a_b,a_b_a),0)

    log_images(title, imgs, step, grid_size=(3, 1))


def log_generate_cycles(a, b, gen_ab, gen_ba, step):
    log_cycle(a, gen_ab, gen_ba, step, title='Images/forward cycle')
    log_cycle(b, gen_ba, gen_ab, step, title='Images/reverse cycle')
    

In [26]:
@torch.no_grad()
def val_epoch(gen_ab,gen_ba,disc_a,disc_b, criterion_bce, valloader, step):
    gen_ab.eval()
    gen_ba.eval()
    disc_a.eval()
    disc_b.eval()
    
    (a,b),*_ = valloader
    # b = a
    a.to(device)
    b.to(device)
    
    log_generate_cycles(a,b,gen_ab,gen_ba,step)
    

In [27]:
def train(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,criterion_l2,optim_gens, optim_discs, trainloader,valloader, epochs,step=0):
#     step = 0
    
    for epoch in range(epochs):
        step = train_epoch(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,criterion_l2,optim_gens, optim_discs, trainloader, step)
        val_epoch(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,valloader, step)
        wandb.log({'epoch': epoch}, step=step)
#         step+=1
    return step

In [28]:
concatenate = lambda a,b:list(a) + list(b)

In [29]:
gen_ab = Generator(0).to(device) # FIXME: del 0
gen_ba = Generator(0).to(device)
disc_a = Discriminator().to(device)
disc_b = Discriminator().to(device)

criterion_bce = nn.MSELoss()
criterion_l2  = nn.L1Loss()

params_gens  = concatenate(gen_ab.parameters(),gen_ba.parameters())
params_discs = concatenate(disc_a.parameters(),disc_b.parameters())

optim_gens  = torch.optim.Adam(params_gens, lr=0.0002, betas=(0.3, 0.999))
optim_discs = torch.optim.Adam(params_discs, lr=0.0002, betas=(0.3, 0.999))


In [30]:
batch_size = 1


trainloader = JointDataLoader(
    dataset_a=train_dataset_a,
    dataset_b=train_dataset_b,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)
valloader = JointDataLoader(
    dataset_a=val_dataset_a,
    dataset_b=val_dataset_b,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8
)

In [31]:
! nvidia-smi

Sun May  9 16:43:56 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  On   | 00000000:8B:00.0 Off |                    0 |
| N/A   53C    P0    44W / 250W |  23159MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:

wandb.init(
    project='GAN',
#     config={'n_res': 2}
)
step = train(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,criterion_l2,optim_gens, optim_discs, trainloader,valloader, epochs=100)

In [None]:
 # if you want to continue train:
# step = train(gen_ab,gen_ba,disc_a,disc_b, criterion_bce,criterion_l2,optim_gens, optim_discs, trainloader, epochs=100,step=step)

In [1]:
!nvidia-smi

Sun May  9 17:01:26 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  On   | 00000000:8B:00.0 Off |                    0 |
| N/A   59C    P0   215W / 250W |  23159MiB / 32510MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces