In [1]:
import os
import time
import copy
import random
import shutil
import zipfile
from collections import defaultdict

import cv2
import torch
import torchmetrics
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn.functional as F
from albumentations import (HorizontalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)
from albumentations.pytorch import ToTensorV2 as ToTensor
from PIL import Image
from skimage import io, transform
from torch import nn
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils
from tqdm import tqdm as tqdm


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
TRAIN_PATH = '../data/stage1_train/'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 画像データ拡張の関数
def get_train_transform():
   return A.Compose(
       [
        # リサイズ(こちらはすでに適用済みなのでなくても良いです)
        A.Resize(256, 256),
        # 正規化(こちらの細かい値はalbumentations.augmentations.transforms.Normalizeのデフォルトの値を適用)
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        # 水平フリップ（pはフリップする確率）
        A.HorizontalFlip(p=0.25),
        # 垂直フリップ
        A.VerticalFlip(p=0.25),
        ToTensor()
        ])

# Datasetクラスの定義
class LoadDataSet(Dataset):
    WIDTH = 256
    HEIGHT = 256

    def __init__(self, path, transform=None):
        self.path = path
        self.folders = os.listdir(path)
        self.transforms = get_train_transform()

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        image_folder = os.path.join(self.path, self.folders[idx], 'images/')
        mask_folder = os.path.join(self.path, self.folders[idx], 'masks/')
        image_path = os.path.join(image_folder, os.listdir(image_folder)[0])

        # 画像データの取得
        img = io.imread(image_path)[:, :, :3].astype('float32')
        img = transform.resize(img, (self.WIDTH, self.HEIGHT))

        mask = self.get_mask(mask_folder, self.WIDTH, self.HEIGHT).astype('float32')

        augmented = self.transforms(image=img, mask=mask)

        img = augmented['image']
        mask = augmented['mask']

        mask = mask.permute(2, 0, 1)
        point = self.get_point()

        label = mask[:, point[0], point[1]]
        point = torch.tensor([point], dtype=torch.int64)

        return img, point, label

    def get_mask(self, mask_folder, IMG_HEIGHT, IMG_WIDTH):
        mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)
        for mask_ in os.listdir(mask_folder):
                mask_ = io.imread(os.path.join(mask_folder, mask_))
                mask_ = transform.resize(mask_, (IMG_HEIGHT, IMG_WIDTH))
                mask_ = np.expand_dims(mask_, axis=-1)
                mask = np.maximum(mask, mask_)

        return mask

    def get_point(self):
        x = random.randint(0, self.WIDTH - 1)
        y = random.randint(0, self.HEIGHT - 1)
        return x, y


In [3]:
train_dataset = LoadDataSet(TRAIN_PATH, transform=get_train_transform())

In [4]:
image, point, label = train_dataset.__getitem__(0)
print(image.shape)
print(point.shape)
print(label.shape)

torch.Size([3, 256, 256])
torch.Size([1, 2])
torch.Size([1])


In [5]:
train_dataset.__len__()

670

In [6]:
def format_image(img):
    img = np.array(np.transpose(img, (1, 2, 0)))
    # 下は画像拡張での正規化を元に戻しています
    mean = np.array((0.485, 0.456, 0.406))
    std = np.array((0.229, 0.224, 0.225))
    img = std * img + mean
    img = img * 255
    img = img.astype(np.uint8)
    return img


In [7]:
split_ratio = 0.25
train_size=int(np.round(train_dataset.__len__() * (1 - split_ratio), 0))
valid_size=int(np.round(train_dataset.__len__() * split_ratio, 0))
train_data, valid_data = random_split(train_dataset, [train_size, valid_size])
train_loader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(dataset=valid_data, batch_size=10)

print("Length of train data: {}".format(len(train_data)))
print("Length of validation data: {}".format(len(valid_data)))

Length of train data: 502
Length of validation data: 168


