In [1]:
# This code is provided for Deep Learning (CS 482/682) Homework 6 practice.
# The network structure is a simplified U-net. You need to finish the last layers
# @Copyright Cong Gao, the Johns Hopkins University, cgao11@jhu.edu
# Modified by Hongtao Wu on Oct 11, 2019 for Fall 2019 Machine Learning: Deep Learning HW6

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms 
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
# Functions for adding the convolution layer
def add_conv_stage(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True, useBN=False):
    if useBN:
        # Use batch normalization
        return nn.Sequential(
          nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
          nn.BatchNorm2d(dim_out),
          nn.LeakyReLU(0.1),
          nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
          nn.BatchNorm2d(dim_out),
          nn.LeakyReLU(0.1)
        )
    else:
        # No batch normalization
        return nn.Sequential(
          nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
          nn.ReLU(),
          nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
          nn.ReLU()
        )


# Upsampling
def upsample(ch_coarse, ch_fine):
    return nn.Sequential(
        nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
        nn.ReLU()
        )


# U-Net
class unet(nn.Module):
    def __init__(self, useBN=False):
        super(unet, self).__init__()
        # Downgrade stages
        self.conv1 = add_conv_stage(1, 32, useBN=useBN)
        self.conv2 = add_conv_stage(32, 64, useBN=useBN)
        self.conv3 = add_conv_stage(64, 128, useBN=useBN)
        self.conv4 = add_conv_stage(128, 256, useBN=useBN)
        # Upgrade stages
        self.conv3m = add_conv_stage(256, 128, useBN=useBN)
        self.conv2m = add_conv_stage(128,  64, useBN=useBN)
        self.conv1m = add_conv_stage( 64,  32, useBN=useBN)
        # Maxpool
        self.max_pool = nn.MaxPool2d(2)
        # Upsample layers
        self.upsample43 = upsample(256, 128)
        self.upsample32 = upsample(128,  64)
        self.upsample21 = upsample(64 ,  32)
        # weight initialization
        # You can have your own weight intialization. This is just an example.
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    m.bias.data.zero_()

        #TODO: Design your last layer & activations
#         self.convLast = add_conv_stage(32, 8, useBN=useBN, kernel_size=1, padding=0)
        self.convLast = nn.Conv2d(32, 3, kernel_size=1)

    def forward(self, x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(self.max_pool(conv1_out))
        conv3_out = self.conv3(self.max_pool(conv2_out))
        conv4_out = self.conv4(self.max_pool(conv3_out))

        conv4m_out_ = torch.cat((self.upsample43(conv4_out), conv3_out), 1)
        conv3m_out  = self.conv3m(conv4m_out_)

        conv3m_out_ = torch.cat((self.upsample32(conv3m_out), conv2_out), 1)
        conv2m_out  = self.conv2m(conv3m_out_)

        conv2m_out_ = torch.cat((self.upsample21(conv2m_out), conv1_out), 1)
        conv1m_out  = self.conv1m(conv2m_out_)

        #TODO: Design your last layer & activations
        convfinal_out = self.convLast(conv1m_out)

        return torch.softmax(convfinal_out, dim=1)

In [None]:
######################## Hyperparameters #################################
# Batch size can be changed if it does not match your memory, please state your batch step_size
# in your report.
train_batch_size = 10
validation_batch_size=10
# Please use this learning rate for Prob1(a) and Prob1(b)
learning_rate = 0.1
# This num_epochs is designed for running to be long enough, you need to manually stop or design
# your early stopping method.
num_epochs = 20

# TODO: Design your own dataset
class ImageDataset(Dataset):
    def __init__(self, input_dir):
        datalist = []
        labellist = []
        onehotlabellist = []
        for folders in os.listdir(input_dir):
            new_dir = os.path.join(input_dir, folders)
            if os.path.isdir(new_dir):
                impath = new_dir+ '/'+folders+'_gray.jpg'
                labelpath = new_dir+ '/'+folders+'_input.jpg'
                im = Image.open(impath)
                imLab = Image.open(labelpath)
                im = np.array(im)
                im = np.expand_dims(im, axis = 2)
                datalist.append(im)
                labellist.append(np.array(imLab))
                
        self.data = torch.from_numpy(np.transpose(np.array(datalist), (0, 3, 1, 2))).type('torch.FloatTensor')
        self.label = torch.from_numpy(np.transpose(np.array(labellist), (0, 3, 1, 2))).type('torch.FloatTensor')
        
        
    def mask(self, label, val):
        return label & val == val
        
    def __len__ (self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


loss = nn.MSELoss()

# TODO: Use your designed dataset for dataloading
train_dataset=ImageDataset(input_dir = "./HW6_data/colorization/train_cor/")
validation_dataset=ImageDataset(input_dir =  "./HW6_data/colorization/validation_cor/")
model = unet(useBN = True)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                       batch_size=train_batch_size, 
                                       shuffle=True)
    
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, 
                                           batch_size=validation_batch_size, 
                                           shuffle=True)

training_loss = []
validation_loss = []
training_iter = iter(train_loader)
validation_iter = iter(validation_loader)
print("Start Training...")
for epoch in range(num_epochs):

    ########################### Training #####################################
    print("\nEPOCH " +str(epoch+1)+" of "+str(num_epochs)+"\n")
    # TODO: Design your own training section
    model.train()
    torch.set_grad_enabled(True)
    for ind, training_data in enumerate(train_loader):
        x = autograd.Variable(training_data[0]).cuda()
        y = autograd.Variable(training_data[1]).cuda()

        optimizer.zero_grad()
        y_hat = model(x)
        train_loss = loss(y_hat, y)
        train_loss.backward()
        optimizer.step()

        training_loss.append(train_loss.data)
    print(train_loss.data)

    ########################### Validation #####################################
    # TODO: Design your own validation section
    print("\nStart Validation...")
    torch.set_grad_enabled(False)
    for ind, valid_data in enumerate(validation_loader):
        x = autograd.Variable(valid_data[0]).cuda()
        y = autograd.Variable(valid_data[1]).cuda()
        y_hat = model(x)
        val_loss = loss(y_hat, y)
        print(val_loss.data)
        validation_loss.append(val_loss.data)

Start Training...

EPOCH 1 of 20

tensor(12025.0566, device='cuda:0')

Start Validation...
tensor(10672.5254, device='cuda:0')
tensor(11425.9062, device='cuda:0')
tensor(11024.6084, device='cuda:0')
tensor(10872.1689, device='cuda:0')
tensor(10790.8223, device='cuda:0')

EPOCH 2 of 20

tensor(10794.3252, device='cuda:0')

Start Validation...
tensor(10774.1621, device='cuda:0')
tensor(10810.3398, device='cuda:0')
tensor(11061.5449, device='cuda:0')
tensor(10747.0918, device='cuda:0')
tensor(11392.8936, device='cuda:0')

EPOCH 3 of 20



In [None]:
# Plot training loss, validation loss
def plot(training_loss, validation_loss):
    plt.figure(300)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')   
    plt.plot(training_loss, 'b')
    plt.figure(400)
    plt.title('Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')   
    plt.plot(validation_loss, 'b')
    plt.show()

In [None]:
plot(training_loss, validation_loss)