Peter Banyas | Nov 13

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader

import torchvision.transforms as transforms

from datasets import ImageDataset
import itertools

import random
import time
import datetime
import sys

from visdom import Visdom
import numpy as np

# Create the Translators

## **WHY DO THE IMPLEMENTATIONS FORGET RELU at OUTPUT???**

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.local_route = nn.Sequential(
            # padding largens image; conv reduces it back to original size.
            nn.ReflectionPad2d(1), # adds one-pixel border around image (reflected outwards)
            nn.Conv2d(in_features, in_features, 3), #in_channels=out_channels bc it's the same image channels in & out.  3x3 Kernel.
            nn.InstanceNorm2d(in_features), #normalizes the input per channel
            
            nn.ReLU(inplace=True), #"inplace" :. will modify the input directly, w/o allocating additional output. memory efficient.
            
            nn.ReflectionPad2d(1), 
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)

            # no ReLU here.  ReLU is applied after sum of local_route & highway
        )
    
    def forward(self, x):
        return x + self.local_route(x) # x bypasses local_route (residual) to help backpropagation.
    

In [12]:
class Generator(nn.Module):
    def __init__(self, input_num_channels, output_num_channels, n_residual_blocks=9):
        super(Generator, self).__init__()

        ## Initial Convolution Block
        model = [ 
            nn.ReflectionPad2d(3), #adds 3-pixel border (reflected outwards)
            nn.Conv2d(input_num_channels, 64, 7), #extracts 64 diff features using 7x7 kernel
            nn.InstanceNorm2d(64), #normalize data along each channel
            nn.ReLU(inplace=True)
        ]

        ## Downsample
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):              # 2 downsampling layers
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), #doubles features; halves image size
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        ## Residual Blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(out_features)] # processing in deep feature space

        ## Upsample
        out_features = in_features // 2
        for _ in range(2):              # 2 upsampling layers
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), #halves features; doubles image size
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        ## Output Layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_num_channels, 7),
            nn.Tanh() #squashes output to [-1, 1]
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [13]:
class Discriminator(nn.Module):
    def __init__(self, input_num_channels):
        super(Discriminator, self).__init__()

        ## Multiple convolutions
        model = [
            nn.Conv2d(input_num_channels, 64, 4, stride=2, padding=1), # 64 features, 4x4 kernel, ~halves image size
            nn.LeakyReLU(.2, inplace=True), #slope of negative part is y=.2x
        ]
        model += [
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # double features, half image size
            nn.InstanceNorm2d(128), # norms per channel
            nn.LeakyReLU(.2, inplace=True)
        ]
        model += [
            nn.Conv2d(128, 256, 4, stride=2, padding=1), #double features, half image size
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(.2, inplace=True)
        ]
        model += [
            nn.Conv2d(256, 512, 4, stride=2, padding=1), #double features, half image size
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(.2, inplace=True)
        ]

        ## FullyConvNet classification layer
        model += [
            nn.Conv2d(512, 1, 4, padding=1) # report one feature, well-informed by 512 features
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x = self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

# Create instances of them

In [None]:
gen_X_to_Y = Generator(3, 3) # "G"
gen_Y_to_X = Generator(3, 3) # "F"

discr_Y = Discriminator(3) # "D_y"
discr_X = Discriminator(3) # "D_x"

if opt.cuda:
    gen_X_to_Y = gen_X_to_Y.cuda()
    gen_Y_to_X = gen_Y_to_X.cuda()
    discr_X = discr_X.cuda()
    discr_Y = discr_Y.cuda()
    
# Initialize weights
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

gen_X_to_Y.apply(weights_init_normal)
gen_Y_to_X.apply(weights_init_normal)

discr_X.apply(weights_init_normal)
discr_Y.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
# def loss_adversarial(X, Y, fake_X, fake_Y)
#     catch_fake_Y = ( discr_Y(gen_X_to_Y(X)) - 1 )**2
#     catch_fake_X = ( discr_X(gen_Y_to_X(Y)) - 1 )**2
#     appreciate_true_Y = ( discr_Y(Y) - 1 )**2
#     appreciate_true_X = ( discr_X(X) - 1 )**2
    



In [None]:
## HYPERPARAMETERS
opt = lambda: None

opt.epoch = 0
opt.num_epochs = 200
opt.batchSize = 1
opt.dataroot = 'datasets/horse2zebra'
opt.lr = .0002
opt.b1 = .5
opt.b2 = .999
opt.decay_epoch = 100
opt.size = 256
opt.input_num_channels = 3
opt.output_num_channels = 3
opt.cuda = True
opt.num_cpu = 8

IDENTITY_WEIGHT = 5.0
CYCLE_WEIGHT = 10.0
ADVERSARIAL_WEIGHT = 1.0


In [None]:
## LOSSES
criterion_identity = nn.L1Loss()
criterion_cycle = nn.L1Loss()
criterion_GAN = nn.MSELoss()

##############################################################################################
## Optimizers, Learning Rate Schedulers
##############################################################################################
optimizer_gen = optim.Adam(itertools.chain(gen_X_to_Y.parameters(), gen_Y_to_X.parameters()),
                                 lr = opt.lr, betas = (opt.b1, opt.b2))
optimizer_discr_X = optim.Adam(discr_X.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
optimizer_discr_Y = optim.Adam(discr_Y.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))

class LambdaLR():
    def __init__(self, num_epochs, offset, decay_start_epoch):
        assert ((num_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.num_epochs = num_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
    def step(self, epoch):
        # learning rate linearly decays from 1 to 0 after the decay starts
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.num_epochs - self.decay_start_epoch)

lr_scheduler_gen = optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda=LambdaLR(opt.num_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_discr_X = optim.lr_scheduler.LambdaLR(optimizer_discr_X, lr_lambda=LambdaLR(opt.num_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_discr_Y = optim.lr_scheduler.LambdaLR(optimizer_discr_Y, lr_lambda=LambdaLR(opt.num_epochs, opt.epoch, opt.decay_epoch).step)

##############################################################################################
## Inputs, memory allocation
##############################################################################################
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_X = Tensor(opt.batchSize, opt.input_num_channels, opt.size, opt.size)
input_Y = Tensor(opt.batchSize, opt.input_num_channels, opt.size, opt.size)
target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) # bunch of 1s we can compare against the discriminator outputs
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) # bunch of 0s ^

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), "Buffer size must be greater than 0"
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size: # if there's space in buffer, add element
                self.data.append(element)
                to_return.append(element)
            else:
                if torch.rand(1).item() > .5: # 50% chance
                    i = torch.randint(0, self.max_size, (1,)).item() # pick random element to replace
                    to_return.append(self.data[i].clone()) # return the element we're replacing
                    self.data[i] = element # fill in the spot
                else:  # other 50% of time
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    
fake_X_buffer = ReplayBuffer()
fake_Y_buffer = ReplayBuffer()

##############################################################################################
## Dataset loader
##############################################################################################
dataloader = DataLoader(ImageDataset(opt.dataroot, 
                                     tranfroms_= [
                                         transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
                                         transforms.RandomCrop(opt.size),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
                                         ],
                                      unaligned=True),
                                      batch_size=opt.batchSize, 
                                      shuffle=True, 
                                      num_workers=opt.num_cpu)

##############################################################################################
## Loss plot
##############################################################################################
def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)

class Logger():
    def __init__(self, n_epochs, batches_epoch):
        self.viz = Visdom()
        self.n_epochs = n_epochs
        self.batches_epoch = batches_epoch
        self.epoch = 1
        self.batch = 1
        self.prev_time = time.time()
        self.mean_period = 0
        self.losses = {}
        self.loss_windows = {}
        self.image_windows = {}


    def log(self, losses=None, images=None):
        self.mean_period += (time.time() - self.prev_time)
        self.prev_time = time.time()

        sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch))

        for i, loss_name in enumerate(losses.keys()):
            if loss_name not in self.losses:
                self.losses[loss_name] = losses[loss_name].data[0]
            else:
                self.losses[loss_name] += losses[loss_name].data[0]

            if (i+1) == len(losses.keys()):
                sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch))
            else:
                sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch))

        batches_done = self.batches_epoch*(self.epoch - 1) + self.batch
        batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch 
        sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done)))

        # Draw images
        for image_name, tensor in images.items():
            if image_name not in self.image_windows:
                self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name})
            else:
                self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name})

        # End of epoch
        if (self.batch % self.batches_epoch) == 0:
            # Plot losses
            for loss_name, loss in self.losses.items():
                if loss_name not in self.loss_windows:
                    self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), 
                                                                    opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name})
                else:
                    self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append')
                # Reset losses for next epoch
                self.losses[loss_name] = 0.0

            self.epoch += 1
            self.batch = 1
            sys.stdout.write('\n')
        else:
            self.batch += 1