In [8]:
class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        # 資料中の『FCN』に当たる部分
        self.conv1 = conv_bn_relu(input_channels,64)
        self.conv2 = conv_bn_relu(64, 128)
        self.conv3 = conv_bn_relu(128, 256)
        self.conv4 = conv_bn_relu(256, 512)
        self.conv5 = conv_bn_relu(512, 1024)
        self.down_pooling = nn.MaxPool2d(2)

        # 資料中の『Up Sampling』に当たる部分
        self.up_pool6 = up_pooling(1024, 512)
        self.conv6 = conv_bn_relu(1024, 512)
        self.up_pool7 = up_pooling(512, 256)
        self.conv7 = conv_bn_relu(512, 256)
        self.up_pool8 = up_pooling(256, 128)
        self.conv8 = conv_bn_relu(256, 128)
        self.up_pool9 = up_pooling(128, 64)
        self.conv9 = conv_bn_relu(128, 64)
        self.conv10 = nn.Conv2d(64, output_channels, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x, points):
        # 正規化
        x = x/255.

        # 資料中の『FCN』に当たる部分
        x1 = self.conv1(x)
        p1 = self.down_pooling(x1)
        x2 = self.conv2(p1)
        p2 = self.down_pooling(x2)
        x3 = self.conv3(p2)
        p3 = self.down_pooling(x3)
        x4 = self.conv4(p3)
        p4 = self.down_pooling(x4)
        x5 = self.conv5(p4)

        # 資料中の『Up Sampling』に当たる部分, torch.catによりSkip Connectionをしている
        p6 = self.up_pool6(x5)
        x6 = torch.cat([p6, x4], dim=1)
        x6 = self.conv6(x6)

        p7 = self.up_pool7(x6)
        x7 = torch.cat([p7, x3], dim=1)
        x7 = self.conv7(x7)

        p8 = self.up_pool8(x7)
        x8 = torch.cat([p8, x2], dim=1)
        x8 = self.conv8(x8)

        p9 = self.up_pool9(x8)
        x9 = torch.cat([p9, x1], dim=1)
        x9 = self.conv9(x9)

        output = self.conv10(x9)

        b, c, h, w = output.shape
        index = (points[:, :, 0] + w * points[:, :, 1]).unsqueeze(2)
        pred = output.reshape(b, c, h * w).gather(2, index).squeeze(2)

        return pred


def conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
    )

def down_pooling():
    return nn.MaxPool2d(2)

def up_pooling(in_channels, out_channels, kernel_size=2, stride=2):
    return nn.Sequential(
        # 転置畳み込み
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [9]:
def test_unet():
    model = UNet(3, 1)
    x, p, t = next(iter(train_loader))
    y = model(x, p)
    print(y.shape, t.shape)
    print(y, t)
    return

test_unet()

torch.Size([10, 1]) torch.Size([10, 1])
tensor([[ 7.8754],
        [ 5.9705],
        [ 0.1203],
        [ 5.6463],
        [ 6.8848],
        [-1.3436],
        [ 6.7791],
        [ 3.5101],
        [ 1.4560],
        [ 2.5458]], grad_fn=<SqueezeBackward1>) tensor([[0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.1186],
        [0.0066],
        [1.0000]])


In [10]:
# <---------------各インスタンス作成---------------------->
model = UNet(3,1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = nn.BCEWithLogitsLoss()
accuracy_metric = torchmetrics.Accuracy(threshold=0.5).to(device='cuda:0')
num_epochs=20
valid_loss_min = np.Inf

checkpoint_path = 'model/chkpoint_'
best_model_path = 'model/bestmodel.pt'

total_train_loss = []
total_train_score = []
total_valid_loss = []
total_valid_score = []

losses_value = 0
for epoch in range(num_epochs):
    # <---------------トレーニング---------------------->
    train_loss = []
    train_score = []
    valid_loss = []
    valid_score = []
    pbar = tqdm(train_loader, desc = 'description')
    for x_train, p_train, t_train in pbar:
        x_train = torch.autograd.Variable(x_train).cuda()
        p_train = torch.autograd.Variable(p_train).cuda()
        t_train = torch.autograd.Variable(t_train).cuda()
        optimizer.zero_grad()
        y_train = model(x_train, p_train)
        # 損失計算
        loss = criterion(y_train, t_train)
        losses_value = loss.item()
        # 精度評価
        score = accuracy_metric(y_train, t_train.long())
        loss.backward()
        optimizer.step()
        train_loss.append(losses_value)
        train_score.append(score.item())
        pbar.set_description(f"Epoch: {epoch+1}, loss: {losses_value}, IoU: {score}")

    # <---------------評価---------------------->
    with torch.no_grad():
        for x_val, p_val, t_val in val_loader:
            x_val = torch.autograd.Variable(x_val).cuda()
            p_val = torch.autograd.Variable(p_val).cuda()
            t_val = torch.autograd.Variable(t_val).cuda()
            y_val = model(x_val, p_val)
            # 損失計算
            loss = criterion(y_val, t_val)
            losses_value = loss.item()
            # 精度評価
            score = accuracy_metric(y_val, t_val.long())
            valid_loss.append(losses_value)
            valid_score.append(score.item())

    total_train_loss.append(np.mean(train_loss))
    total_train_score.append(np.mean(train_score))
    total_valid_loss.append(np.mean(valid_loss))
    total_valid_score.append(np.mean(valid_score))
    print(f"Train Loss: {total_train_loss[-1]}, Train IOU: {total_train_score[-1]}")
    print(f"Valid Loss: {total_valid_loss[-1]}, Valid IOU: {total_valid_score[-1]}")

    checkpoint = {
        'epoch': epoch + 1,
        'valid_loss_min': total_valid_loss[-1],
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    print("")

Epoch: 1, loss: 5.147805213928223, IoU: 0.5: 100%|██████████| 51/51 [02:44<00:00,  3.22s/it]                 


Train Loss: 1.2554461190513535, Train IOU: 0.7470588239968992
Valid Loss: 0.6106826531536439, Valid IOU: 0.8338235231006846



Epoch: 2, loss: 0.6553490161895752, IoU: 0.5: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]                


Train Loss: 0.8757495956093657, Train IOU: 0.8058823475650713
Valid Loss: 0.6413560759495286, Valid IOU: 0.8985294103622437



Epoch: 3, loss: 0.03526078909635544, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]               


Train Loss: 0.5733212412280195, Train IOU: 0.8607843062456917
Valid Loss: 0.5508027015363469, Valid IOU: 0.8691176386440501



Epoch: 4, loss: 0.17650499939918518, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]               


