In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)



Mounted at /content/drive


In [None]:
import os
os.chdir('drive/MyDrive/CycleGANs')

In [None]:
from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint

import numpy as np
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import torchvision.utils as vutils
import torch.autograd as autograd
torch.set_printoptions(threshold=5000)
from matplotlib import pyplot
from statistics import mean

### Data Loaders

In [None]:
batch_size=5
workers=2
image_size=(256,256)

# load training set of monet paintings
dataset_monet_train = dset.ImageFolder(root="dataset/trainMonet",
                                       transform=transforms.Compose([transforms.Resize(image_size),
                                                                     transforms.CenterCrop(image_size),
                                                                     transforms.ToTensor(),
                                                                     transforms.Normalize((0, 0, 0), (1, 1, 1)),]))
dataloader_monet_train = torch.utils.data.DataLoader(dataset_monet_train, batch_size=batch_size, shuffle=True, num_workers=workers)

# load training set of real pictures
dataset_picture_train = dset.ImageFolder(root="dataset/trainPicture",
                                         transform=transforms.Compose([transforms.Resize(image_size),
                                                                       transforms.CenterCrop(image_size),
                                                                       transforms.ToTensor(),
                                                                       transforms.Normalize((0, 0, 0), (1, 1, 1)),]))
dataloader_picture_train = torch.utils.data.DataLoader(dataset_picture_train, batch_size=batch_size, shuffle=True, num_workers=workers)

# load test set of monet paintings
dataset_monet_test = dset.ImageFolder(root="dataset/testMonet",
                                      transform=transforms.Compose([transforms.Resize(image_size),
                                                                    transforms.CenterCrop(image_size),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize((0, 0, 0), (1, 1, 1)),]))
dataloader_monet_test = torch.utils.data.DataLoader(dataset_monet_test, batch_size=batch_size, shuffle=True, num_workers=workers)

# load test set of picture paintings
dataset_picture_test = dset.ImageFolder(root="dataset/testPicture",
                                      transform=transforms.Compose([transforms.Resize(image_size),
                                                                    transforms.CenterCrop(image_size),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize((0, 0, 0), (1, 1, 1)),]))
dataloader_picture_test = torch.utils.data.DataLoader(dataset_picture_test, batch_size=batch_size, shuffle=True, num_workers=workers)

### Model Architectures

Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.LeakyReLU = nn.LeakyReLU(0.2)

        #Convolutional Layers
        self.c3_64 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.c64_128 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.c128_256 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.c256_512 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.c512_512 = nn.Conv2d(512, 512, 4, padding=1)
        self.c512_1 = nn.Conv2d(512, 1, 4, padding=1)

        #Instance Normalization Layers
        self.i_128 = nn.InstanceNorm2d(128)
        self.i_256 = nn.InstanceNorm2d(256)
        self.i_512_1 = nn.InstanceNorm2d(512)
        self.i_512_2 = nn.InstanceNorm2d(512)

        #Linear (for Wasserstein GANs)
        #self.lin = nn.Linear()

    def forward(self, x):
        #x: image of size 3x256x256

        #x1: 64x128x128
        x1 = self.LeakyReLU(self.c3_64(x))

        #x2: 128x64x64
        x2 = self.LeakyReLU(self.i_128(self.c64_128(x1)))

        #x3: 256x32x32
        x3 = self.LeakyReLU(self.i_256(self.c128_256(x2)))

        #x4: 512x16x16
        x4 = self.LeakyReLU(self.i_512_1(self.c256_512(x3)))

        #x5: 512x16x16
        x5 = self.LeakyReLU(self.i_512_2(self.c512_512(x4)))

        # out: 1 x 16 x 16
        # fine grained-discrimination
        out = self.c512_1(x5)

        return out

Generator

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self):
        super(ResNetBlock, self).__init__()
        self.layers=nn.Sequential(nn.Conv2d(256, 256, 3, padding='same'),
                                  nn.InstanceNorm2d(256),
                                  nn.ReLU(),
                                  nn.Conv2d(256, 256, 3, padding='same'))
        # Concatenate before second instance normalization layer
        self.i_256=nn.InstanceNorm2d(256)
        self.ReLU=nn.ReLU()

    def forward(self, x):
        cat=self.layers(x) + x
        normed=self.i_256(cat)
        out=self.ReLU(normed)
        return out

