In [2]:
import os
import numbers

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid, save_image

print('PyTorch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
can_use_gpu = torch.cuda.is_available()
print('Is GPU available:', can_use_gpu)

PyTorch version: 0.4.1
torchvision version: 0.2.1
Is GPU available: True


In [19]:
# general settings

device = torch.device('cuda' if can_use_gpu else 'cpu')

batchsize_train = 64
batchsize_validation = 5

height_for_train_cropping = 128
width_for_train_cropping = 128
height_for_validation_cropping = 512
width_for_validation_cropping = 512

# TODO:seed setting and exclude randomness?

# directory settings
root_dir = '../data/'

# training data directory
image_dir = root_dir + 'images_resized_quarter/'
label_dir = root_dir + 'onex0.8_resized_quarter/'

# directory to save model output
estimated_label_dir = root_dir + 'estimated_onex0.8_resized_quarter/'
if not os.path.exists(estimated_label_dir):
    os.mkdir(estimated_label_dir)

# directory to save model weights and training log
log_dir = root_dir + 'log_onex0.8_resized_quarter/'
if not os.path.exists(log_dir):
    os.mkdir(log_dir)

In [20]:
class DocDataset(Dataset):
    def __init__(self, image_dir, label_dir, file_name_list,
                 transform_sync=None, transform_image=None, transform_label=None):
        assert(image_dir[-1] == '/')
        assert(label_dir[-1] == '/')
        self.image_dir = image_dir
        self.label_dir = label_dir
        
        # image or label filename list in image_dir or label_dir (to speedup train_test_split, I'll split file name list)
        # I expect corresponding image and label have same filename
        # This sort is so that following __getitem__ method expect file_name_list have unique order
        self.file_name_list = sorted(file_name_list) 
        
        # to do same random cropping for corresponding image and label
        self.transform_sync = transform_sync
        self.transform_image = transform_image
        self.transform_label = transform_label
        
    def __len__(self):
        return len(self.file_name_list)
    
    def __getitem__(self, idx):
        image_name = self.image_dir + self.file_name_list[idx]
        label_name = self.label_dir + self.file_name_list[idx]
        
        image = Image.open(image_name)
        label = Image.open(label_name)
        
        if self.transform_sync is not None:
            image, label = self.transform_sync(image, label)
        if self.transform_image is not None:
            image = self.transform_image(image)
        if self.transform_label is not None:
            label = self.transform_label(label) 
            
        return image, label

In [21]:
# split to train data and validation data for simplicity
# TODO:test should be conducted by isolated test data (different document)

# sort to eliminate os.listdir randomness
# I expect corresponding image and label have same filename
file_name = sorted(os.listdir(image_dir))
train_file_name, validation_file_name = train_test_split(file_name, test_size=0.2, random_state=0)

print('The number of training data:', len(train_file_name))
print('The number of validation data:', len(validation_file_name))

The number of training data: 276
The number of validation data: 70


In [22]:
# transform for synchronize cropping for image and label
# warning:this class can't do padding
class RandomCropSync(object):
    def __init__(self, size):
        if isinstance(self, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
    
    def get_params(img, output_size):
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w
        
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
    
    def __call__(self, img1, img2):
        assert(img1.size == img2.size)
        i, j, h, w = self.get_params(img1, self.size)
        
        img1_cropped = F.crop(img1, i, j, h, w)
        img2_cropped = F.crop(img2, i, j, h, w)
        
        return img1_cropped, img2_cropped

In [23]:
tf_sync_train = RandomCropSync((height_for_train_cropping, width_for_train_cropping)) # use for training
tf_sync_validation = RandomCropSync((height_for_validation_cropping, width_for_validation_cropping)) # use for validation
tf_image = transforms.ToTensor() # use always
tf_label = transforms.ToTensor() # use always

train_dataset = DocDataset(image_dir, label_dir, train_file_name,
                           tf_sync_train, tf_image, tf_label)
validation_dataset = DocDataset(image_dir, label_dir, validation_file_name,
                                tf_sync_validation, tf_image, tf_label)

train_loader = DataLoader(train_dataset, batch_size=batchsize_train, shuffle=True)
# In validation, I'll save estimated label, therefore shuffle=True to save result for different input
validation_loader = DataLoader(validation_dataset, batch_size=batchsize_validation, shuffle=True)

In [24]:
# TODO:explore other normalization
# define parts for U-net for convenience (for encoder parts)
# downsampling to half size (default)
# conv > batchnorm(optional) > dropout(optional) > relu
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, ksize=4, stride=2, padding=1, use_bn=True, drop_prob=0.0):
        super(DownSample, self).__init__()
        self.use_batchnorm = use_bn
        self.use_dropout = drop_prob > 0
        
        self.cv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=padding)
        if self.use_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        if self.use_dropout:
            self.dr = nn.Dropout(drop_prob)
        self.rl = nn.ReLU(0.2)
        
    def forward(self, x):
        out = self.cv(x)
        if self.use_batchnorm:
            out = self.bn(out)
        if self.use_dropout:
            out = self.dr(out)
        out = self.rl(out)
        return out

