# **肺セグメンテーションをやってみよう**

## **深層学習によるセグメンテーション**
セグメンテーションの課題に対しては，「ピクセル毎の分類問題を解く」という問題設定で深層学習を適用することが一般的です．
最も簡単な例としては，正方形のパッチを切り取っては，中心のピクセルが関心領域であるかを分類することを，ピクセル数だけ繰り返す方法があります（図）（引用）．

## **U-Netによるセグメンテーション**
今回は，画像部会【リンク】にて公開されている胸部X線写真における，肺野を抽出する課題に挑戦しましょう．

## **全体の流れ**
- データの読み込み
- U-Netの定義
- 誤差関数の定義
- 評価指標の定義
- 学習ループの設計
- モデルの学習
- モデルの評価

### **データセットの読み込み**
data\
&emsp; \- train\
&emsp;&emsp;&emsp; \- label\
&emsp;&emsp;&emsp; \- org\
&emsp; \- val\
&emsp;&emsp;&emsp; \- label\
&emsp;&emsp;&emsp; \- org\
&emsp; \- test\
&emsp;&emsp;&emsp; \- label\
&emsp;&emsp;&emsp; \- org

上記のフォルダ構造を念頭に置きながら，データローダを作ります．\
データをtrain, validation, testに分ける意義については，【リンク】をご覧ください．\
画像部会が公開している形式では，trainとtestのみに分かれていますが，trainのうち10例をvalとして分け直しました．\
\
まずは必要なライブラリをインポートしましょう．
今回使うライブラリは以下の通りです．

In [8]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import glob

# **ミニバッチ学習のおさらい**
<img src="imgs/minibatch.PNG">

In [9]:
class ChestDataset(Dataset):
    def __init__(self, train=True, shuffle=True):
        super().__init__()
        
        if train == True:
            self.img_path_list = sorted(glob.glob("data/train/org/*"))
            self.label_path_list = sorted(glob.glob("data/train/label/*"))
        else:
            self.img_path_list = sorted(glob.glob("data/test/org/*"))
            self.label_path_list = sorted(glob.glob("data/test/label/*"))
    
    def __len__(self):
        return len(self.img_path_list)
    
    def __getitem__(self, index):
        image_path = self.img_path_list[index]
        label_path = self.label_path_list[index]
        img = np.array(Image.open(image_path))
        img = add_axis(img)
        label = np.array(Image.open(label_path))
        label = self._label_transform(label)
        label = add_axis(label)
        return img, label
    
    def _label_transform(self, label):# transformの関数に書き換える
        nonzero = np.nonzero(label)
        label[nonzero] = 1
        return label
    
    def add_axis(arr):# transformの関数に書き換える
        arr_new = np.expand_dims(arr, 0)
        return arr_new

Transformについて説明する

In [10]:
def add_axis(arr):
    arr_new = np.expand_dims(arr, 0)
    return arr_new

In [11]:
chest_train = ChestDataset(train=True, shuffle=True)
chest_test = ChestDataset(train=False, shuffle=False)
train_loader = DataLoader(chest_train, batch_size=5, shuffle=True)
test_loader = DataLoader(chest_test, batch_size=5, shuffle=False)

train_loaderがきちんと画像を読み込めているか確認しましょう．\
ここでshapeについて詳しく説明する

In [12]:
for i, data in enumerate(train_loader):
    print(data[0].shape)

torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])
torch.Size([5, 1, 256, 256])


### **ネットワークの定義**
<img src="imgs/unet.png">