def genBlockSequence(num_blocks=8):
    if num_blocks == 0:
        return nn.Identity
    ls = [ResNetBlock() for i in range(num_blocks)]
    return ls


class Generator(nn.Module):
    def __init__(self, num_ResNetBlocks=8):
        super(Generator, self).__init__()

        self.encoder = nn.Sequential(nn.ReflectionPad2d(3),
                                     nn.Conv2d(3, 64, 7, stride=1, padding=0), nn.InstanceNorm2d(64), nn.ReLU(),
                                     nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.InstanceNorm2d(128), nn.ReLU(),
                                     nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.InstanceNorm2d(256), nn.ReLU())

        self.residual_network_blocks = nn.Sequential(*genBlockSequence(num_ResNetBlocks))

        self.decoder = nn.Sequential(nn.ConvTranspose2d(256, 512, 3, 1, 1), nn.PixelShuffle(2), nn.InstanceNorm2d(128), nn.ReLU(),
                                     nn.ConvTranspose2d(128, 256, 3, 1, 1), nn.PixelShuffle(2), nn.InstanceNorm2d(64), nn.ReLU(),
                                     nn.ReflectionPad2d(3),
                                     nn.Conv2d(64, 3, 7, 1, 0),
                                     nn.Tanh())

    def forward(self, x):

        #x: image of size 3x256x256
        encoded_image = self.encoder(x)
        mapped = self.residual_network_blocks(encoded_image)
        decoded = self.decoder(mapped)

        return decoded


### Loss Functions

In [None]:
def discriminator_loss(real_score, fake_score):
    return torch.mean(fake_score**2)+torch.mean((real_score-1)**2)

def generator_loss(disc_results):
    return torch.mean((1-disc_results)**2)

In [None]:
def find_grad_norm_on_interpolates(D, real_images, gen_images, batch_size=5):
    alpha = torch.rand(batch_size, 1, 1, 1).expand(real_images.size()).cuda()

    interpolates = alpha * real_images + (1-alpha) * gen_images
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    interpolates_D_results = D(interpolates)

    gradients = autograd.grad(outputs=interpolates_D_results, inputs=interpolates,
                            grad_outputs=torch.ones(interpolates_D_results.size()).cuda(),
                            create_graph=True, retain_graph=True, only_inputs=True)[0]

    grad_norm_on_interpolates = torch.mean(
                                    ((torch.linalg.norm(gradients, 2, dim=1) - 1) ** 2))

    del interpolates, gradients, alpha, real_images, gen_images
    return grad_norm_on_interpolates

### Training

In [None]:
#initialize dictionary of losses
def reset_losses():
    LossesInEpoch = {
        "Total_Gen": [],
        "D_Monet" : [],
        "D_Picture" : [],
        "Gen_Monet2Picture":[], #loss computed by the discriminator
        "Gen_Picture2Monet":[],
        "Cycle_Monet":[],
        "Cycle_Picture":[],
        "Identity_Picture2Monet":[],
        "Identity_Monet2Picture":[]
    }
    return LossesInEpoch

def avg_losses(LossesInEpoch):
    tg = LossesInEpoch["Total_Gen"]
    da = LossesInEpoch["D_Monet"]
    df = LossesInEpoch["D_Picture"]
    gp = LossesInEpoch["Gen_Monet2Picture"]
    gm = LossesInEpoch["Gen_Picture2Monet"]
    cm = LossesInEpoch["Cycle_Monet"]
    cp = LossesInEpoch["Cycle_Picture"]
    im= LossesInEpoch["Identity_Picture2Monet"]
    ip = LossesInEpoch["Identity_Monet2Picture"]

    return sum(tg)/len(tg), sum(da)/len(da), sum(df)/len(df), sum(gp)/len(gp), sum(gm)/len(gm), sum(cm)/len(cm), sum(cp)/len(cp), sum(im)/len(im),sum(ip)/len(ip)


