Reference  
https://lp-tech.net/articles/hzfn7
    

パッチ分割用ソースコードはimagePatcherSemSeg.py

``` 
$ python imagePatcherSemSeg.py ./datasets/wM1 500 ./results/ wM
```

のようにして実行するとモザイク画像をパッチに分割し，かつセグメンテーションもしてくれる．  
このデータセットを使って学習をするのかと？

Pytorch による画像セグメンテーション  
データセットは https://www.kaggle.com/c/carvana-image-masking-challenge/data  
kaggleアカウント作って電話番号とかでVerifyして色々めんどいけどなんとか．

PytorchのインストールはURLから．  
https://pytorch.org/ を参照  

Windows  
```
$ pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp37-cp37m-win_amd64.whl
$ pip install https://download.pytorch.org/whl/cpu/torchvision-0.3.0-cp37-cp37m-win_amd64.whl
```

Mac  
```
$ pip install torch torchvision
$ brew install libomp
```

https://github.com/pytorch/pytorch/issues/20030

PydensecrfはpipでインストールしようとするとEigenとかいうC++のライブラリでつっかかることがある．その場合は，  
- venvを使っているならアクティベート
- https://github.com/lucasb-eyer/pydensecrf をクローン  
- http://eigen.tuxfamily.org/index.php?title=Main_Page から最新のEigenライブラリをダウンロードし，解凍
- 解凍した中に入っているEigenディレクトリを，pydensecrf/pydensecrf/densecrf/include/Eigenに上書き
- pydensecrfディレクトリの直下で，python setup.py install
- これにてインストール成功  
  
参考: https://github.com/lucasb-eyer/pydensecrf/issues/69

In [10]:
import sys, os, numpy as np
from optparse import OptionParser

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision.transforms as transforms
from torch.autograd import Function, Variable

from tqdm import tqdm
import pydensecrf.densecrf as dcrf
import random
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

プログラム内で使用する変数を定義

In [37]:
DIR_IMG = 'datasets/carvana/train/'
DIR_MASK = 'datasets/carvana/train_masks/'
DIR_IMG_TEST = 'datasets/carvana/test/'
DIR_CHECKPOINT = 'results/carvana/checkpoint/'
VAL_PERCENT = 0.05 # テスト用データの割合
SCALE = 0.5
N = 2
BATCH_SIZE = 2
EPOCH = 5
THRESHOLD = 0.5

画像の読み込み

In [12]:
def to_cropped(ids, dir, suffix):
    for id, pos in ids:
        img = Image.open(dir + id + suffix)
        
        w = img.size[0]
        h = img.size[1]
        newW = int(w * SCALE)
        newH = int(h * SCALE)
        
        img = img.resize((newW, newH))
        img = img.crop((0, 0, newW, newH)) # 左上1/4のみを切り取る意味は実験を早めるため？
        img = np.array(img, dtype=np.float32)
        
        h = img.shape[0]
        if pos == 0:
            img = img[:, :h]
        else:
            img = img[:, -h:]
        
        yield img

画像とマスク画像の読み込み

In [35]:
def get_img_mask(ids):
    img = to_cropped(ids, DIR_IMG, '.jpg')
    img = map(lambda x: np.transpose(x, axes=[2, 0, 1]), img)
    img = map(lambda x: x / 255, img)
    
    mask = to_cropped(ids, DIR_MASK, '_mask.gif')
    
    return zip(img, mask)

バッチ化関数

In [15]:
def batch(iterable, batch_size):
    b = []
    for i, t in enumerate(iterable):
        b.append(t)
        if (i + 1) % batch_size == 0:
            yield b # batch_sizeの数だけまとめてyield
            b = []
    
    if len(b) > 0:
        yield b

データセットの読み込み

In [26]:
# 拡張子を取り払ったID部分の配列
ids_all = [f[:-4] for f in os.listdir(DIR_IMG)]

