In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms


In [2]:
# def load_data():
#     data_loaders = []
#     for img_size, batch_size in img_batch_size:
#         dataset = dset.ImageFolder(root=dataroot,
#                            transform=transforms.Compose([
#                                transforms.Resize(img_size),
#                                transforms.CenterCrop(img_size),
#                                transforms.ToTensor(),
#                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                            ]))

#         dataload = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
#                                         shuffle=True, num_workers=workers, drop_last=True)
#         data_loaders.append(dataload)
#     # print("Data Loaded")
#     return data_loaders


In [3]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
lamb = 10
lr = .0002

In [4]:
# Many different views on which design is best, using original resnet design
class Residual2d(nn.Module):
    def __init__(self, dim):
        super(Residual2d, self).__init__()
        self.res = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1, 0),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1, 0),
            nn.InstanceNorm2d(dim))
        self.relu = nn.ReLU(True)
    
    def forward(self, x):
        output = x + self.res(x)
        output = self.relu(output)
        return output

        

In [5]:
# Generator Code
# Use mirrored padding always or just beginning and final layer
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
#         3x256x256 -> 64x256x256
        self.c7s1_64 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, 7, 1, 0),
            nn.InstanceNorm2d(64),
            nn.ReLU(True))
    
#         64x256x256 -> 128x128x128
        self.d128 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 128, 3, 2, 0),
            nn.InstanceNorm2d(128),
            nn.ReLU(True))
        
#         128x128x128 -> 256x64x64
        self.d256 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 256, 3, 2, 0),
            nn.InstanceNorm2d(64),
            nn.ReLU(True))
        
#         256x64x64 -> 256x64x64
        res_blocks = [Residual2d(256) for x in range(9)]
        self.R9 = nn.Sequential(*res_blocks)
        
#         Try using upsample + conv
#         For some reason padding beforehand doesnt work with convtranpose, 
#         should work even with tranpose special padding rule
#         256x64x64 -> 128x128x128
        self.u128 = nn.Sequential(
#             nn.ReflectionPad2d(1),
            nn.ConvTranspose2d(256, 128, 3, 2, 1, output_padding=1),
#             nn.ReflectionPad2d(1)
        )
        
#         128x128x128 -> 64x256x256
        self.u256 = nn.Sequential(
#             nn.ReflectionPad2d(1),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1),
#             nn.ReflectionPad2d(1)
        )
        
        self.c7s1_3 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7, 1, 0),
            nn.InstanceNorm2d(64),
            nn.ReLU(True))
        
    def forward(self, input):
        output = self.c7s1_64(input)
        output = self.d128(output)
        output = self.d256(output)
        output= self.R9(output)
        output= self.u128(output)
        output= self.u256(output)
        output= self.c7s1_3(output)
#         print(output.shape)
        return output


In [6]:
# https://arxiv.org/pdf/1611.07004.pdf
# Try different patch sizes
# Last two convolutions need to have a stride and padding of 1
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.C64 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace= True))
    
        self.C128 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace= True))
        
        self.C256 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace= True))
    
        self.C512 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 1, 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace= True))
        self.last = nn.Conv2d(512,1, 4, 1, 1)
        
    def forward(self, input):
        output = self.C64(input)
        output = self.C128(output)
        output = self.C256(output)
        output = self.C512(output)
        output = self.last(output)



#         print(output.shape)
        return output


In [7]:
# x: input
# y: target
# G: x -> y
# F: y -> x



In [8]:

# Gx = gen_G(x)
# Gy = gen_G(y)
# FGx = gen_F(Gx)

# Fy = gen_F(y)
# Fx = gen_F(x)
# GFy = gen_G(Fy)

# # res_x = dis_x_vs_Fy(x)
# res_Fy = dis_x_vs_Fy(Fy)
# # res_y = dis_y_vs_Gx(y)
# res_Gx = dis_y_vs_Gx(Gx)



In [9]:
# implement buffering images later
# class ImageBuffer
#     def __init__(self):
#         self.buffer = []
#         self.buffer_size = 0
#         self.last_image = -1
    
#     def add_image(image):
#         if self.buffer_size < 50:
#             self.buffer.append(image)
#             self.last_image = self.buffer_size
#             self.buffer_size += 1
#         else:
            