In [None]:
from torchvision.transforms.functional import to_grayscale
def train(dataloader_picture_train, dataloader_monet_train,
        dataloader_picture_test, dataloader_monet_test,
        G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture,
        optimizer_D_Monet, optimizer_D_Picture, optimizer_G_Monet2Picture, optimizer_G_Picture2Monet,
        start_epoch, max_num_epochs=200):


    # Move generators and discriminators to GPU
    G_Monet2Picture = G_Monet2Picture.cuda()
    G_Picture2Monet = G_Picture2Monet.cuda()
    D_Monet = D_Monet.cuda()
    D_Picture = D_Picture.cuda()
    #Set the criterion for cycle and identity loss
    criterion = torch.nn.L1Loss()


    # Dicionary of all losses to keep track of progress
    LossesInEpoch = reset_losses()
    LossesAcrossEpochs = reset_losses()

    #Training Loop
    for epoch in range(start_epoch, max_num_epochs):

        #0 iterations
        iters = 0

        # Iterate through batches
        for t, (data_picture, data_monet) in enumerate(zip(dataloader_picture_train, dataloader_monet_train),0):

            # Move data to GPU
            picture_real = data_picture[0].cuda()
            monet_real = data_monet[0].cuda()

            # Forward passes
            picture_fake = G_Monet2Picture(monet_real)
            monet_reconstructed = G_Picture2Monet(picture_fake)     #for cycle loss
            monet_fake = G_Picture2Monet(picture_real)
            picture_reconstructed = G_Monet2Picture(monet_fake)        #for cycle loss

            # Discriminator Anime, compute on generated images randomly sampled from cache
            # Every 3 iterations
            optimizer_D_Monet.zero_grad()
            if (iters > 0 or epoch > start_epoch) and iters % 3 == 0:
                #Sample non-contiguous block
                idx = torch.randint(0, cache_monet_fake.size(0), (5,))              #revisit here
                samples = cache_monet_fake[idx]

                Disc_loss_Monet = discriminator_loss(D_Monet(monet_real), D_Monet(samples.detach()))        #revisit here


                LossesInEpoch["D_Monet"].append(Disc_loss_Monet.item())

            else:
                #Calculate Discriminator loss
                Disc_loss_Monet = discriminator_loss(D_Monet(monet_real), D_Monet(monet_fake.detach()))
                LossesInEpoch["D_Monet"].append(Disc_loss_Monet.item())


            Disc_loss_Monet.backward()
            optimizer_D_Monet.step()

            # Discriminator Face, compute on generated images randomly sampled from cache
            # Every 3 Iterations
            optimizer_D_Picture.zero_grad()

            if (iters > 0 or epoch > start_epoch) and iters % 3 == 0:
                #Sample non-contiguous block !Change
                idx = torch.randint(0, cache_picture_fake.size(0), (5,))
                samples = cache_picture_fake[idx]

                Disc_loss_Picture = discriminator_loss(D_Picture(picture_real), D_Picture(samples.detach()))
                LossesInEpoch["D_Picture"].append(Disc_loss_Picture.item())

            else:
                Disc_loss_Picture = discriminator_loss(D_Picture(picture_real), D_Picture(picture_fake.detach()))

                LossesInEpoch["D_Picture"].append(Disc_loss_Picture.item())


            Disc_loss_Picture.backward()
            optimizer_D_Picture.step()

            # Generator Losses:
            optimizer_G_Picture2Monet.zero_grad()
            optimizer_G_Monet2Picture.zero_grad()

            # Discriminator based generators Losses
            Gen_loss_Monet2Picture = generator_loss(D_Picture(picture_fake))
            Gen_loss_Picture2Monet = generator_loss(D_Monet(monet_fake))

            # Cycle Consistency both use the two generators
            Cycle_loss_Monet = criterion(monet_reconstructed, monet_real)*10
            Cycle_loss_Picture = criterion(picture_reconstructed, picture_real)*10

            # Identity loss
            Id_loss_Picture2Monet = criterion(G_Picture2Monet(monet_real), monet_real)*5
            Id_loss_Monet2Picture = criterion(G_Monet2Picture(picture_real), picture_real)*5

            # Total losses for generators to back-prop through
            Total_Loss = Gen_loss_Monet2Picture + Gen_loss_Picture2Monet + \
                    Cycle_loss_Monet + Cycle_loss_Picture + \
                    Id_loss_Picture2Monet + Id_loss_Monet2Picture

            LossesInEpoch["Total_Gen"].append(Total_Loss)
            LossesInEpoch["D_Monet"].append(Disc_loss_Monet)
            LossesInEpoch["D_Picture"].append(Disc_loss_Picture)
            LossesInEpoch["Gen_Monet2Picture"].append(Gen_loss_Monet2Picture)
            LossesInEpoch["Gen_Picture2Monet"].append(Gen_loss_Picture2Monet)
            LossesInEpoch["Cycle_Monet"].append(Cycle_loss_Monet)
            LossesInEpoch["Cycle_Picture"].append(Cycle_loss_Picture)
            LossesInEpoch["Identity_Picture2Monet"].append(Id_loss_Picture2Monet)
            LossesInEpoch["Identity_Monet2Picture"].append(Id_loss_Monet2Picture)

            #Backpropogate
            Total_Loss.backward()

            # Optimisation step
            optimizer_G_Picture2Monet.step()
            optimizer_G_Monet2Picture.step()

            # Define the fake caches
            #If caches are empty, clone most recent inputs
            #If cache is not full, add inputs to cache
            #If cache is full, replace random elements of cache with most recent input
            if(epoch == start_epoch and iters == 0):
                cache_picture_fake = picture_fake.clone()
                cache_monet_fake = monet_fake.clone()
            elif (cache_picture_fake.shape[0] >= batch_size * 5 and
                    picture_fake.shape[0] >= batch_size * 5):

                #Randomly sampling is different, so this  can be a contiguous block
                rand_int = randint(5, 24)
                cache_picture_fake[rand_int-5:rand_int] = picture_fake.clone()
                cache_monet_fake[rand_int-5:rand_int] = monet_fake.clone()

            elif(cache_picture_fake.shape[0]< 25):
                cache_picture_fake = torch.cat((picture_fake.clone(), cache_picture_fake))
                cache_monet_fake = torch.cat((monet_fake.clone(), cache_monet_fake))

            #Increment Iterations
            iters += 1

            # Release GPU memory
            del data_picture, data_monet, monet_real, picture_real, monet_fake, picture_fake #, interpolates

            # Print intermediate results
            if iters % 50 == 0:
                print('Epoch %d \tLosses: \tGenTotal: %.4f\tGen_Monet2Picture: %.4f\tGen_Picture2Monet: %.4f\tCycle_Monet: %.4f\tCycle_Picture: %.4f\tID_Picture2Monet: %.4f\tID_Monet2Picture: %.4f\tDisc_Monet: %.4f\tDisc_Picture: %.4f'
                            % (epoch,
                                Total_Loss,
                                Gen_loss_Monet2Picture, Gen_loss_Picture2Monet,
                                Cycle_loss_Monet, Cycle_loss_Picture,
                                Id_loss_Picture2Monet, Id_loss_Monet2Picture,
                                Disc_loss_Monet.item(), Disc_loss_Picture.item()))

        #Save Models
        #If we save all the models, we take up 400 gigabytes. Save the newest, and
        #Save one at certain checkpoints

        if epoch in [15, 25, 35, 45, 50, 60, 70, 75, 85, 95, 100]:
            save_models(G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture, str(epoch))
        else:
            save_models(G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture, 'newest')
        # if(epoch == 15):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "15")
        # elif(epoch==25):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "25")
        # elif(epoch==35):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "35")
        # elif(epoch==45):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "45")
        # elif(epoch==50):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "50")
        # elif(epoch==60):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "60")
        # elif(epoch==70):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "70")
        # elif(epoch==75):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "75")
        # elif(epoch==85):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "85")
        # elif(epoch==95):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "95")
        # elif(epoch==100):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "100")
        # elif(epoch==150):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "150")
        # elif(epoch==199):
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "199")
        # else:
        #     save_models(G_Anime2Face, G_Face2Anime, D_Anime, D_Face, "newest")


        # Test and plot the images
        #if epoch % 3 == 0:
        plot_images_test(dataloader_monet_test, dataloader_picture_test)

        #Compute Average Loss Across The Epoch for each of the loss measurments
        l1, l2, l3, l4, l5, l6, l7, l8, l9 = avg_losses(LossesInEpoch)

        LossesAcrossEpochs["Total_Gen"].append(l1)
        LossesAcrossEpochs["D_Monet"].append(l2)
        LossesAcrossEpochs["D_Picture"].append(l3)
        LossesAcrossEpochs["Gen_Monet2Picture"].append(l4)
        LossesAcrossEpochs["Gen_Picture2Monet"].append(l5)
        LossesAcrossEpochs["Cycle_Monet"].append(l6)
        LossesAcrossEpochs["Cycle_Picture"].append(l7)
        LossesAcrossEpochs["Identity_Picture2Monet"].append(l8)
        LossesAcrossEpochs["Identity_Monet2Picture"].append(l9)
        LossesInEpoch = reset_losses()


    return LossesAcrossEpochs