Train Loss: 0.47124496905827057, Train IOU: 0.8686274453705433
Valid Loss: 0.6281522521201302, Valid IOU: 0.8529411729644326



Epoch: 5, loss: 0.15712429583072662, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]               


Train Loss: 0.4543606989523944, Train IOU: 0.8745097950393078
Valid Loss: 0.6156334596521714, Valid IOU: 0.8573529404752395



Epoch: 6, loss: 0.16861799359321594, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]               


Train Loss: 0.4760594350450179, Train IOU: 0.864705879314273
Valid Loss: 0.5750580105711433, Valid IOU: 0.8279411757693571



Epoch: 7, loss: 0.17399334907531738, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.45s/it]               


Train Loss: 0.45668286464962304, Train IOU: 0.8725490114268135
Valid Loss: 0.47141002381549163, Valid IOU: 0.8573529404752395



Epoch: 8, loss: 1.2791695594787598, IoU: 0.5: 100%|██████████| 51/51 [02:03<00:00,  2.43s/it]                


Train Loss: 0.4466630531584515, Train IOU: 0.8725490161016876
Valid Loss: 0.5044892469749731, Valid IOU: 0.8455882352941176



Epoch: 9, loss: 0.13555502891540527, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.45s/it]               


Train Loss: 0.4224110010500048, Train IOU: 0.8686274442018247
Valid Loss: 0.3824085549396627, Valid IOU: 0.8970588235294118



Epoch: 10, loss: 0.4094507098197937, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]                


Train Loss: 0.45726272390753614, Train IOU: 0.8725490090893764
Valid Loss: 0.4492317530162194, Valid IOU: 0.9029411708607393



Epoch: 11, loss: 2.0090320110321045, IoU: 0.5: 100%|██████████| 51/51 [02:04<00:00,  2.44s/it]                


Train Loss: 0.5665649428379302, Train IOU: 0.8333333269053814
Valid Loss: 0.4946618921616498, Valid IOU: 0.8544117597972646



Epoch: 12, loss: 0.08109353482723236, IoU: 1.0: 100%|██████████| 51/51 [02:04<00:00,  2.44s/it]               


Train Loss: 0.4035295972637102, Train IOU: 0.8921568522266313
Valid Loss: 0.37977421327548866, Valid IOU: 0.8955882226719576



