### setup

In [33]:
import os
import glob
import pickle
import argparse
import random

In [1]:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import cv2

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import Sampler
import torchvision.transforms as transforms

In [22]:
parser = {
    'train_path': '../data/train/',
    'train_hq_path': '../data/train_hq/',
    'train_masks_path': '../data/train_masks/',
    'train_masks_file': '../data/train_masks.csv',
    'intermediate_path': '../intermediate/',
    'split_data': True,
    'batch_size': 16,
    'log_every': 10,
    'train': True,
    'model_name': '',
    'test': False,
    'seed': 20170915,
}
args = argparse.Namespace(**parser)

torch.manual_seed(args.seed)

args.intermediate_path = os.path.join(args.intermediate_path, str(args.seed))
if not os.path.isdir(args.intermediate_path):
    os.mkdir(args.intermediate_path)

### model

In [10]:
class ConvBnRelu2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(ConvBnRelu2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bn   = nn.BatchNorm2d(out_channels, eps=1e-4)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        o = self.conv(x)
        if self.bn is not None:
            o = self.bn(o)
        return self.relu(o)

    def merge_bn(self):  # for faster inference
        if self.bn is None:
            return
        
        conv_weight     = self.conv.weight.data
        bn_weight       = self.bn.weight.data
        bn_bias         = self.bn.bias.data
        bn_running_mean = self.bn.running_mean
        bn_running_var  = self.bn.running_var
        bn_eps          = self.bn.eps

        N,C,H,W = conv_weight.size()
        std = torch.sqrt(bn_running_var+bn_eps)
        std_bn_weight = (bn_weight/std).repeat(C*H*W,1).t().contiguous().view(N,C,H,W)
        conv_weight_hat = std_bn_weight*conv_weight
        conv_bias_hat   = bn_bias - (bn_weight/std)*bn_running_mean
        
        self.conv = nn.Conv2d(self.conv.in_channels, self.conv.out_channels, self.conv.kernel_size,
                              padding=self.conv.padding, bias=True)
        self.conv.weight.data = conv_weight_hat
        self.conv.bias.data   = conv_bias_hat
        self.bn = None

In [9]:
class StackEncoder(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(StackEncoder, self).__init__()
        self.encode = nn.Sequential(
            ConvBnRelu2d(in_channels, out_channels, kernel_size, padding=padding),
            ConvBnRelu2d(out_channels, out_channels, kernel_size, padding=padding))

    def forward(self, x):
        e = self.encode(x)
        o = F.max_pool2d(e, kernel_size=2, stride=2)
        return e, o

In [8]:
class StackDecoder(nn.Module):
    
    def __init__(self, en_channels, in_channels, out_channels, kernel_size=3, padding=1):
        super(StackDecoder, self).__init__()
        self.decode = nn.Sequential(
            ConvBnRelu2d(en_channels+in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            ConvBnRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            ConvBnRelu2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding))

    def forward(self, e, x):
        N,C,H,W = e.size()
        x = F.upsample(x, size=(H,W), mode='bilinear')
        x = torch.cat([e, x], dim=1)
        return self.decode(x)

In [11]:
class UNet1024(nn.Module):
    
    def __init__(self, in_shape):
        super(UNet1024, self).__init__()
        C,H,W = in_shape

        # 1024
        self.down1 = StackEncoder(  C,  24)  # 512
        self.down2 = StackEncoder( 24,  64)  # 256
        self.down3 = StackEncoder( 64, 128)  # 128
        self.down4 = StackEncoder(128, 256)  # 64
        self.down5 = StackEncoder(256, 512)  # 32
        self.down6 = StackEncoder(512, 768)  # 16

        self.center = ConvBnRelu2d(768, 768)

        # 16
        self.up6 = StackDecoder(768, 768, 512)  # 32
        self.up5 = StackDecoder(512, 512, 256)  # 64
        self.up4 = StackDecoder(256, 256, 128)  # 128
        self.up3 = StackDecoder(128, 128,  64)  # 256
        self.up2 = StackDecoder( 64,  64,  24)  # 512
        self.up1 = StackDecoder( 24,  24,  24)  # 1024
        
        self.mask = nn.Conv2d(24, 1, kernel_size=1)

    def forward(self, x):
        e1, o = self.down1(x)
        e2, o = self.down2(o)
        e3, o = self.down3(o)
        e4, o = self.down4(o)
        e5, o = self.down5(o)
        e6, o = self.down6(o)

        o = self.center(o)
        
        o = self.up6(e6, o)
        o = self.up5(e5, o)
        o = self.up4(e4, o)
        o = self.up3(e3, o)
        o = self.up2(e2, o)
        o = self.up1(e1, o)

        o = self.mask(o)
        o = torch.squeeze(o, dim=1)
        return o

### utils

In [12]:
def dice_score(probs, target, weight=None, use_mask=True, threshold=0.5):
    probs = (probs > threshold).float() if use_mask else probs
    N     = target.size(0)
    if self.weight is None:
        w = Variable(torch.ones(target.size()).cuda()).view(N, -1)
    else:
        w = self.weight.view(N, -1)
    w2    = w*w
    m1    = probs.view(N, -1)
    m2    = target.view(N, -1)
    score = 2 * ((w2*m1*m2).sum(dim=1)+1) / ((w2*m1).sum(dim=1) + (w2*m2).sum(dim=1)+1)
    return score.sum()/N


def dice_loss(logits, target, weight=None):
    probs = F.sigmoid(logits)
    loss  = 1 - dice_score(probs, target, weight, use_mask=False)
    return loss


def criterion(logits, target):
    N,H,W = target.size()
    a = F.avg_pool2d(target, kernel_size=41, padding=20)
    boundary = (a.ge(0.01) * a.le(0.99)).float()
    weight = Variable(torch.ones(a.size()).cuda())

    w0 = weight.sum()
    weight = weight + 2*boundary
    w1 = weight.sum()
    weight = weight*w0/w1
        
    return (F.binary_cross_entropy_with_logits(logits, target, weight)
            + dice_loss(logits, target, weight))

In [34]:
def image_to_tensor(image, mean=0, std=1):
    image = image.astype(np.float32)
    image = (image-mean)/std
    image = image.transpose((2,0,1))  # HWC -> CHW
    tensor = torch.from_numpy(image)
    return tensor


def label_to_tensor(label, threshold=0.5):
    label  = (label>threshold).astype(np.float32)
    tensor = torch.from_numpy(label)
    return tensor

In [36]:
class CarDataset(Dataset):

    def __init__(self, image_path, mask_path='', transform=[], mode='train'):
        super(CarDataset, self).__init__()
        self.img_names = sorted([x[-1] for x in glob.glob(image_path + '/*.jpg').split('/')])
        self.img_path  = image_path
        self.mask_path = mask_path
        self.transform = transform
        self.mode      = mode

    def get_image(self, index):
        name  = self.img_names[index]
        file  = os.path.join(self.img_path, name)
        img   = cv2.imread(img_file)
        image = img.astype(np.float32)/255
        return image
    
    def get_label(self, name):
        name = name.split('.')[0] + '_mask.gif'
        file = os.path.join(self.mask_path, name)
        mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
        label = mask.astype(np.float32)/255
        return label

    def get_train_item(self, index):
        image, name = self.get_image(index)
        label = self.get_label(name)

        for t in self.transform:
            image, label = t(image, label)
        image = image_to_tensor(image)
        label = label_to_tensor(label)
        return image, label, index

    def get_test_item(self, index):
        image = self.get_image(index)

        for t in self.transform:
            image = t(image)
        image = image_to_tensor(image)
        return image, index

    def __getitem__(self, index):
        if self.mode == 'train':
            return self.get_train_item(index)
        elif self.mode == 'test':
            return self.get_test_item(index)

    def __len__(self):
        return len(self.names)

In [39]:
class RandomSamplerWithLength(Sampler):
    
    def __init__(self, data_source, length):
        self.len_data = len(data_source)
        self.num_samples = length

    def __iter__(self):
        l = list(range(self.len_data))
        random.shuffle(l)
        l = l[0:self.num_samples]
        return iter(l)

    def __len__(self):
        return self.num_samples

In [None]:
def random_shift_scale_rotateN(images, shift_limit=(-0.0625,0.0625), scale_limit=(1/1.1,1.1),
                               rotate_limit=(-45,45), aspect_limit = (1,1), prob=0.5):
    if random.random() < prob:
        H,W,C = images[0].shape

        angle  = random.uniform(rotate_limit[0],rotate_limit[1])  #degree
        scale  = random.uniform(scale_limit[0],scale_limit[1])
        aspect = random.uniform(aspect_limit[0],aspect_limit[1])
        sx = scale*aspect / (aspect**0.5)
        sy = scale / (aspect**0.5)
        dx = round(random.uniform(shift_limit[0],shift_limit[1])*W)
        dy = round(random.uniform(shift_limit[0],shift_limit[1])*H)
        cc = math.cos(angle/180*math.pi)*(sx)
        ss = math.sin(angle/180*math.pi)*(sy)
        rotate_matrix = np.array([[cc,-ss], [ss,cc]])

        box0 = np.array([[0,0], [W,0], [W,H], [0,H]]).astype(np.float32)
        box1 = box0 - np.array([W/2, H/2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array([W/2+dx, H/2+dy])
        mat = cv2.getPerspectiveTransform(box0, box1)

        for n, image in enumerate(images):
            images[n] = cv2.warpPerspective(image, mat, (W, H), flags=cv2.INTER_LINEAR,
                                            borderMode=cv2.BORDER_REFLECT_101, borderValue=(0,0,0))
    return images


def train_augment(image, label):
    image, mask = random_shift_scale_rotateN([image, mask], shift_limit=(-0.0625,0.0625),
                                             scale_limit=(0.91,1.21), rotate_limit=(-0,0))
    return image, mask

In [None]:
def split_data():  # train 4320, val 768
    pass

### prepare

In [None]:
if args.split_data:
    split_data()

In [None]:
train_dataset = CarDataset(args.train_path, args.train_masks_path,
                           transform=[lambda x,y: train_augment(x,y)], mode='train')
train_loader  = DataLoader(train_dataset, args.batch_size,
                           sampler=RandomSamplerWithLength(train_dataset, 4320),
                           drop_last=True, num_workers=8)

In [None]:
val_dataset = CarDataset(args.val_path, args.val_masks_path,
                         transform=[lambda x,y: train_augment(x,y)], mode='train')
val_loader  = DataLoader(val_dataset, args.batch_size,
                         sampler=SequentialSampler(val_dataset, 4320),
                         drop_last=False, num_workers=8)

### train