### Save and Load Models

In [None]:
def save_models(generator1, generator2, discriminator1, discriminator2, name):
    torch.save(generator1, "saved_models/"+str(name)+"_G_Monet2Picture.pt")
    torch.save(generator2, "saved_models/"+str(name)+"_G_Picture2Monet.pt")
    torch.save(discriminator1, "saved_models/"+str(name)+"_D_Monet.pt")
    torch.save(discriminator2, "saved_models/"+str(name)+"_D_Picture.pt")

def load_models(name):
    G_Monet2Picture=torch.load("saved_models/"+str(name)+"_G_Monet2Picture.pt")
    G_Picture2Monet=torch.load("saved_models/"+str(name)+"_G_Picture2Monet.pt")
    D_Monet=torch.load("saved_models/"+str(name)+"_D_Monet.pt")
    D_Picture=torch.load("saved_models/"+str(name)+"_D_Picture.pt")
    return G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture

### Test/Plot

In [None]:
#asdf
def plot_images_test(dataloader_monet_test, dataloader_picture_test):
    batch_a_test = next(iter(dataloader_monet_test))[0].cuda()
    real_a_test = batch_a_test.cpu().detach()
    fake_b_test = G_Monet2Picture(batch_a_test ).cpu().detach()

    plt.figure(figsize=(10,10))
    plt.imshow(np.transpose(vutils.make_grid((real_a_test[:4]+1)/2,
                                              padding=2, normalize=True).cpu(),(1,2,0)))
    plt.axis("off")
    plt.title("Real Monets")
    plt.show()

    plt.figure(figsize=(10,10))
    plt.imshow(np.transpose(vutils.make_grid((fake_b_test[:4]+1)/2,
                                              padding=2, normalize=True).cpu(),(1,2,0)))
    plt.axis("off")
    plt.title("Generated Pictures")
    plt.show()

    batch_b_test = next(iter(dataloader_picture_test))[0].cuda()
    real_b_test = batch_b_test.cpu().detach()
    fake_a_test = G_Picture2Monet(batch_b_test ).cpu().detach()

    plt.figure(figsize=(10,10))
    plt.imshow(np.transpose(vutils.make_grid((real_b_test[:4]+1)/2,
                                              padding=2, normalize=True).cpu(),(1,2,0)))
    plt.axis("off")
    plt.title("Real Pictures")
    plt.show()

    plt.figure(figsize=(10,10))
    plt.imshow(np.transpose(vutils.make_grid((fake_a_test[:4]+1)/2,
                                              padding=2, normalize=True).cpu(),(1,2,0)))
    plt.axis("off")
    plt.title("Generated Monets")
    plt.show()

