In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import itertools
from torch import optim 
import random 
import os 
from PIL import Image 
import torch.utils.data as data
import os.path
from torch.utils.data import Dataset
from torchvision import datasets,transforms,utils
import torch.nn.functional as F 
import matplotlib.pyplot as plt 
import numpy as np
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from IPython import display
import numpy
import random
from tkinter import *
from tkinter.messagebox import showinfo
import pdb
from tkinter import filedialog
from PIL import Image, ImageTk

In [2]:
# Data set creation functions

# Function that creates a list of transforms from the Pytorch library
# Each image in the dataset is passed through the transforms
# Transforms include: Transfer to Tensor data structure, random crops, and normalization

def get_transform(load_size,crop_size,color):
    
    transform_list = [transforms.Resize([load_size,load_size],Image.BICUBIC),transforms.RandomCrop(crop_size),
                      transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
    
    if color == 'grayscale':
        transform_list.append(transforms.Grayscale(num_output_channels = 3))
        
    composed_transforms = transforms.Compose(transform_list)
    
    return composed_transforms

# Function that gets a chosen number of images from the dataset in order to display the network training progress

def get_test_dataset(num_imgs,dataset):
    a = 0
    tensor_list_A = []
 
    for i, data in enumerate(dataset):
 
        real_A = data['A']

        tensor_list_A.append(real_A[0])
       
        a +=1
        if a == num_imgs:
            break
        
    img_list = torch.stack(tensor_list_A) 
    return img_list

# Class that defines the structure of the dataset
# Sets up the file paths to each style domain image folder
# Transforms and prepares each image for enumeration
# Creates an indexing system for the dataset

class DatasetStructure():
    
    def initialize(self,path,loadSize,fineSize,color):
        
        self.dir_A = os.path.join(path, 'train' + 'A')
        self.dir_B = os.path.join(path, 'train' + 'B')
        
        self.A_paths = []
    
        for root, _, files in sorted(os.walk(self.dir_A)):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.png'):
                    path = os.path.join(root, file)
                    self.A_paths.append(path)
                    
        self.B_paths = []
    
        for root, _, files in sorted(os.walk(self.dir_B)):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.png'):
                    path = os.path.join(root, file)
                    self.B_paths.append(path)                 

        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        
        self.transform = get_transform(loadSize,fineSize,color) 
        
    def __getitem__(self, index):
        
        A_path = self.A_paths[index % self.A_size]

        index_B = index % self.B_size

        B_path = self.B_paths[index_B]
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        A = self.transform(A_img)
        B = self.transform(B_img)

        return {'A': A, 'B': B,
                'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        return max(self.A_size, self.B_size)
    
# Creates an instance of the dataset
# Specifies the image sizes in the dataset and whether they are gray or color

def createDataset(path,load_size,crop_size,color):
    
    instance = DatasetStructure()
    instance.initialize(path,load_size,crop_size,color)
    
    dataset = instance
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=not 'store_true', num_workers=int(0))
    
    return dataloader

def new_dataset(path,size,color):
    
    test_dataset = createDataset(path,size,size,color)

    tensor_list = []    
    for i, data in enumerate(test_dataset):
        real_A = data['A']
        real_B = data['B']
        real = [real_A,real_B]
        tensor_list.append(real[0][0])     
 
    img_data = torch.stack(tensor_list)   
    return img_data

In [3]:
# Network functions

# initializes NN weights

def init_weights(net):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):

            init.normal_(m.weight.data, 0.0, 0.02)

            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
            
    net.apply(init_func)
    
class Discriminator(nn.Module):
    def __init__(self, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(Discriminator, self).__init__()
        
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        
        NN_sequence = [nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1),nn.LeakyReLU(0.2,True),
                    
                    nn.Conv2d(64, 64 * 2,kernel_size=4, stride=2, padding=1, bias=use_bias),
                    norm_layer(64 * 2),nn.LeakyReLU(0.2, True),
                    
                    nn.Conv2d(64 * 2, 64 * 4,kernel_size=4, stride=2, padding=1, bias=use_bias),
                    norm_layer(64 * 4),nn.LeakyReLU(0.2, True),
                    
                    nn.Conv2d(64 * 4, 64 * 8,kernel_size=4, stride=2, padding=1, bias=use_bias),
                    norm_layer(64 * 8),nn.LeakyReLU(0.2, True),
                    
                    nn.Conv2d(64 * 8, 64 * 8,kernel_size=4, stride=1, padding=1, bias=use_bias),
                    norm_layer(64 * 8),nn.LeakyReLU(0.2, True),
                    
                    nn.Conv2d(64 * 8, 1, kernel_size=4, stride=1, padding=1)]

        if use_sigmoid:
            NN_sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*NN_sequence)
        
    def forward(self, input):
        return self.model(input) 
    
