In [1]:
import torch
import torchvision
from torch import nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn import init

from settings import EPOCHS
from unet import UNET
from Discriminator import Discriminator
import torchvision.transforms as transforms
from torch.nn import BCELoss, L1Loss, MSELoss
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg
from torch.optim import lr_scheduler


def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def get_scheduler(optimizer):
    
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + 1 - EPOCHS) / float(EPOCHS + 1)
        return lr_l
    
    return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

def update_learning_rate(scheduler, optimizer):
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    print('learning rate = %.7f' % lr)

def trainPix2Pix(model, data, totalEpochs=EPOCHS, genLr=0.0001, descLr=0.00005):
    genOptimizer = Adam( list(model.gen.parameters()), lr=0.0001)
    discOptimizer = Adam( list(model.disc.parameters()), lr=0.00005)
    
    net_g_scheduler = get_scheduler(genOptimizer)
    net_d_scheduler = get_scheduler(discOptimizer)

    
    criterion = MSELoss().cuda()
    criterionL1 = L1Loss().cuda()

    model.gen.train()
    model.disc.train()
    for epoch in range(totalEpochs):
        print("Epoch " + str(epoch))
        for minibatch, (color_and_gray, gray_three_channel) in enumerate(data):
            train_step(model, color_and_gray.cuda(), gray_three_channel.cuda(), criterion, genOptimizer, discOptimizer, criterionL1)
            if minibatch > 10:
                break
        
        update_learning_rate(net_g_scheduler, genOptimizer)
        update_learning_rate(net_d_scheduler, discOptimizer)

# assumes minibatch is only colord images.
def train_step(model, color, black_white, criterion, gen_optimizer, disc_optimizer, criterion_l1):    
    # generate images
    generated = model.generate(black_white)
    
    disc_optimizer.zero_grad()
    
    input_output = torch.cat((black_white, generated), 1)
    input_target = torch.cat((black_white, color), 1)
    
    # train with generated   
    pred_generated = model.discriminate(input_output.data)
    generated_labels = torch.tensor(0).expand_as(pred_generated).cuda()
    loss_false = criterion(pred_generated, generated_labels.float())
    
    # train with target
    pred_targets = model.discriminate(input_target)
    targets_labels = torch.tensor(1).expand_as(pred_targets).cuda()
    loss_true = criterion(pred_targets, targets_labels.float())
    
    print("Loss false: {} Loss true {}: ".format(loss_false, loss_true))
    loss_discriminator = (loss_false + loss_true) / 2
    loss_discriminator.backward()
    disc_optimizer.step()

    
    gen_optimizer.zero_grad()
    pred_output =  model.discriminate(input_output) 

    loss_gen = criterion(pred_output, targets_labels.float())
    # G(A) = B
    loss_ab = criterion_l1(black_white, color) * 1 # weight, L1 term 

    loss_gen = (loss_gen + loss_ab)/2
    
    loss_gen.backward()
    gen_optimizer.step()




class pix2pix(nn.Module):

    def __init__(self):
        super(pix2pix, self).__init__()
        numclasses = 3 #RGB
        numchannels = 64
        self.gen = UNET(numclasses, numchannels)
        self.disc = Discriminator()
#         self.criterion = CrossEntropyLoss()
        self.writer = SummaryWriter('runs/pix2pix')

    def log_image(self, images):
        # write to tensorboard
        img_grid = torchvision.utils.make_grid(images)
        self.writer.add_image('four_fashion_mnist_images', img_grid)

    def log_metrics(self, epoch, loss):
        self.writer.add_scalar('training loss', loss, epoch)
        self.trainData.append(loss)

    def generate(self, greyscale):
        return self.gen(greyscale)
        #Need to add dropout

    def discriminate(self, img):
        #(images, features, height, width)
        # Return average - 1 value for all images
        ret = self.disc(img)
        ret = torch.mean(ret, axis=2)
        ret = torch.mean(ret, axis=2)
        return ret






In [2]:
def TenToPic(image):
    s = image.size()
    ret = torch.zeros(s[1], s[2], s[0])
    for i in range(s[0]):
        ret[:, :, i] = image[i, :,:]
    return ret.detach().numpy().astype(int)


# In[11]:


from utils import get_datasets
train_dataset, test_dataset = get_datasets()
ex = None
for i in train_dataset:
    ex = i
    break

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/cs253wi20an/PA5/dataloader.py", line 73, in __getitem__
    color_and_gray = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(color_and_gray)
  File "/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 175, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 209, in normalize
    raise TypeError('tensor is not a torch image.')
TypeError: tensor is not a torch image.


In [None]:
model = pix2pix().cuda()
# out = model.generate(ex[1])
# plt.imshow(TenToPic(out[0,:,:,:]))
# plt.figure()
# plt.imshow(out.detach().numpy()[0,0,:,:])
plt.figure()
plt.imshow(TenToPic(ex[0][0,:,:,:]))
plt.figure()
plt.imshow(TenToPic(ex[1][0,:,:,:]))

model.cuda()

In [None]:
cpuModel = model.cpu()
out = model.generate(ex[1]) * 255
plt.imshow(TenToPic(out[0,:,:,:]))
plt.figure()
plt.imshow(out.detach().numpy()[0,0,:,:])
model.cuda()

In [None]:
trainPix2Pix(model, train_dataset, totalEpochs=10)

In [None]:
cpuModel = model.cpu()
out = model.generate(ex[1]).squeeze(0)
plt.figure()
plt.imshow(TenToPic(ex[1][:,:,:]))
plt.figure()
plt.imshow(TenToPic((out[:,:,:])))
plt.figure()
plt.imshow(out.detach().numpy()[0,0,:,:])
plt.colorbar()
plt.figure()
plt.imshow(out.detach().numpy()[0,1,:,:])
plt.colorbar()
plt.figure()
plt.imshow(out.detach().numpy()[0,2,:,:])
plt.colorbar()
plt.figure()
model.cuda()
temp = 1
plt.figure()
plt.imshow(TenToPic(ex[0][0,:,:,:]))
import pytorch_ssim
print("-----------")
#print(pytorch_ssim.ssim(out[:,:,:,:], ex[0][:,:,:,:]))
#print(pytorch_ssim.ssim(ex[0][:,:,:,:], ex[0][:,:,:,:]))

diff = out[0,0,:,:] - out[0,1,:,:]
plt.figure()
plt.imshow(diff[:,:].detach().numpy())