# 全てのIDを2つにコピー
ids_all = [(idx, i) for i in range(N) for idx in ids_all]

random.shuffle(ids_all)

# 全訓練データ数？
n = int(len(ids_all) * VAL_PERCENT)

# 訓練データとテストデータを分けている？
ids = {'train': ids_all[:-n], 'val': ids_all[-n:]}

len_train = len(ids['train'])
len_val = len(ids['val'])

U-Netの実装

In [39]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )
    
    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
        
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1) # catenate? でマージ
        x = self.conv(x)
        return x

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
        def forward(self, x):
            x = self.conv(x)
            return x

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128) # Uの左側からやってくるデータとマージを行うので入力2倍
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

Dice係数を算出するクラス  
https://mieruca-ai.com/ai/jaccard_dice_simpson/  
集合同士の類似度の一種  
$$ DSC(A, B) = \frac{2|A \cap B|}{|A| + |B|} $$
片方の要素数がもう片方より多い時には無駄に結果が小さくなってしまうJaccard係数とは違い，2つの集合の平均要素数を計算することで，共通要素数を重視した類似度計算を行う係数

In [33]:
class DiceCoeff(Function):
    """Dice coeff for indivisual examples"""
    def forward(self, inpt, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(inpt.view(-1), target.view(-1))
        self.union = torch.sum(inpt) + torch.sum(target) + eps
        
        t = (2 * self.inter.float() + eps) / self.union.float()
        return t
    
    # 単一の出力しか持たないため，勾配も一つだけ
    def backward(self, grad_output):
        inpt, target = self.saved_variables
        grad_input = grad_target = None
        
        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union + self.inter) \
                            / self.union * self.union
        if self.needs_input_grad[1]:
            grad_target = None
        
        return grad_input, grad_target

def dice_coeff(inpt, target):
    """Dice coeff for batches"""
    if inpt.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()
    
    for i, c in enumerate(zip(inpt, target)):
        s = s + DiceCoeff().forward(c[0], c[1])
    
    return s / (i + 1)

ネットワークの学習

In [40]:
if torch.cuda.is_available():
    net = UNet(n_channels=3, n_classes=1).cuda()
else:
    net = UNet(n_channels=3, n_classes=1).cpu()

optimizer = optim.SGD(
    net.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=0.0005
)

criterion = nn.BCELoss()