class Generator(nn.Module):
    def __init__(self,norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d 
            
        encoder = [nn.ReflectionPad2d(3),nn.Conv2d(3, 64, kernel_size=7, padding=0,bias=use_bias),norm_layer(64),nn.ReLU(True), 
                 
                 nn.Conv2d(64, 64 * 2, kernel_size=3,stride=2, padding=1, bias=use_bias),
                 norm_layer(64 * 2),nn.ReLU(True),
                 
                 nn.Conv2d(64 * 2, 64 * 4, kernel_size=3,stride=2, padding=1, bias=use_bias),
                 norm_layer(64 * 4),nn.ReLU(True)]
        
        transformer = []
        for i in range(n_blocks):
            transformer += [ResnetBlock(64 * 4, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, 
                                  use_bias=use_bias)]           
       
        decoder = [nn.ConvTranspose2d(64 * 4, int(64 * 4 / 2),kernel_size=3, stride=2,padding=1, output_padding=1,
                                      bias=use_bias),norm_layer(int(64 * 4 / 2)),nn.ReLU(True), 
                   
                   nn.ConvTranspose2d(64 * 2, int(64 * 2 / 2),kernel_size=3, stride=2,padding=1, output_padding=1,
                                      bias=use_bias),norm_layer(int(64 * 2 / 2)),nn.ReLU(True),
                   
                   nn.ReflectionPad2d(3),
                   
                   nn.Conv2d(64, 3, kernel_size=7, padding=0),
                   
                   nn.Tanh()]

        model = encoder + transformer + decoder
        
        self.model = nn.Sequential(*model)
        
    def forward(self, input):
        return self.model(input)
    
# Define a resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        
        self.block = self.build_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):

        block = [nn.ReflectionPad2d(1),
                       
                       nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias),norm_layer(dim),nn.ReLU(True)]
        
        if use_dropout:
            block += [nn.Dropout(0.5)]

        block += [nn.ReflectionPad2d(1),
                       
                  nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias),norm_layer(dim)]

        return nn.Sequential(*block)

    def forward(self, x):
        out = x + self.block(x)
        return out
    
def createModels():
    norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    
    DA = Discriminator(norm_layer, 'store_true')
    DA.to('cuda')  
    init_weights(DA)
    
    DB = Discriminator(norm_layer, 'store_true')
    DB.to('cuda')  
    init_weights(DB)
    
    GAB = Generator(norm_layer, not 'store_true', n_blocks=9)
    GAB.to('cuda')  
    init_weights(GAB)
    
    GBA = Generator(norm_layer, not 'store_true', n_blocks=9)
    GBA.to('cuda')  
    init_weights(GBA)
     
    
    return GAB,GBA,DA,DB

def createNN(NN_type,gpu):
    
    norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    
    if NN_type == 'generator':
        NN = Generator(norm_layer, not 'store_true', n_blocks=9)

    elif NN_type == 'discriminator':
        NN = Discriminator(norm_layer, 'store_true')
      
    if gpu:
        NN.to('cuda')  
    else:
        NN.to('cpu')  
        
    init_weights(NN)
        
    return NN

In [4]:
# Training functions

def update_pool(imgs,pool):
    pool_imgs = []
    for img in imgs:
        img = torch.unsqueeze(img.data, 0)
        if pool[1] < pool[0]:
            pool[1] = pool[1] + 1
            pool[2].append(img)
            pool_imgs.append(img)
        else:
            if random.uniform(0, 1) > 0.5:
                r_id = random.randint(0, pool[0] - 1)
                temp_img = pool[2][r_id].clone()
                pool[2][r_id] = img
                pool_imgs.append(temp_img)
            else:
                pool_imgs.append(img)
    pool_imgs = torch.cat(pool_imgs, 0)
    return pool_imgs

def gan_loss(input,target):
    
    loss = nn.BCELoss()
        
    if target:
        target_tensor = torch.tensor(1.0)
    else:
        target_tensor = torch.tensor(0.0)
        
    target_tensor = target_tensor.expand_as(input)
    target_tensor = target_tensor.cuda()
        
    return loss(input,target_tensor)

# learning rate functions

def get_scheduler(optimizer):

    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch - 99) / float(101)
        return lr_l
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

    return scheduler

