In [1]:
import torchvision.utils as vutils
import copy
import math
import os
import numpy as np
from PIL import Image, ImageFile
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as td
import torchvision as tv
import pandas as pd
from torch.autograd import Variable
from io import BytesIO
import itertools
from image_pool import ImagePool
import time
from dataset import ArtDataset,LandscapeDataset,myimshow
from model import weights_init,Generator,Discriminator,cal_loss_Cycle,cal_loss_Gan
# from Unetmodel import weights_init,Generator,Discriminator,cal_loss_Cycle,cal_loss_Gan
# from DnCNNmodel import weights_init,Generator,Discriminator,cal_loss_Cycle,cal_loss_Gan
'''
default setting is to run the model of Zhu, Park et al's
to set unet=True below to run the other two models
'''

"\ndefault setting is to run the model of Zhu, Park et al's\nto set unet=True below to run the other two models\n"

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
print(device)

cuda


In [3]:
batchsize = 1
landscape_root_dir = '/datasets/ee285f-public/flickr_landscape/'
art_root_dir = '/datasets/ee285f-public/wikiart'

In [4]:
X_sets = LandscapeDataset(landscape_root_dir)
X_loader =  list(td.DataLoader(X_sets, batch_size = batchsize, shuffle = True, pin_memory = True))

In [5]:
Y_sets = ArtDataset(art_root_dir)
Y_loader =  list(td.DataLoader(Y_sets, batch_size = batchsize, shuffle = True, pin_memory = True))

In [6]:
input_nc = 3    # the number of channels of input data
output_nc = 3 # the number of channels of output data
# Create the discriminator D_X--distinguish the image X in domain X and F(Y)
D_X = Discriminator(input_nc).to(device)
D_X.apply(weights_init)

# Create the discriminator D_Y--distinguish the image Y in domain Y and G(X)
D_Y = Discriminator(output_nc).to(device)
D_X.apply(weights_init)


unet=False  # if true, use Unet or DnCNN, if false, use encoder-resnet block-decoder
if(unet):
    # Create the generator G--Generator the image from X to Y domain
    G = Generator(6).to(device)
    # Create the generator F--Generator the image from Y to X domain
    F=Generator(6).to(device)
else:
    G = Generator(input_nc, output_nc).to(device)    
    F = Generator(input_nc, output_nc).to(device)
G.apply(weights_init)
F.apply(weights_init)

Generator(
  (conv): ModuleList(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (bn): ModuleList(
    (0): BatchNorm2d(3, eps=3, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm2d(3, eps=3, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm2d(3, eps=3, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu): ReLU(inplace)
)

In [7]:
def run(num_epoch, D_X, D_Y, G, F, X_loader, Y_loader,  lamda = 10):
    # Initialize the MSELoss function
    criterion1 = nn.MSELoss()
    # Initialize the L1Loss function
    criterion2 = nn.L1Loss()
    real_label = Variable(torch.cuda.FloatTensor(1).fill_(1.0), requires_grad = False)
    fake_label = Variable(torch.cuda.FloatTensor(1).fill_(0.0), requires_grad = False)
    # Setup Adam optimizers for D_X, D_Y, G, F
    optimizerD_X = optim.Adam(D_X.parameters(), lr = 1e-3, betas=(0.5, 0.999))
    optimizerD_Y = optim.Adam(D_Y.parameters(), lr = 1e-3, betas=(0.5, 0.999))
    optimizerGenerator = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr = 1e-3, betas=(0.5, 0.999))
    loss_Dx_plot=[]
    loss_Dy_plot=[]
    loss_cycle_plot=[]
    for epoch in range(start_epoch, num_epoch):
        fake_X_pool=ImagePool(50)
        fake_Y_pool=ImagePool(50)
        for i in range(len(X_loader)):
            real_X = (X_loader[i]).cuda()
            real_Y = (Y_loader[i]).cuda()
            fake_X = F(real_Y)
            fake_Y = G(real_X)
            Y_D_1 = D_Y(fake_Y)
            X_D_1 = D_X(fake_X)
            
            D_X.requires_grad = False
            D_Y.requires_grad = False
            optimizerGenerator.zero_grad()
            gan_loss_X = criterion1(Y_D_1 , real_label)
            gan_loss_Y = criterion1(X_D_1, real_label)
            
            loss_cycle = cal_loss_Cycle(F, real_X, fake_Y) + cal_loss_Cycle(G,real_Y,fake_X)
            loss_G = lamda * loss_cycle + gan_loss_X + gan_loss_Y
            loss_G.backward(retain_graph=True)
            optimizerGenerator.step()
            
            # Update D_X network: minimize D_X(F(Y)**2 + (D_X(X) - 1)**2
            D_X.requires_grad = True
            D_X.zero_grad()
            fake_X = fake_X_pool.query(fake_X)
            loss_D_X = cal_loss_Gan(D_X, real_X, fake_X)   
            loss_D_X.backward(retain_graph=True)
            optimizerD_X.step()
            
            # Update D_Y network: minimize D_Y(G(X)**2 + (D_Y(Y) - 1)**2
            D_Y.requires_grad = True
            D_Y.zero_grad()
            fake_Y= fake_Y_pool.query(fake_Y)
            loss_D_Y = cal_loss_Gan(D_Y, real_Y, fake_Y)
            loss_D_Y.backward(retain_graph = True) 
            optimizerD_Y.step()
            
            loss_total = loss_D_X + loss_D_Y + loss_cycle
        
        print("Epoch: {}/{}".format(epoch, num_epoch))
        print("Dx = {}, Dy = {}, cycle = {}, total loss = {}".format(loss_D_X, loss_D_Y, loss_cycle, loss_total))
        if(epoch%20 == 0):
            plt.figure()
            myimshow(G(X_loader[0].cuda())[0].detach())
            loss_Dx_plot.append(loss_D_X)
            loss_Dy_plot.append(loss_D_Y)
            loss_cycle_plot.append(loss_cycle)
         # Save models checkpoints
        torch.save(G.state_dict(), 'output2/G.pth')
        torch.save(F.state_dict(), 'output2/F.pth')
        torch.save(D_X.state_dict(), 'output2/D_X.pth')
        torch.save(D_Y.state_dict(), 'output2/D_Y.pth')
    return loss_Dx_plot,loss_Dy_plot,loss_cycle_plot

In [None]:
start_epoch=0
num_epoch = 50
if start_epoch != 0:
    G.load_state_dict(torch.load('output/G.pth'))
    F.load_state_dict(torch.load('output/F.pth'))
    D_X.load_state_dict(torch.load('output/D_X.pth'))
    D_Y.load_state_dict(torch.load('output/D_Y.pth'))
else:
    G.apply(weights_init)
    F.apply(weights_init)
    D_X.apply(weights_init)
    D_Y.apply(weights_init)
loss_Dx_plot,loss_Dy_plot,loss_cycle_plot=run(num_epoch, D_X, D_Y, G, F, X_loader, Y_loader)