In [1]:
import os
import glob

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
%matplotlib inline

In [3]:
import torch
import torchvision as tv
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.transforms.functional as TF
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

In [4]:
# CUDA stuff
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")


In [5]:
use_cuda

False

# use cycleGAN

## what to use?

for generator, use encoder/transformer/decoder combo

for discriminator, use patchGAN

# making generator

we gotta make residual blocks

In [6]:
# def activation_func(activation):
#     return  nn.ModuleDict([
#         ['relu', nn.ReLU(inplace=True)],
#         ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
#         ['selu', nn.SELU(inplace=True)],
#         ['none', nn.Identity()]
#     ])[activation]

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation_fn):
        super(ResidualBlock, self).__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.activation_fn = activation_fn
        self.shortcut = nn.Identity()
        
        
        
    def forward(self, x):
        residual = x
        if self.apply_shortcut:
            residual = self.shortcut(x)
        x += residual
        x = self.activate(x)
        return x
    
    @property
    def apply_shortcut(self):
        return self.in_channels != self.out_channels
        
        

In [8]:
class ResNetResidualBlock(ResidualBlock):
    def __init__(self, in_channels, out_channels, activation_fn):
        super(ResNetResidualBlock, self).__init__(in_channels, 
                                                       out_channels, 
                                                       *args,
                                                       **kwargs)
        
        
        
        

In [9]:
class ResNetBlock(nn.Module):
    def __init__(self, input_dim):
        super(ResNetBlock, self).__init__()
        # input and output dim will be the same for our uses
        self.conv1 = nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, bias=True)
        self.norm1 = nn.InstanceNorm2d(input_dim)
        self.relu1 = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        
        self.conv2 = nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, bias=True)
        self.norm2 = nn.InstanceNorm2d(input_dim)
        
#         self.relu_final = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        
    def forward(self, x):
        x_new = self.conv1(x)
        x_new = self.norm1(x_new)
        x_new = self.relu1(x_new)
        x_new = self.conv2(x_new)
        x_new = self.norm2(x_new)
        out = x + x_new
#         out = self.relu_final(x_new)
        return out

In [10]:
class CycleGenerator(nn.Module):
    def __init__(self):
        super(CycleGenerator, self).__init__()
#         self.activations = nn.ModuleDict({
#         'relu', nn.ReLU(inplace=True),
#         'leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True),
#         'selu', nn.SELU(inplace=True),
#         'none', nn.Identity()})
        
        # do we need this many filter channels 
        # if we're doing a 1 channel image rather than 3 channel?
        
#         #encoder section
#         self.conv1 = nn.Conv2d(in_channels=1, out_channels=64,
#                                kernel_size=(7, 7), padding=0)
#         self.conv2 = nn.Conv2d(64, 128, (3, 3), padding=(1, 1), stride=2)
#         self.conv3 = nn.Conv2d(128, 256, (3, 3), padding=(1, 1), stride=2)
        
#         # in the transformer
#         self.conv4 = nn.Conv2d(1, 128, (3, 3), padding=(1, 1), stride=2)
#         https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(in_channels=1, out_channels=64,
                           kernel_size=7, padding=0,
                           bias=True),
                 nn.InstanceNorm2d(64),
                 nn.LeakyReLU(negative_slope=0.01, inplace=True)]
    
        #downsampling layers
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(in_channels=64*mult, out_channels=64*mult*2,
                           kernel_size=3, stride=2, padding=1,
                           bias=True),
                      nn.InstanceNorm2d(64*mult*2),
                      nn.LeakyReLU(negative_slope=0.01, inplace=True)]
            
        # resnet blocks layer
        num_resnet_blocks = 6
        for i in range(num_resnet_blocks):
            model += [ResNetBlock(64*mult*2)]
            
        # upsampling layers
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(in_channels=64*mult, out_channels=int(64*mult/2),
                           kernel_size=3, stride=2, padding=1,
                           bias=True),
                      nn.InstanceNorm2d(int(64*mult/2)),
                      nn.LeakyReLU(negative_slope=0.01, inplace=True)]
            
        self.model = nn.Sequential(*model)
        
            
        
    def forward(self, x):
        return model(x)
        

In [11]:
class CycleDiscriminator(nn.Module):
    def __init__(self):
        super(CycleDiscriminator, self).__init__()
        #https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/models.py
        
        model = [nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
        n_layers = 3
        
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8) # 2^n
            model += [
                nn.Conv2d(64 * nf_mult_prev, 64 * nf_mult, stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(64*nf_mult),
                nn.LeakyReLU(0.2, True)
            ]
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8) # 2^n
        model += [
            nn.Conv2d(64 * nf_mult_prev, 64 * nf_mult, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64*nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        
        model += [nn.Conv2d(64*nf_mult, 1, kernel_size=4, stride=1, padding=1)] # 1 channel prediction map
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        out = self.model(x)
        return out

In [None]:
class ImageBuffer():
    def __init__(self, max_size=50):
        assert max_size > 0
        self.max_size = max_size
        self.data = []
    
    def push_and_pop(self, data):
        to_return = []
        for element in data.data

In [14]:
for i in range(1, 4):
    print(i)

1
2
3


In [12]:
class CycleGAN(nn.Module):
    def __init__(self):
        '''
        init the cycle gan. add the loss functions
        add the models.
        add img buffer
        '''
        
        #start with models
        generator_clean = CycleGenerator() # dirty to clean
        discriminator_clean = CycleDiscriminator() # clean is fake/real
        
        generator_dirty = CycleGenerator() # clean to dirty
        discriminator_clean = CycleDiscriminator() # dirty is fake/real
        
        # turn on cuda
        if use_cuda:
            generator_clean.cuda()
            discriminator_clean.cuda()
            
            generator_dirty.cuda()
            discriminator_dirty.cuda()
        
        
        