In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import skimage
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid, save_image

print('PyTorch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
use_gpu = torch.cuda.is_available()
print('Is GPU available:', use_gpu)

PyTorch version: 0.4.1
torchvision version: 0.2.1
Is GPU available: True


In [2]:
# settings

# device
device = torch.device('cuda' if use_gpu else 'cpu')

# batchsize (same as AVID paper)
batchsize = 16

# seed setting (warning : cuDNN's randomness is remaining)
seed = 1
torch.manual_seed(seed)
if use_gpu:
    torch.cuda.manual_seed(seed)
    
# directory settings
# Data directory (for IR-MNIST)
data_dir = '../../data/IR-MNIST/'
train_data_dir = data_dir + 'Train_Samples/'
test_data_dir = data_dir + 'Test_Samples/'

# directory to put generated images
output_dir = data_dir + 'output/'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    
# directory to save state_dict
save_dir = data_dir + 'save/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [3]:
# make dataset class for image loading
class MyDataset(Dataset):
    def __init__(self, root_dir, transform = None):
        self.root_dir = root_dir
        self.list_dir = os.listdir(root_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.list_dir)
    
    def __getitem__(self, idx):
        img_name = self.root_dir + self.list_dir[idx]
        image = skimage.io.imread(img_name)
        
        if self.transform:
            image = self.transform(image)
            
        return image

In [4]:
# define transform
# Normalize [0~255] to [-1~1]
class Normalize:
    def __call__(self, image):
        return (image - 127.5) / 127.5
    
class Tofloat:
    def __call__(self, tensor):
        return tensor.float()
    
tf = transforms.Compose([Normalize(), transforms.ToTensor(), Tofloat()])

In [5]:
# make dataset
imgDataset = MyDataset(train_data_dir, transform = tf)

# split to train data and validation data
train_data, validation_data = train_test_split(imgDataset, test_size = 0.2, random_state = seed)

print('The number of training data:', len(train_data))
print('The number of validation data:', len(validation_data))

The number of training data: 4000
The number of validation data: 1000


In [6]:
# make DataLoader
train_loader = DataLoader(train_data, batch_size = batchsize, shuffle = True)
validation_loader = DataLoader(validation_data, batch_size = batchsize, shuffle = False)

In [7]:
# comment out when runnnig in no GUI machine
'''
X = iter(train_loader).next()[0].numpy()
X = np.transpose(X, [1, 2, 0]) * 0.5 + 0.5
print(X.shape)
plt.imshow(X)
'''

'\nX = iter(train_loader).next()[0].numpy()\nX = np.transpose(X, [1, 2, 0]) * 0.5 + 0.5\nprint(X.shape)\nplt.imshow(X)\n'

In [8]:
# define parts for U-net and FCN for convenience
# downsampling
# conv > batchnorm > dropout > leakyrelu
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4 , stride = 2, padding = 1, \
                                                                    use_batchnorm = True, use_dropout = False):
        super(Downsample, self).__init__()
        self.cv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dr = nn.Dropout(0.5)
        self.rl = nn.LeakyReLU(0.2)
        
        self.use_batchnorm = use_batchnorm
        self.use_dropout = use_dropout
        
    def forward(self, x):
        out = self.cv(x)
        
        if self.use_batchnorm:
            out = self.bn(out)
            
        if self.use_dropout:
            out = self.dr(out)
            
        out = self.rl(out)
        
        return out

In [9]:
# define parts for U-net for convenience
# upsampling (using transposed convolution)
# conv > batchnorm > dropout > relu
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, \
                                                                   use_batchnorm = True, use_dropout = False):
        super(Upsample, self).__init__()
        self.tc = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dr = nn.Dropout(0.5)
        self.rl = nn.ReLU()
        
        self.use_batchnorm = use_batchnorm
        self.use_dropout = use_dropout
        
    def forward(self, x):
        out = self.tc(x)
        
        if self.use_batchnorm:
            out = self.bn(out)
            
        if self.use_dropout:
            out = self.dr(out)
            
        out = self.rl(out)
        
        return out