### Execution

In [None]:
lr = 0.0002
beta1 = 0.5

In [None]:
#IF STARTING TRAINING OVER:
#Initialize generators and discriminators, if there are no models to build off of
G_Monet2Picture = Generator()
G_Picture2Monet = Generator()
D_Monet = Discriminator()
D_Picture = Discriminator()

In [None]:
##OR load generators and discriminators from models that this notebook has generated before.
G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture = load_models("newest")
#"newest" or "N" where N = 15, 25, 35, 45, 50, etc. based on number of epochs before

In [None]:
# temp cell
G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture = load_models("63")

In [None]:
#This is one plus how much many full epochs have trained so far (i.e., the next epoch to train)
start_epoch = 64

In [None]:
# Initialize optimizers
optimizer_G_Monet2Picture = torch.optim.Adam(G_Monet2Picture.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_G_Picture2Monet = torch.optim.Adam(G_Picture2Monet.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_Monet = torch.optim.Adam(D_Monet.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_Picture = torch.optim.Adam(D_Picture.parameters(), lr=lr, betas=(beta1, 0.999))

# Run training
Losses_Across_Epochs = train(dataloader_picture_train, dataloader_monet_train,
      dataloader_picture_test, dataloader_monet_test,
      G_Monet2Picture, G_Picture2Monet, D_Monet, D_Picture, optimizer_D_Monet, optimizer_D_Picture,
      optimizer_G_Monet2Picture, optimizer_G_Picture2Monet, start_epoch)

Output hidden; open in https://colab.research.google.com to view.