In [25]:
# TODO:explore other normalization (because batch size is very small)
# define parts for U-net for convenience (for decorder)
# upsampling to double size (default) (using transposed convolution)
# conv > batchnorm(optional) > dropout(optional) > relu
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, ksize=4, stride=2, padding=1, use_bn=True, drop_prob=0.0):
        super(UpSample, self).__init__()
        self.use_batchnorm = use_bn
        self.use_dropout = drop_prob > 0
        
        self.tc = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=padding)
        if self.use_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        if self.use_dropout:
            self.dr = nn.Dropout(drop_prob)
        self.rl = nn.ReLU()
        
    def forward(self, x):
        out = self.tc(x)
        if self.use_batchnorm:
            out = self.bn(out)
        if self.use_dropout:
            out = self.dr(out)
        out = self.rl(out)
        return out

In [26]:
# TODO : add attribute for switching using dropout or not and batchnorm or not
class U_Net(nn.Module):
    def __init__(self, n_depth_encoder, n_base_channels=32):
        super(U_Net, self).__init__()
        
        self.n_depth_encoder = n_depth_encoder
        n_channels = 3
        # encoder parts
        self.encoder = nn.ModuleList()
        for i in range(self.n_depth_encoder):
            if i == 0:
                self.encoder.append(DownSample(n_channels, n_base_channels))
                n_channels = n_base_channels
            else:
                self.encoder.append(DownSample(n_channels, n_channels*2))
                n_channels = n_channels*2
                
        # decoder parts
        self.decoder = nn.ModuleList()
        for i in range(self.n_depth_encoder):
            if i == 0:
                self.decoder.append(UpSample(n_channels, n_channels))
            else:
                self.decoder.append(UpSample(n_channels + n_channels//2, n_channels//2))
                n_channels = n_channels//2

        # 1x1 convolution to adjust channels and refine result
        self.conv1x1 = nn.Conv2d(n_channels, 3, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out_encoders = []
        for i in range(self.n_depth_encoder):
            if i != 0:
                out_encoders.append(x)
            x = self.encoder[i](x)
            
        for i in range(self.n_depth_encoder):
            if i == 0:
                x = self.decoder[i](x)
            else:
                concated_input = torch.cat([x, out_encoders[self.n_depth_encoder-i-1]], dim=1)
                x = self.decoder[i](concated_input)
        
        out = self.conv1x1(x)
        return out

In [29]:
net = U_Net(n_depth_encoder=5, n_base_channels=32)
net = net.to(device)

#TODO:explore good initialization

optimizer = optim.Adam(net.parameters(), lr=2e-4)
criterion = nn.BCEWithLogitsLoss()

# count the number of trainable parameters
num_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)

# print settings
print('The number of trainable parameters:', num_trainable_params)
print('Model:\n', net)
print('\nOptimizer:\n', optimizer)
print('Loss:\n', criterion)

The number of trainable parameters: 11165091
Model:
 U_Net(
  (encoder): ModuleList(
    (0): DownSample(
      (cv): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rl): ReLU(inplace)
    )
    (1): DownSample(
      (cv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rl): ReLU(inplace)
    )
    (2): DownSample(
      (cv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rl): ReLU(inplace)
    )
    (3): DownSample(
      (cv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rl): ReLU(inplace)
    )
    (4): DownSample