logger = Logger(opt.num_epochs, len(dataloader))


# Training

In [None]:
for epoch in range(opt.epoch, opt.num_epochs):
    for i, batch in enumerate(dataloader):

        # Pull in the real images
        real_X = Variable(input_X.copy_(batch['A']))  ## ALERT ALERT may need to change 'A' to 'X' 
        real_Y = Variable(input_Y.copy_(batch['B']))

        # Activate the generators
        optimizer_gen.zero_grad() # start with zero gradients

        #### Identity Loss
        allegedly_same_Y = gen_X_to_Y(real_Y) # G(Y) should remain Y
        loss_identity_Y = criterion_identity(allegedly_same_Y, real_Y) * IDENTITY_WEIGHT

        allegedly_same_X = gen_Y_to_X(real_X) # G(X) should remain X
        loss_identity_X = criterion_identity(allegedly_same_X, real_X) * IDENTITY_WEIGHT

        #### GAN Loss
        fake_Y = gen_X_to_Y(real_X)
        gullibility_to_fake_Y = discr_Y(fake_Y) # 1: tricked that fake_Y is real. 0: realizes fake_Y is fake  ## ALERT ALERT is this true
        loss_GAN_Y = criterion_GAN(gullibility_to_fake_Y, target_real) # how far away from 1 are we

        fake_X = gen_Y_to_X(real_Y)
        gullibility_to_fake_X = discr_X(fake_X) # 1: tricked that fake_X is real. 0: realizes fake_X is fake  ## ALERT ALERT is this true
        loss_GAN_X = criterion_GAN(gullibility_to_fake_X, target_real)

        #### Cycle Consistency Loss
        allegedly_reconstructed_X = gen_Y_to_X(fake_Y) # complete the cycle
        loss_cycle_X = criterion_cycle(allegedly_reconstructed_X, real_X) * CYCLE_WEIGHT

        allegedly_reconstructed_Y = gen_X_to_Y(fake_X) # complete the cycle
        loss_cycle_Y = criterion_cycle(allegedly_reconstructed_Y, real_Y) * CYCLE_WEIGHT

        #### TOTAL Loss
        loss_gen = loss_identity_X + loss_identity_Y + loss_GAN_X + loss_GAN_Y + loss_cycle_X + loss_cycle_Y
        loss_gen.backward()

        optimizer_gen.step()