def update_learning_rate():
    for scheduler in schedulers:
        scheduler.step()
    lr = optimizers[0].param_groups[0]['lr']
    
    
def update_learning_rate1(schedulers,optimizers):
    for scheduler in schedulers:
        scheduler.step()
    lr = optimizers[0].param_groups[0]['lr']
    
def createOptimizers():
      
    optimizer_G = torch.optim.Adam(itertools.chain(GAB.parameters(), GBA.parameters()),0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(itertools.chain(DA.parameters(), DB.parameters()),0.0002, betas=(0.5, 0.999))
    optimizers = [optimizer_G,optimizer_D]
    
    return optimizers

def optimize_parameters(real_A,real_B,optimizers):
    
    lf = torch.nn.L1Loss()
    
    fake_B = GAB(real_A)
    re_A = GBA(fake_B)
    fake_A = GBA(real_B)
    re_B = GAB(fake_A)
    
    nets = [DA, DB]
    requires_grad = False    
    for net in nets:
        for param in net.parameters():
            param.requires_grad = requires_grad           
    optimizers[0].zero_grad()
    
    g_loss = (gan_loss(DA(fake_B), True)) + (gan_loss(DB(fake_A), True)) + (lf(re_A, real_A) * 10.0) + \
             (lf(re_B, real_B) * 10.0) + (lf(GAB(real_B), real_B) * 10.0 * 0.5) + \
             (lf(GBA(real_A), real_A) * 10.0 * 0.5)
    
    g_loss.backward()
    
    optimizers[0].step() 
        
    nets = [DA, DB]
    requires_grad = True   
    for net in nets:
        for param in net.parameters():
            param.requires_grad = requires_grad            
    optimizers[1].zero_grad()  
    
    # calculate discrimintor A loss, by predicting real images and generator AB fake images
     
    fake_B = update_pool(fake_B,B_pool)    
    pred_real_DA = DA(real_B)
    pred_fake_DA = DA(fake_B.detach())       
    DA_loss = ((gan_loss(pred_real_DA, True)) + (gan_loss(pred_fake_DA, False))) * 0.5
    DA_loss.backward() 
    
    # calculate discrimintor B loss, by predicting real images and generator AB fake images    
    
    fake_A = update_pool(fake_A,A_pool)   
    pred_real_DB = DB(real_A)    
    pred_fake_DB = DB(fake_A.detach())    
    DB_loss = ((gan_loss(pred_real_DB, True)) + (gan_loss(pred_fake_DB, False))) * 0.5
    DB_loss.backward()
    
    optimizers[1].step()  
    
def startTraining(epoch_limit,in_size,out_size,sample_num,path,save_image_path):

    # create dataset

    dataset = createDataset(train_path,in_size,out_size,'color')

    img_list = get_test_dataset(sample_num,dataset)
        
    # training process

    for epoch in range(1,epoch_limit):
        i = 0
        for i, data in enumerate(dataset):
        
            real_A = data['A']
            real_B = data['B']
            if torch.cuda.is_available(): real_A = real_A.cuda()
            if torch.cuda.is_available(): real_B = real_B.cuda()
  
            optimize_parameters(real_A,real_B,optimizers)
        
            # Display Progress
            if (i) % 100 == 0:
                display.clear_output(True)
                # Display Images
                test_images_B = GAB(img_list.cuda()).data.cpu()
                        
                saveImages(save_image_path,test_images_B,epoch,i,'cyclegan','test_save',i)

                print (i)
                print ('epoch: ' ,epoch,'/',200)     

                saveModel(GAB,epoch,str(i),path)
            i = i + 1
        
        update_learning_rate()
        

In [6]:
main_path = os.path.dirname(os.path.abspath('__file__'))

print(main_path)

C:\Users\ian_000\Desktop\GAN_final_code


In [None]:
# Start training process

GAB,GBA,DA,DB = createModels()

optimizers = createOptimizers()

schedulers = [get_scheduler(optimizer) for optimizer in optimizers]

pool_size = 50
A_num_imgs = 0
A_imgs = []
B_num_imgs = 0
B_imgs = []

A_pool = [pool_size,A_num_imgs,A_imgs]

B_pool = [pool_size,B_num_imgs,B_imgs]

save_model_path = 'F:/GAN/saved_models/test_models/'
save_image_path = 'F:/GAN/saved_images/test'
train_path = 'F:/GAN/datasets/color2render'

startTraining(10, 96, 64, 1,save_model_path,save_image_path)