Epoch: 13, loss: 1.1264389753341675, IoU: 0.5: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]                


Train Loss: 0.4641800341652889, Train IOU: 0.8666666640954859
Valid Loss: 0.4859687011031544, Valid IOU: 0.8602941106347477



Epoch: 14, loss: 0.17912383377552032, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.46s/it]               


Train Loss: 0.4081304492611511, Train IOU: 0.8803921493829465
Valid Loss: 0.41591571490554247, Valid IOU: 0.8882352850016426



Epoch: 15, loss: 0.1645144522190094, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.47s/it]                


Train Loss: 0.4576927575410581, Train IOU: 0.8686274477079803
Valid Loss: 0.5035043674356797, Valid IOU: 0.833823519594529



Epoch: 16, loss: 0.13467150926589966, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.47s/it]               


Train Loss: 0.3898341683488266, Train IOU: 0.8882352861703611
Valid Loss: 0.45051609417971444, Valid IOU: 0.8691176421502057



Epoch: 17, loss: 0.14069905877113342, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.47s/it]               


Train Loss: 0.42486515758084314, Train IOU: 0.8803921540578207
Valid Loss: 0.39226873219013214, Valid IOU: 0.8808823508374831



Epoch: 18, loss: 0.9484347105026245, IoU: 0.5: 100%|██████████| 51/51 [02:04<00:00,  2.44s/it]                


Train Loss: 0.4519154544846684, Train IOU: 0.8549019554082085
Valid Loss: 0.43593404222937193, Valid IOU: 0.8632352843004114



Epoch: 19, loss: 0.14521490037441254, IoU: 1.0: 100%|██████████| 51/51 [02:04<00:00,  2.43s/it]               


Train Loss: 0.3877033432032548, Train IOU: 0.8980392054015515
Valid Loss: 0.4077554910498507, Valid IOU: 0.8999999866766089



Epoch: 20, loss: 0.1698000133037567, IoU: 1.0: 100%|██████████| 51/51 [02:05<00:00,  2.47s/it]                


Train Loss: 0.43346985970057694, Train IOU: 0.8803921517203835
Valid Loss: 0.3806152291157666, Valid IOU: 0.916176462874693



In [None]:
import seaborn as sns

plt.figure(1)
plt.figure(figsize=(15,5))
sns.set_style(style="darkgrid")
plt.subplot(1, 2, 1)
sns.lineplot(x=range(1,num_epochs+1), y=total_train_loss, label="Train Loss")
sns.lineplot(x=range(1,num_epochs+1), y=total_valid_loss, label="Valid Loss")
plt.title("Loss")
plt.xlabel("epochs")
plt.ylabel("DiceLoss")

plt.subplot(1, 2, 2)
sns.lineplot(x=range(1,num_epochs+1), y=total_train_score, label="Train Score")
sns.lineplot(x=range(1,num_epochs+1), y=total_valid_score, label="Valid Score")
plt.title("Score (IoU)")
plt.xlabel("epochs")
plt.ylabel("IoU")
plt.show()

In [None]:
def visualize_predict(model, n_images):
    figure, ax = plt.subplots(nrows=n_images, ncols=3, figsize=(15, 18))
    with torch.no_grad():
        for data,mask in val_loader:
            data = torch.autograd.Variable(data, volatile=True).cuda()
            mask = torch.autograd.Variable(mask, volatile=True).cuda()
            o = model(data)
            break
    for img_no in range(0, n_images):
        tm=o[img_no][0].data.cpu().numpy()
        img = data[img_no].data.cpu()
        msk = mask[img_no].data.cpu()
        img = format_image(img)
        msk = format_mask(msk)
        ax[img_no, 0].imshow(img)
        ax[img_no, 1].imshow(msk, interpolation="nearest", cmap="gray")
        ax[img_no, 2].imshow(tm, interpolation="nearest", cmap="gray")
        ax[img_no, 0].set_title("Input Image")
        ax[img_no, 1].set_title("Label Mask")
        ax[img_no, 2].set_title("Predicted Mask")
        ax[img_no, 0].set_axis_off()
        ax[img_no, 1].set_axis_off()
        ax[img_no, 2].set_axis_off()
    plt.tight_layout()
    plt.show()

visualize_predict(model, 6)