## Discriminators

In [None]:
optimizer_discr_X.zero_grad()

#### Real Loss
visibility_of_real_X = discr_X(real_X)
loss_real_X = criterion_GAN(visibility_of_real_X, target_real)

#### Fake Loss



# **NEXT TO DO: TRAINING FROM /Users/peterbanyas/Desktop/ECE 661/Project 661/PyTorch-CycleGAN/train**

old

In [None]:


criterion_identity = nn.L1Loss()
criterion_cycle = nn.L1Loss()
criterion_GAN = nn.MSELoss()

# goals:
target_X = Tensor(opt.batchSize, opt.input_num_channels, opt.size, opt.size)

# When input is already desired output, it should stay the same.
def identity_loss(X, Y):
    allegedly_same_Y = gen_X_to_Y(Y)
    allegedly_same_X = gen_Y_to_X(X)
    loss_identity_Y = criterion_identity(Y, allegedly_same_Y)
    loss_identity_X = criterion_identity(X, allegedly_same_X)
    return ( loss_identity_Y + loss_identity_X ) * IDENTITY_WEIGHT

# Appreciate truth, disbelieve lies.
def GAN_loss(X, Y, discr_X, ):
    fake_Y = gen_X_to_Y(X)
    fake_X = gen_Y_to_X(Y)
    believe_fake_Y = discr_Y(fake_Y)
    believe_fake_X = discr_X(fake_X)
    believe_true_Y = discr_Y(Y)
    believe_true_X = discr_X(X)

    



    


    


