In [6]:
import torch
import torchvision
from torch import nn as nn
from torch.nn.utils.rnn import pack_padded_sequence

from settings import EPOCHS
from unet import UNET
from Discriminator import Discriminator
import torchvision.transforms as transforms
from torch.nn import BCELoss, L1Loss, MSELoss
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg



def trainPix2Pix(model, data, totalEpochs=EPOCHS, genLr=0.0001, descLr=0.00005):
    genOptimizer = Adam( list(model.gen.parameters()), lr=genLr)
    discOptimizer = Adam( list(model.disc.parameters()), lr=descLr)
    criterion = MSELoss().cuda()
    criterionL1 = L1Loss().cuda()

    model.gen.train()
    model.disc.train()
    for epoch in range(totalEpochs):
        for minibatch, (color_and_gray, gray_three_channel) in enumerate(data):
            train_step(model, color_and_gray.cuda(), gray_three_channel.cuda(), criterion, genOptimizer, discOptimizer, criterionL1)


# assumes minibatch is only colord images.
def train_step(model, color, black_white, criterion, gen_optimizer, disc_optimizer, criterion_l1):
    model.gen.zero_grad()
    model.disc.zero_grad()
    
    # generate images
    generated = model.generate(black_white)
    
    input_output = torch.cat((black_white, generated), 3)
    input_target = torch.cat((black_white, color), 3)
    
    # train with generated   
    pred_generated = model.discriminate(input_output.detach())
    generated_labels = torch.tensor(0).expand_as(pred_generated).cuda()
    loss_false = criterion(pred_generated, generated_labels.float())
    
    # train with target
    pred_targets = model.discriminate(input_target)
    targets_labels = torch.tensor(1).expand_as(pred_targets).cuda()
    loss_true = criterion(pred_targets, targets_labels.float())
    
    loss_discriminator = (loss_false + loss_true) / 2
    loss_discriminator.backward()
    disc_optimizer.step()


    pred_output =  model.discriminate(input_output) 

    loss_gen = criterion(pred_output, targets_labels.float())
    # G(A) = B
    loss_ab = criterion_l1(black_white, color) * 10 # weight, L1 term 

    loss_gen = loss_gen + loss_ab / 2
    
    loss_gen.backward()
    gen_optimizer.step()




class pix2pix(nn.Module):

    def __init__(self):
        super(pix2pix, self).__init__()
        numclasses = 3 #RGB
        numchannels = 64
        self.gen = UNET(numclasses, numchannels)
        self.disc = Discriminator()
#         self.criterion = CrossEntropyLoss()
        self.writer = SummaryWriter('runs/pix2pix')

    def log_image(self, images):
        # write to tensorboard
        img_grid = torchvision.utils.make_grid(images)
        self.writer.add_image('four_fashion_mnist_images', img_grid)

    def log_metrics(self, epoch, loss):
        self.writer.add_scalar('training loss', loss, epoch)
        self.trainData.append(loss)

    def generate(self, greyscale):
        return self.gen(greyscale)
        #Need to add dropout

    def discriminate(self, img):
        #(images, features, height, width)
        # Return average - 1 value for all images
        return self.disc(img)







In [None]:
def TenToPic(image):
    s = image.size()
    ret = torch.zeros(s[1], s[2], s[0])
    for i in range(s[0]):
        ret[:, :, i] = image[i, :,:]
    return ret.detach().numpy().astype(int)


# In[11]:


from utils import get_datasets
train_dataset, test_dataset = get_datasets()
ex = None
for i in train_dataset:
    ex = i
    break

In [None]:
model = pix2pix().cuda()
# out = model.generate(ex[1])
# plt.imshow(TenToPic(out[0,:,:,:]))
# plt.figure()
# plt.imshow(out.detach().numpy()[0,0,:,:])
plt.figure()
plt.imshow(TenToPic(ex[0][0,:,:,:]))
plt.figure()
plt.imshow(TenToPic(ex[1][0,:,:,:]))

model.cuda()
asdfasdf=1

In [None]:
cpuModel = model.cpu()
out = model.generate(ex[1])
plt.imshow(TenToPic(out[0,:,:,:]))
plt.figure()
plt.imshow(out.detach().numpy()[0,0,:,:])
model.cuda()
asdfasdf=1

In [5]:
trainPix2Pix(model, train_dataset, totalEpochs=10, genLr=0.01, descLr=0.005)

RuntimeError: Given groups=1, weight of size 64 3 4 4, expected input[1, 6, 224, 224] to have 3 channels, but got 6 channels instead

In [None]:
cpuModel = model.cpu()
out = model.generate(ex[1])
plt.imshow(TenToPic((out[0,:,:,:] + 20)))
plt.figure()
plt.imshow(out.detach().numpy()[0,0,:,:])
plt.colorbar()
plt.figure()
plt.imshow(out.detach().numpy()[0,1,:,:])
plt.colorbar()
plt.figure()
plt.imshow(out.detach().numpy()[0,2,:,:])
plt.colorbar()
plt.figure()
model.cuda()
temp = 1
plt.figure()
plt.imshow(TenToPic(ex[0][0,:,:,:]))