In [10]:
''' Memo : CNN size equation (no dilation)

                                OUT = (IN + 2*Padding - Kernel_size) / Stride + 1                 
'''
# define Inpainter (Generator for GAN)
# U-net architecture
class Inpainter(nn.Module):
    def __init__(self):
        super(Inpainter, self).__init__()
        
        # U-net encoder
        # default: kernel_size = 4, stride = 2, padding = 1, using batchnorm, no dropout
        self.encoder1 = Downsample(  3,  32, use_batchnorm = False)   # out tensor size: (batchsize,  32, 112, 112)
        self.encoder2 = Downsample( 32,  64)                          # out tensor size: (batchsize,  64,  56,  56)
        self.encoder3 = Downsample( 64, 128)                          # out tensor size: (batchsize, 128,  28,  28)
        self.encoder4 = Downsample(128, 256)                          # out tensor size: (batchsize, 256,  14,  14)
        self.encoder5 = Downsample(256, 512)                          # out tensor size: (batchsize, 512,   7,   7)
        
        # U-net decoder
        # default: kernel_size = 4, stride = 2, padding = 1, using batchnorm, no dropout
        self.decoder1 = Upsample(512    , 512)                        # out tensor size: (batchsize, 512,  14,  14)
        self.decoder2 = Upsample(512+256, 512)                        # out tensor size: (batchsize, 512,  28,  28)
        self.decoder3 = Upsample(512+128, 256)                        # out tensor size: (batchsize, 256,  56,  56)
        self.decoder4 = Upsample(256+ 64, 128)                        # out tensor size: (batchsize, 128, 112, 112)
        self.decoder5 = Upsample(128+ 32,  64)                        # out tensor size: (batchsize,  64, 224, 224)
        
        # pointwise convolution to adjust channel with no image size change
        self.decoder_final = nn.Conv2d(64, 3, kernel_size = 1, stride = 1, padding = 0)
        self.th = nn.Tanh()
        
    def forward(self, x):
        # encoding part
        out_encoder1 = self.encoder1(x)
        out_encoder2 = self.encoder2(out_encoder1)
        out_encoder3 = self.encoder3(out_encoder2)
        out_encoder4 = self.encoder4(out_encoder3)
        out_encoder5 = self.encoder5(out_encoder4)
        
        # decording part
        out_decoder1 = self.decoder1(out_encoder5)
        out_decoder2 = self.decoder2(torch.cat([out_decoder1, out_encoder4], dim = 1))
        out_decoder3 = self.decoder3(torch.cat([out_decoder2, out_encoder3], dim = 1))
        out_decoder4 = self.decoder4(torch.cat([out_decoder3, out_encoder2], dim = 1))
        out_decoder5 = self.decoder5(torch.cat([out_decoder4, out_encoder1], dim = 1))
        
        out = self.decoder_final(out_decoder5)
        out = self.th(out)
        
        return out

In [17]:
''' Memo : CNN size equation (no dilation)

                                OUT = (IN + 2*Padding - Kernel_size) / Stride + 1                 
'''
# define Detector (Discriminator for GAN)
# FCN-architecture (PatchGAN discriminator)
''' Issue : Should we condition discriminator by input Image? '''
class Detector(nn.Module):
    def __init__(self):
        super(Detector, self).__init__()
        # default: kernel_size = 4, stride = 2, padding = 1, using batchnorm, no dropout
        self.fcn1x = Downsample(  3,  32, use_batchnorm = False)   # out tensor size: (batchsize,  32, 112, 112)
        self.fcn1y = Downsample(  3,  32, use_batchnorm = False)   # out tensor size: (batchsize,  32, 112, 112)
        self.fcn2  = Downsample( 64,  64)                          # out tensor size: (batchsize,  64,  56,  56)
        self.fcn3  = Downsample( 64, 128)                          # out tensor size: (batchsize, 128,  28,  28)
        self.fcn4  = Downsample(128, 256)                          # out tensor size: (batchsize, 256,  14,  14)
        self.fcn5  = Downsample(256, 512, stride = 1, padding = 0) # out tensor size: (batchsize, 512,  11,  11)
        
        # pointwise convolution to adjust channel with no image size change
        self.fcn_final = nn.Conv2d(512, 1, kernel_size = 1, stride = 1, padding = 0)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, y):
        # input x:Inpainter output, y:original image for conditioning
        out_x = self.fcn1x(x)
        out_y = self.fcn1y(y)
        out = torch.cat([out_x, out_y], dim = 1)
        
        out = self.fcn2(out)
        out = self.fcn3(out)
        out = self.fcn4(out)
        out = self.fcn5(out)
        
        out = self.fcn_final(out)
        out = self.sigmoid(out)
        
        return out

In [20]:
# network, optimizer and hyperparameters
inpainter = Inpainter()
detector = Detector()