for epoch in range(EPOCH):
    train = get_img_mask(ids['train'])
    val = get_img_mask(ids['val'])
    
    # train section
    epoch_loss = 0
    for i, b in enumerate(batch(train, BATCH_SIZE)):
        img = np.array([i[0] for i in b]).astype(np.float32)
        mask = np.array([i[1] for i in b])
        
        if torch.cuda.is_available():
            img = torch.from_numpy(img).cuda()
            mask = torch.from_numpy(mask).cuda()
        else:
            img = torch.from_numpy(img).cpu()
            mask = torch.from_numpy(mask).cpu()
        
        mask_flat = mask.view(-1)
        
        mask_pred = net(img) # prediction
        mask_prob = F.sigmoid(mask_pred) # probability
        mask_prob_flat = mask_prob.view(-1)
        
        loss = criterion(mask_prob_flat, mask_flat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        if i % 10 == 0:
            print(f'{i} / {int(len_train/BATCH_SIZE)} ---- loss: {loss.item()}')
    
    print(f'Epoch finished ! Loss: {epoch_loss / len_train}')
    
    # ---- val section
    val_dice = 0
    for j, b in enumerate(val):
        if torch.cuda.is_available():
            img = torch.from_numpy(b[0]).unsqueeze(0).cuda()
            mask = torch.from_numpy(b[1]).unsqueeze(0).cuda()
        else:
            img = torch.from_numpy(b[0]).unsqueeze(0).cpu()
            mask = torch.from_numpy(b[1]).unsqueeze(0).cpu()
        
        mask_pred = net(img)[0]
        mask_prob = F.sigmoid(mask_pred)
        mask_bin = (mask_prob > 0.5).float()
        val_dice += dice_coeff(mask_bin, mask).item()
        
        if j % 10 == 0:
            print(f"val: {j}/{len_val}")
    
    torch.save(net.state_dict(), f"{DIR_CHECKPOINT}CP{epoch+1}.pth")
    print(f"Checkpoint {epoch+1} saved !")
    print(f"Validation Dice Coeff: {val_dice / len_val}")

NotImplementedError: 

テストデータによる評価

In [None]:
file_img_test = os.listdir(DIR_IMG_TEST)
random.shuffle(file_img_test)

for i, file in enumerate(file_img_test):
    img_original = Image.open(DIR_IMG_TEST+file)
    img = img_original
    
    w = img.size[0]
    h = img.size[1]
    
    newW = int(w * SCALE)
    newH = int(h * SCALE)
    
    img = img.resize((newW, newH))
    img = img.crop((0, 0, newW, newH))
    img = np.array(img, dtype=np.float32)
    img = img / 255
    
    img_left = img[:, :newH]
    img_right = img[;, -newH:]
    
    img_left = np.transpose(img_left, axes=[2, 0, 1])
    img_right = np.transpose(img_right, axes=[2, 0, 1])
    
    if torch.cuda.is_available():
        img_left = torch.from_numpy(img_left).unsqueeze(0).cuda()
        img_right = torch.from_numpy(img_right).unsqueeze(0).cuda()
    else:
        img_left = torch.from_numpy(img_left).unsqueeze(0).cpu()
        img_right = torch.from_numpy(img_right).unsqueeze(0).cpu()
    
    with torch.no_grad():
        mask_left = net(img_left)
        mask_right = net(img_right)

        mask_prob_left = F.sigmoid(mask_left).squeeze(0)
        mask_prob_right = F.sigmoid(mask_right).squeeze(0)
        
        tf = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(h),
                transforms.ToTensor()
        ])
        
        mask_prob_left = tf(mask_prob_left.cpu())
        mask_prob_right = tf(mask_prob_right.cpu())
        
        mask_prob_left_np = mask_prob_left.squeeze().cpu().numpy()
        mask_prob_right_np = mask_prob_right.squeeze().cpu().numpy()
        
        mask_prob_np = np.zeros((h, w), np.float32)
        mask_prob_np[:, :w//2+1] = mask_prob_left_np[:, :w//2+1]
        mask_prob_np[:, w//2+1:] = mask_prob_right_np[:, -(w//2-1):]
            
        
        h = mask_prob_np.shape[0]
        w = mask_prob_np.shape[1]

        mask_prob_np = np.expand_dims(mask_prob_np, 0)
        mask_prob_np = np.append(1 - mask_prob_np, mask_prob_np, axis=0)

        d = dcrf.DenseCRF2D(w, h, 2)
        U = -np.log(mask_prob_np)
        U = U.reshape((2, -1))
        U = np.ascontiguousarray(U)
        img = np.ascontiguousarray(np.array(img_original).astype(np.uint8))

        d.setUnaryEnergy(U)

        d.addPairwiseGaussian(sxy=20, compat=3)
        d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)

        mask = d.inference(5)
        mask = np.argmax(np.array(mask), axis=0).reshape((h, w))        
        mask = mask_prob_np > args['threshold']
        mask = Image.fromarray((mask[0] * 255).astype(np.uint8))
        
        
        fig = plt.figure(figsize=(27, 7))

        ax1 = fig.add_subplot(131)    
        ax1.imshow(img_original)
        ax1.set_title('input', fontsize=28)

        ax2 = fig.add_subplot(132)
        ax2.imshow(mask)
        ax2.set_title('output', fontsize=28)

        ax3 = fig.add_subplot(133)
        ax3.imshow(img_original)
        ax3.imshow(mask, alpha=0.8)
        ax3.set_title('input and output', fontsize=28)
        
        plt.show()