In [13]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):

    def __init__(self, n_class=2, input_channel=1, output_channel=1):
        super(UNet, self).__init__()
        self.n_class = n_class
        
        self.input_channel = input_channel
        self.output_channel = output_channel
        
        self.enco1_1 = nn.Conv2d(self.input_channel, 64, kernel_size=3, stride=1, padding=1)
        self.enco1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        
        self.enco2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.enco2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.enco3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.enco3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.enco4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.enco4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.enco5_1 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1)
        self.enco5_2 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)

        self.deco6_1 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)
        self.deco6_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)

        self.deco7_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.deco7_2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)

        self.deco8_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.deco8_2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)

        self.deco9_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.deco9_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.final_layer = nn.Conv2d(64, self.output_channel, kernel_size=1)

        self.bn1_1 = nn.BatchNorm2d(  64)
        self.bn1_2 = nn.BatchNorm2d(  64)

        self.bn2_1 = nn.BatchNorm2d(  128)
        self.bn2_2 = nn.BatchNorm2d(  128)

        self.bn3_1 = nn.BatchNorm2d(  256)
        self.bn3_2 = nn.BatchNorm2d(  256)

        self.bn4_1 = nn.BatchNorm2d(  512)
        self.bn4_2 = nn.BatchNorm2d(  512)

        self.bn5_1 = nn.BatchNorm2d(  1024)
        self.bn5_2 = nn.BatchNorm2d(  512)

        self.bn6_1 = nn.BatchNorm2d(  512)
        self.bn6_2 = nn.BatchNorm2d(  256)

        self.bn7_1 = nn.BatchNorm2d(  256)
        self.bn7_2 = nn.BatchNorm2d(  128)

        self.bn8_1 = nn.BatchNorm2d(  128)
        self.bn8_2 = nn.BatchNorm2d(  64)

        self.bn9_1 = nn.BatchNorm2d(  64)
        self.bn9_2 = nn.BatchNorm2d(  64)
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  

    def forward(self, x): #x = (batchsize, 3, 360, 480)
        #if LRN:
        #    x = F.local_response_normalization(x) #Needed for preventing from overfitting

        h1_1 = F.relu(self.bn1_1(self.enco1_1(x)))
        h1_2 = F.relu(self.bn1_2(self.enco1_2(h1_1)))
        pool1, pool1_indice = F.max_pool2d(h1_2, 2, stride=2, return_indices=True) #(batchsize,  64, 180, 240)

        h2_1 = F.relu(self.bn2_1(self.enco2_1(pool1)))
        h2_2 = F.relu(self.bn2_2(self.enco2_2(h2_1)))
        pool2, pool2_indice = F.max_pool2d(h2_2, 2, stride=2, return_indices=True) #(batchsize, 128,  90, 120) 

        h3_1 = F.relu(self.bn3_1(self.enco3_1(pool2)))
        h3_2 = F.relu(self.bn3_2(self.enco3_2(h3_1)))
        pool3, pool3_indice = F.max_pool2d(h3_2, 2, stride=2, return_indices=True) #(batchsize, 256,  45,  60) 

        h4_1 = F.relu(self.bn4_1(self.enco4_1(pool3)))
        h4_2 = F.relu(self.bn4_2(self.enco4_2(h4_1)))
        pool4, pool4_indice = F.max_pool2d(h4_2, 2, stride=2, return_indices=True) #(batchsize, 256,  23,  30) 

        h5_1 = F.relu(self.bn5_1(self.enco5_1(pool4)))
        h5_2 = F.relu(self.bn5_2(self.enco5_2(h5_1)))
        
        up5 = F.max_unpool2d(h5_2, pool4_indice, kernel_size=2, stride=2, output_size=(pool3.shape[2], pool3.shape[3]))
        h6_1 = F.relu(self.bn6_1(self.deco6_1(torch.cat((up5, h4_2), dim=1))))
        h6_2 = F.relu(self.bn6_2(self.deco6_2(h6_1)))

        up6 = F.max_unpool2d(h6_2, pool3_indice, kernel_size=2, stride=2, output_size=(pool2.shape[2], pool2.shape[3]))
        h7_1 = F.relu(self.bn7_1(self.deco7_1(torch.cat((up6, h3_2), dim=1))))
        h7_2 = F.relu(self.bn7_2(self.deco7_2(h7_1)))

        up7 = F.max_unpool2d(h7_2, pool2_indice, kernel_size=2, stride=2, output_size=(pool1.shape[2], pool1.shape[3]))
        h8_1 = F.relu(self.bn8_1(self.deco8_1(torch.cat((up7, h2_2), dim=1))))
        h8_2 = F.relu(self.bn8_2(self.deco8_2(h8_1)))

        up8 = F.max_unpool2d(h8_2, pool1_indice, kernel_size=2, stride=2, output_size=(x.shape[2], x.shape[3])) #x = (batchsize, 128, 360, 480)
        h9_1 = F.relu(self.bn9_1(self.deco9_1(torch.cat((up8, h1_2), dim=1))))
        h9_2 = F.relu(self.bn9_2(self.deco9_2(h9_1)))

        h = self.final_layer(h9_2)
        #print(h.shape)
        #print(t.shape)
        predict = h
        #loss = 	nn.BCEWithLogitsLoss(h, t)
        
        #predict = nn.Softmax(h)
        return torch.sigmoid(predict)

In [28]:
from torch.autograd import Function
class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
                         / (self.union * self.union)
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)


class BCEDiceLoss(nn.Module):
    def __init__(self):
        super(BCEDiceLoss, self).__init__()

    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5
        input = torch.sigmoid(input)
        num = target.size(0)
        input = input.view(num, -1)
        target = target.view(num, -1)
        intersection = (input * target)
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
        dice = 1 - dice.sum() / num
        return 0.5 * bce + dice

### 学習ループの作成


In [34]:
import torch.optim as optim

# モデルの定義
model = UNet()
model = model.to("cuda")

# optimizerの準備
optimizer = optim.RAdam(model.parameters())

# 誤差関数の定義
#criterion = BCEDiceLoss()
criterion = nn.BCELoss()

# 訓練ループ

batchsize = 5

train_size = len(train_loader)

epochs = 100

train_loss_list = [] # epoch毎のtrain_lossを保存しておくための入れ物
val_loss_list = [] # epoch毎のvalidation_lossを保存しておくための入れ物