In [10]:
class CycleGan:
    def __init__(self):
        self.gen_G = Generator().to(device)
        self.gen_F = Generator().to(device)
        self.dis_x_vs_Fy = Discriminator().to(device)
        self.dis_y_vs_Gx = Discriminator().to(device)
        self.Gx = -1
        self.Fy = -1
        self.lossGx = -1
        self.lossFy= -1
        self.lossY = -1
        self.lossX = -1
        self.crit = nn.MSELoss()
        self.optimizer_gen_G = torch.optim.Adam(self.gen_G.parameters(), lr = lr)
        self.optimizer_gen_F = torch.optim.Adam(self.gen_F.parameters(), lr = lr)
        self.optimizer_dis_x = torch.optim.Adam(self.dis_x_vs_Fy.parameters(), lr = lr)
        self.optimizer_dis_y = torch.optim.Adam(self.dis_y_vs_Gx.parameters(), lr = lr)
    
    def grad_toggle(self, grad):
        for param in self.dis_x_vs_Fy.parameters():
            param.requires_grad = grad
        for param in self.dis_y_vs_Gx.parameters():
            param.requires_grad = grad
            
    def loss_gen(self, result):
        return self.crit(result, torch.zeros_like(result))
    
#     divide by 2 to slow down
    def loss_dis(self, real, fake):
        loss1 = self.crit(real, torch.ones_like(real))
        loss2 = self.crit(fake, torch.zeros_like(fake))
        return torch.mean((loss1 + loss2) * .5)

    def loss_cyclic(self, real, cycled):
        loss = torch.abs(cycled - real)
        return torch.mean(loss) * lamb

    # Found in 5.2 of paper, .5 found in git
    def loss_identity(self, real, ident):
        loss = torch.abs(ident - real)
        return torch.mean(loss) * lamb * .5
    
    def generator_loss(self, x, y):
        self.Gx = self.gen_G(x)
        Gy = self.gen_G(y)
        FGx = self.gen_F(self.Gx)

        self.Fy = self.gen_F(y)
        Fx = self.gen_F(x)
        GFy = self.gen_G(self.Fy)
        
        self.grad_toggle(False)
        
        lossGx = self.loss_gen(self.dis_y_vs_Gx(self.Gx))
        self.lossGx = lossGx.item()
        
        lossFy = self.loss_gen(self.dis_x_vs_Fy(self.Fy))
        self.lossFy = lossFy.item()
        
        lossFGx = self.loss_cyclic(x, FGx)
        lossGFy = self.loss_cyclic(y, GFy)
        
        identGy = self.loss_identity(y, Gy)
        identFx = self.loss_identity(x, Fx)
        
        loss = lossGx + lossFy + lossFGx + lossGFy + identGy + identFx
        
        self.grad_toggle(True)

        return loss
    
    def discriminator_loss(self, x, y):
        dis_Gx = self.dis_y_vs_Gx(self.Gx.detach())
        dis_y = self.dis_y_vs_Gx(y)
        
        dis_Fy = self.dis_x_vs_Fy(self.Fy.detach())
        dis_x = self.dis_x_vs_Fy(x)
        
        loss_y = self.loss_dis(dis_y, dis_Gx)
        self.lossY = loss_y.item()
        
        loss_x = self.loss_dis(dis_x, dis_Fy)
        self.lossX = loss_x.item()
        
        loss = loss_x + loss_y
        return loss
        
        
    def step(self, x, y):
        self.optimizer_gen_G.zero_grad()
        self.optimizer_gen_F.zero_grad()
        loss_g = self.generator_loss(x,y)
        loss_g.backward()
        self.optimizer_gen_G.step()
        self.optimizer_gen_F.step()
       
        self.optimizer_dis_x.zero_grad()
        self.optimizer_dis_y.zero_grad()
        loss_d = self.discriminator_loss(x, y)
        loss_d.backward()
        self.optimizer_dis_x.step()
        self.optimizer_dis_y.step()
        print("loss Gx: ", self.lossGx, "loss Fy: ", self.lossFy, "loss disY: ", self.lossY, "loss disX: ", self.lossX)


In [11]:
x = torch.rand(1,3,400,400).to(device)
y = torch.rand(1,3,400,400).to(device)
gan = CycleGan()

In [12]:
gan.step(x, y)

loss Gx:  0.19479721784591675 loss Fy:  0.1534479707479477 loss disY:  0.9343559741973877 loss disX:  0.5426048040390015


In [13]:
# Notes to self
# Load data as cropped not scaled images, then generate on full image, 
# Do data loading
# Do checkpointing