In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#os.remove('/kaggle/working/datasets/cityscapes.pth')
# for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#       print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from glob import glob

In [3]:
train_path = glob('/kaggle/input/cityscapes-image-pairs/cityscapes_data/train/*')
valid_path = glob('/kaggle/input/cityscapes-image-pairs/cityscapes_data/val/*')

In [4]:
class Cityscapes(Dataset):
    def __init__(self, data_path, transform=None, target_transform=None):
        super(Cityscapes, self).__init__()
        self.data_path = data_path
        #self.datasets = np.array(data)
        #self.images, self.targets = np.array_split(self.datasets, 2, axis=2)
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, item):
        image_pair = plt.imread(self.data_path[item])
        image, target = image_pair[:, :int(image_pair.shape[1] / 2)], image_pair[:, int(image_pair.shape[1] / 2):]
        #image = self.images[item]
        #target = self.targets[item]
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return image, target

    def __len__(self):
        return len(self.data_path)

In [5]:
import torch
import torch.nn as nn
from torch import Tensor


class UNet(nn.Module):
    def __init__(self, image_channel=3, mid_channel=64):
        super(UNet, self).__init__()
        self.two_conv_block = nn.Sequential(
            nn.Conv2d(image_channel, mid_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channel, mid_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channel),
            nn.ReLU(inplace=True)
        )
        self.down_sample_1 = DownSampleConvBlock(mid_channel, mid_channel * 2, mid_channel * 2)
        self.down_sample_2 = DownSampleConvBlock(mid_channel * 2, mid_channel * 4, mid_channel * 4)
        self.down_sample_3 = DownSampleConvBlock(mid_channel * 4, mid_channel * 8, mid_channel * 8)
        self.down_sample_4 = DownSampleConvBlock(mid_channel * 8, mid_channel * 16, mid_channel * 16)
        self.up_sample_1 = UpSampleConvBlock(mid_channel * 16, mid_channel * 8, mid_channel * 8)
        self.up_sample_2 = UpSampleConvBlock(mid_channel * 8, mid_channel * 4, mid_channel * 4)
        self.up_sample_3 = UpSampleConvBlock(mid_channel * 4, mid_channel * 2, mid_channel * 2)
        self.up_sample_4 = UpSampleConvBlock(mid_channel * 2, mid_channel, mid_channel)
        # 降维
        self.conv1x1 = nn.Conv2d(mid_channel, image_channel, kernel_size=1)
        self.bn = nn.BatchNorm2d(image_channel)

    def forward(self, x):
        x1 = self.two_conv_block(x)
        x2 = self.down_sample_1(x1)
        x3 = self.down_sample_2(x2)
        x4 = self.down_sample_3(x3)
        x5 = self.down_sample_4(x4)
        x_u1 = self.up_sample_1(x5, x4)
        x_u2 = self.up_sample_2(x_u1, x3)
        x_u3 = self.up_sample_3(x_u2, x2)
        x_u4 = self.up_sample_4(x_u3, x1)
        return self.bn(self.conv1x1(x_u4))


class BasicConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
        super(BasicConvBlock, self).__init__()
        self.layers = nn.Sequential(
            # the first conv block
            nn.Conv2d(in_channels, mid_channels, kernel_size, stride, padding, dilation),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            # the second conv block
            nn.Conv2d(mid_channels, out_channels, kernel_size, stride, padding, dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x)


class DownSampleConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, pool_size=2, kernel_size=3, stride=1, padding=1,
                 dilation=1):
        super(DownSampleConvBlock, self).__init__()
        self.down_sample = nn.MaxPool2d(kernel_size=pool_size)
        self.basic_conv_block = BasicConvBlock(in_channels, mid_channels, out_channels, kernel_size, stride, padding,
                                               dilation)

    def forward(self, x):
        # down sample
        x = self.down_sample(x)
        # two conv block
        x = self.basic_conv_block(x)
        return x


class UpSampleConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, up_size=2, kernel_size=3, stride=1, padding=1,
                 dilation=1):
        super(UpSampleConvBlock, self).__init__()
        self.up_sample = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=up_size, stride=up_size,
                                            dilation=dilation)
        self.basic_conv_block = BasicConvBlock(in_channels, mid_channels, out_channels, kernel_size, stride, padding,
                                               dilation)

    def forward(self, x, skip_x):
        # up sample
        x = self.up_sample(x)
        # concat x and skip_x in the dimension of channel
        x = torch.cat([x, skip_x], dim=1)
        # two conv block
        x = self.basic_conv_block(x)
        return x


class AdaptiveFeatureFusionModule(nn.Module):
    """Adaptive Feature Fusion Module(AFFM)

    Fusion multiple-scale feature maps, the count of feature maps is not fixed,
    the value of counts must equal the size of feature_maps, the number of layers
    in AFFM is determined by the parameter of counts.

    """

    def __init__(self, counts):
        super(AdaptiveFeatureFusionModule, self).__init__()
        self.counts = counts
        pass

    def forward(self, feature_maps: tuple = None):
        pass

In [6]:
def dice_coeff(predict, target, reduce_batch_first=False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert predict.size() == target.size()
    if predict.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {predict.shape})')

    if predict.dim() == 2 or reduce_batch_first:
        inter = torch.dot(predict.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(predict) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter
        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(predict.shape[0]):
            dice += dice_coeff(predict[i, ...], target[i, ...])
        # return average dice loss value of a batch
        return dice / predict.shape[0]


def multiclass_dice_coeff(predict, target, reduce_batch_first=False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert predict.size() == target.size()
    dice = 0
    for channel in range(predict.shape[1]):
        dice += dice_coeff(predict[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
    return dice / predict.shape[1]


def dice_loss(predict, target, multiclass=True, epsilon=1e-6):
    # Dice loss (objective to minimize) between 0 and 1
    assert predict.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(predict, target, reduce_batch_first=True, epsilon=epsilon)


class DiceLoss(nn.Module):
    def __init__(self, ep=1e-8):
        super(DiceLoss, self).__init__()
        self.ep = ep

    def forward(self, predict, target):
        # the shape of predict must equal to the shape of target
        value = dice_loss(predict, target, True, self.ep)
        return value

In [7]:
def validate(network_model, valid_loader, loss, device):
    network_model.eval()
    v_loss_total = 0.0
    with torch.no_grad():
        for j, (v_x, v_l) in enumerate(valid_loader):
            v_x = v_x.to(device)
            v_l = v_l.to(device)
            v_predict = network_model(v_x)
            loss_value = loss(v_predict, v_l)
            v_loss_total += loss_value.item()
    val_avg_loss = v_loss_total / len(valid_loader)
    return val_avg_loss


class SearchBestModel(object):
    def __init__(self, min_delta=0, verbose=True):
        super(SearchBestModel, self).__init__()
        self.verbose = verbose
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None

    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif self.best_score - val_loss >= self.min_delta:
            self.best_score = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print('performance reducing: counter {}'.format(self.counter))

In [8]:
def train(train_loader, valid_loader, model, optimizer, loss, epoch, device):
    loss_change_list = []
    valid_loss_change = []
    save_best = {}
    save_last = {}
    search_best_model = SearchBestModel()
    for i in range(epoch):
        model.train()
        total_loss = 0.0
        for index, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.to(device).to(torch.float32)

            segment_mask = model(image)
            loss_value = loss(segment_mask, label)

            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()

            total_loss = total_loss + loss_value.item()

            print('epoch {} batch {}/{} loss = {:.4f}'.format(i + 1, index + 1, len(train_loader), loss_value.item()))
        # save train loss change history used for model analyse
        loss_change_list.append(total_loss / len(train_loader))
        # use the dataset for validation to validate the trained model
        valid_avg_loss = validate(model, valid_loader, loss, device)
        # save valid loss change history used for model analyse
        valid_loss_change.append(valid_avg_loss)

        print('epoch {} train loss = {:.4f} valid loss = {:.4f}'.format(i + 1, total_loss / len(train_loader), valid_avg_loss))

        # see if satisfy the conditions of early stopping
        search_best_model(valid_avg_loss)
        # if satisfy the conditions of early stopping, break the training process
        if search_best_model.counter > 0:
            continue
        # if not satisfy the conditions of early stopping, it shows that
        # the model in this epoch is the best, save the params of current model.
        save_best['model_state_dict'] = model.state_dict()
        # save optimizer used for re-train
        save_best['optimizer_state_dict'] = optimizer.state_dict()
        # save the epoch of current best model
        save_best['epoch'] = i
    # save loss change history of training and validation
    save_last['train_loss_change'] = loss_change_list
    save_last['valid_loss_change'] = valid_loss_change
    save_last['model_state_dict'] = model.state_dict()
    save_last['optimizer_state_dict'] = optimizer.state_dict()
    save_last['trained_epoch'] = epoch
    torch.save(save_best, './unet-best.pth')
    torch.save(save_last, './unet-last.pth')

In [9]:
if __name__ == '__main__':
    current_device = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_image_path = glob('../input/cityscapes-image-pairs/cityscapes_data/train/*')
    valid_image_path = glob('../input/cityscapes-image-pairs/cityscapes_data/val/*')

    image_transforms = transforms.Compose([
#         transforms.ToPILImage(mode='RGB'),
        transforms.ToTensor(),
        transforms.Resize(256),
#         transforms.PILToTensor(),
#         transforms.ColorJitter(),
#         transforms.GaussianBlur((5,), (5, 15)),
#         transforms.Normalize((0.5,), (0.5,)),
    ])
    target_transforms = transforms.Compose([
#         transforms.ToPILImage(mode='RGB'),
#         transforms.Grayscale(),
#         transforms.PILToTensor(),
        transforms.ToTensor(),
#         transforms.Resize(256),
#         transforms.Normalize((0.5,), (0.5,)),
    ])
    valid_transforms = transforms.Compose([
        transforms.ToTensor(),
#         transforms.Resize(256),
#         transforms.Normalize((0.5,), (0.5,)),
#         transforms.ToPILImage(mode='RGB'),
        transforms.Resize(256),
#         transforms.PILToTensor(),
#         transforms.Normalize((0.5,), (0.5,)),
    ])

    train_cityscapes = Cityscapes(train_image_path, image_transforms, target_transforms)
    valid_cityscapes = Cityscapes(valid_image_path, valid_transforms, target_transforms)

    train_loader = DataLoader(train_cityscapes, batch_size=16, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_cityscapes, batch_size=16, shuffle=True, drop_last=True)

    unet = UNet().to(current_device)
    ce_loss = nn.MSELoss().to(current_device)
    optimizer_adam = optim.Adam(unet.parameters(), lr=0.0001)
    train(train_loader, valid_loader, unet, optimizer_adam, ce_loss, epoch=50, device=current_device)

epoch 1 batch 1/185 loss = 1.2194
epoch 1 batch 2/185 loss = 1.1514
epoch 1 batch 3/185 loss = 1.1129
epoch 1 batch 4/185 loss = 1.0736
epoch 1 batch 5/185 loss = 1.0443
epoch 1 batch 6/185 loss = 1.0304
epoch 1 batch 7/185 loss = 1.0007
epoch 1 batch 8/185 loss = 0.9995
epoch 1 batch 9/185 loss = 0.9713
epoch 1 batch 10/185 loss = 0.9735
epoch 1 batch 11/185 loss = 0.9668
epoch 1 batch 12/185 loss = 0.9730
epoch 1 batch 13/185 loss = 0.9513
epoch 1 batch 14/185 loss = 0.9430
epoch 1 batch 15/185 loss = 0.9322
epoch 1 batch 16/185 loss = 0.9389
epoch 1 batch 17/185 loss = 0.9291
epoch 1 batch 18/185 loss = 0.9359
epoch 1 batch 19/185 loss = 0.9300
epoch 1 batch 20/185 loss = 0.9191
epoch 1 batch 21/185 loss = 0.9235
epoch 1 batch 22/185 loss = 0.9169
epoch 1 batch 23/185 loss = 0.9185
epoch 1 batch 24/185 loss = 0.9037
epoch 1 batch 25/185 loss = 0.8952
epoch 1 batch 26/185 loss = 0.9016
epoch 1 batch 27/185 loss = 0.9142
epoch 1 batch 28/185 loss = 0.9001
epoch 1 batch 29/185 loss = 0