for epoch in range(epochs):
    train_loss_add = 0 # 1エポック分の誤差を累積しておくための変数
    model.train() #学習モードであることを明示
    for i, data in enumerate(train_loader):
        x, t = data
        x = torch.tensor(x) / 255
        t = torch.tensor(t).float()
        
        x = x.to("cuda")
        t = t.to("cuda")
        
        predict = model(x)
        
        loss = criterion(predict, t)
        model.zero_grad()
        loss.backward() # 誤差逆伝播法により，各パラメータについての勾配を求める
        optimizer.step() # 上で求めた勾配を用いて，山が低くなっている方へ一歩進む
        
        train_loss_add += loss.data
        
    loss_mean = train_loss_add / int(train_size/batchsize)
    print("epoch" + str(epoch+1))
    print("train_loss:" + str(loss_mean))
    train_loss_list.append(loss_mean)
    
    # validation
    model.eval()
    val_loss_add = 0
    num = 0
    for i, data in enumerate(test_loader):
            
        #cudaに変換
        x, t = data
        x = torch.tensor(x) / 255
        t = torch.tensor(t).float()
        x = x.to("cuda")
        t = t.to("cuda")
        predict = model(x)
        
        loss = criterion(predict, t)
        val_loss_add += loss.data

  x = torch.tensor(x) / 255
  t = torch.tensor(t).float()


epoch1
train_loss:tensor(3.0998, device='cuda:0')


  x = torch.tensor(x) / 255
  t = torch.tensor(t).float()


epoch2
train_loss:tensor(2.3974, device='cuda:0')
epoch3
train_loss:tensor(1.7527, device='cuda:0')
epoch4
train_loss:tensor(1.2409, device='cuda:0')
epoch5
train_loss:tensor(1.0304, device='cuda:0')
epoch6
train_loss:tensor(0.8997, device='cuda:0')
epoch7
train_loss:tensor(0.8478, device='cuda:0')
epoch8
train_loss:tensor(0.7624, device='cuda:0')
epoch9
train_loss:tensor(0.7214, device='cuda:0')
epoch10
train_loss:tensor(0.7198, device='cuda:0')
epoch11
train_loss:tensor(0.6549, device='cuda:0')
epoch12
train_loss:tensor(0.6096, device='cuda:0')
epoch13
train_loss:tensor(0.5567, device='cuda:0')
epoch14
train_loss:tensor(0.5771, device='cuda:0')
epoch15
train_loss:tensor(0.5592, device='cuda:0')
epoch16
train_loss:tensor(0.5214, device='cuda:0')
epoch17
train_loss:tensor(0.4976, device='cuda:0')
epoch18
train_loss:tensor(0.4560, device='cuda:0')
epoch19
train_loss:tensor(0.4319, device='cuda:0')
epoch20
train_loss:tensor(0.4142, device='cuda:0')
epoch21
train_loss:tensor(0.3856, devic

### セグメンテーションの精度評価

In [110]:
# DICE係数とIOUについて説明，実装
# 質的評価のための可視化方法

Tue Mar  1 02:18:54 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   49C    P0   138W / 300W |  21425MiB / 32505MiB |     77%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   50C    P0    67W / 300W |  10351MiB / 32508MiB |      0%      Default |
|       

In [12]:
img = Image.open("data/train/org/1.png")
img = np.array(img)
img.shape

(256, 256)

In [15]:
glob.glob("data/train/org/*")

['data/train/org/22.png',
 'data/train/org/23.png',
 'data/train/org/4.png',
 'data/train/org/47.png',
 'data/train/org/9.png',
 'data/train/org/5.png',
 'data/train/org/7.png',
 'data/train/org/25.png',
 'data/train/org/1.png',
 'data/train/org/8.png',
 'data/train/org/12.png',
 'data/train/org/26.png',
 'data/train/org/39.png',
 'data/train/org/35.png',
 'data/train/org/28.png',
 'data/train/org/43.png',
 'data/train/org/21.png',
 'data/train/org/42.png',
 'data/train/org/34.png',
 'data/train/org/27.png',
 'data/train/org/32.png',
 'data/train/org/41.png',
 'data/train/org/13.png',
 'data/train/org/50.png',
 'data/train/org/36.png',
 'data/train/org/40.png',
 'data/train/org/49.png',
 'data/train/org/17.png',
 'data/train/org/3.png',
 'data/train/org/33.png',
 'data/train/org/18.png',
 'data/train/org/6.png',
 'data/train/org/48.png',
 'data/train/org/46.png',
 'data/train/org/31.png',
 'data/train/org/37.png',
 'data/train/org/11.png',
 'data/train/org/20.png',
 'data/train/org/29.

In [83]:
a = np.array([0,2,0,0,2])
mask = np.nonzero(a)
a[mask] = 1
a

array([0, 1, 0, 0, 1])