In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import random

import numpy as np
import torch
import torchvision.transforms
import torchvision.transforms.functional as Func


class ToPILImage(object):
    def __call__(self, image, target=None):
        image = Func.to_pil_image(image)
        if target is not None:
            target = Func.to_pil_image(target)
        return image, target


class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target=None):
        seed = np.random.randint(65536)
        torch.manual_seed(seed)
        crop = torchvision.transforms.RandomCrop(self.size)
        image = crop(image)
        if target is not None:
            target = crop(target)
        return image, target


class Resize(object):
    def __init__(self, size):
        super(Resize, self).__init__()
        self.size = size

    def __call__(self, image, target=None):
        image = Func.resize(image, self.size)
        if target is not None:
            target = Func.resize(target, self.size, interpolation=Func.InterpolationMode.NEAREST)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, flip_prob=0.5):
        self.flip_prob = flip_prob

    def __call__(self, image, target=None):
        if random.random() < self.flip_prob:
            image = Func.hflip(image)
            if target is not None:
                target = Func.hflip(target)
        return image, target


class ColorJitter(object):
    def __call__(self, image, target):
        color_jitter = torchvision.transforms.ColorJitter()
        image = color_jitter(image)
        return image, target


class GrayScale(object):
    def __call__(self, image, target):
        gray_scale = torchvision.transforms.Grayscale()
        image = gray_scale(image)
        return image, target


class RandomRotation(object):
    def __init__(self, degrees):
        super(RandomRotation, self).__init__()
        self.degrees = degrees

    def __call__(self, image, target=None):
        degree = random.randint(self.degrees[0], self.degrees[1])
        image = Func.rotate(image, degree)
        if target is not None:
            target = Func.rotate(target, degree)
        return image, target


class Normalize(object):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = Func.normalize(image, mean=self.mean, std=self.std)
        return image, target


class ToTensor(object):
    def __call__(self, image, target=None):
        image = Func.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

In [3]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data_path, target_path, transforms=None):
        super(CustomDataset, self).__init__()
        self.data_paths = data_path
        self.target_paths = target_path
        self.transforms = transforms

    def __getitem__(self, item):

        image = plt.imread(self.data_paths[item])
        target = plt.imread(self.target_paths[item])

        image = np.expand_dims(image, axis=-1)

        # image = torch.from_numpy(image)
        # target = torch.from_numpy(target)

        if self.transforms is not None:
            image, target = self.transforms(image, target)
        return image, target

    def __len__(self):
        return len(self.data_paths)


def load_data(data_path, target_path, batch_size, drop_last=False, transforms=None):
    datas = CustomDataset(data_path=data_path, target_path=target_path, transforms=transforms)
    data_loader = DataLoader(datas, shuffle=True, batch_size=batch_size, drop_last=drop_last)
    return data_loader

In [4]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional


class UNet(nn.Module):
    """U-Net模型的pytorch实现。
    论文地址：https://arxiv.org/abs/1505.04597
    模型的总体结构: 编码器 -> 一个ConvBlock -> 解码器 -> 一个Conv 1 * 1
    """

    def __init__(self):
        super(UNet, self).__init__()
        # 编码器部分
        self.eb1 = EncoderBlock(1, 64, kernel_size=2)
        self.eb2 = EncoderBlock(64, 128, kernel_size=2)
        self.eb3 = EncoderBlock(128, 256, kernel_size=2)
        self.eb4 = EncoderBlock(256, 512, kernel_size=2)
        # 编码器与解码器之间有一个ConvBlock
        self.cb = ConvBlock(512, 1024)
        # 解码器部分
        self.db1 = DecoderBlock(1024, 512)
        self.db2 = DecoderBlock(512, 256)
        self.db3 = DecoderBlock(256, 128)
        self.db4 = DecoderBlock(128, 64)
        # 一个Conv 1 * 1, 二分类，结果为两个通道
        self.conv1x1 = nn.Conv2d(64, 2, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        ex1, skip_x1 = self.eb1(x)
        ex2, skip_x2 = self.eb2(ex1)
        ex3, skip_x3 = self.eb3(ex2)
        ex4, skip_x4 = self.eb4(ex3)
        cbx = self.cb(ex4)
        dx1 = self.db1(cbx, skip_x4)
        dx2 = self.db2(dx1, skip_x3)
        dx3 = self.db3(dx2, skip_x2)
        dx4 = self.db4(dx3, skip_x1)
        crop = transforms.CenterCrop(size=(x.shape[-1], x.shape[-2]))
        # normalize = transforms.Normalize((0.5,), (0.5,))
        return self.sigmoid(self.conv1x1(crop(dx4)))


class BottleBlock(nn.Module):
    def __init__(self, channels):
        super(BottleBlock, self).__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels // 2, kernel_size=1),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=channels // 2, out_channels=channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=channels // 2, out_channels=channels, kernel_size=1)
        )
        self.bn = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.layer(x)
        return self.relu(self.bn(torch.add(out, x)))


class ConvBlock(nn.Module):
    """一个Conv2d卷积后跟一个Relu激活函数，卷积核大小为3 * 3

    :param in_channels: 层次块的输入通道数
    :param out_channels: 层次块输出层的通道数
    """

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.change_channel = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bk1 = BottleBlock(out_channels)
        self.bk2 = BottleBlock(out_channels)

    def forward(self, x):
        x = self.change_channel(x)
        x = self.bk1(x)
        x = self.bk2(x)
        return x


class DownSampling(nn.Module):
    """下采样，使用max pool方法执行，核大小为 2 * 2，用在编码器的ConvBlock后面

    :param kernel_size: 下采样层（即最大池化层）的核大小
    """

    def __init__(self, kernel_size):
        super(DownSampling, self).__init__()
        self.down_sample = nn.MaxPool2d(kernel_size=kernel_size)

    def forward(self, x):
        return self.down_sample(x)


class UpSampling(nn.Module):
    """上采样，用在解码器的ConvBlock前面，使用转置卷积，同时通道数减半，

    C_out = out_channels
    H_out = (H_in - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
    W_out = (W_in - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1

    :param in_channels: 转置卷积的输入通道数
    :param out_channels: 转置卷积的输出通道数
    :param kernel_size: 转置卷积的卷积核大小，默认为2
    :param stride: 转置卷积的步幅，默认为2
    """

    def __init__(self, in_channels, out_channels, kernel_size=7, stride=2, dilation=1, padding=0, output_padding=1):
        super(UpSampling, self).__init__()
        # self.up_sample = nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        # stride=2, kernel_size=2相当于宽高翻倍
        self.up_sample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                            dilation=dilation, padding=padding, output_padding=output_padding)

    def forward(self, x):
        return self.up_sample(x)


class EncoderBlock(nn.Module):
    """编码器中的一个层次块

    :param in_channels: 层次块的输入通道数
    :param out_channels: 层次块输出层的通道数
    :param kernel_size: 下采样层（即最大池化层）的核大小
    """

    def __init__(self, in_channels, out_channels, kernel_size):
        super(EncoderBlock, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.down_sample = DownSampling(kernel_size)

    def forward(self, x):
        x1 = self.conv_block(x)
        return self.down_sample(x1), x1


class ConcatLayer(nn.Module):
    """跳跃连接，在通道维上连接

    """

    def __init__(self):
        super(ConcatLayer, self).__init__()

    def forward(self, x, skip_x):
        # 将从编码器传过来的特征图裁剪到与输入相同尺寸
        x1 = functional.center_crop(skip_x, [x.shape[-2], x.shape[-1]])
        # crop = transforms.RandomCrop(x.shape)
        # x1 = crop(skip_x).unsqueeze(dim=0)
        # x1 = x.unsqueeze(dim=0)
        # F.grid_sample()
        # x2 = x.unsqueeze(dim=0)
        if x1.shape != x.shape:
            raise Exception('要连接的两个特征图尺寸不一致，skip_x.shape={}，x.shape={}'.format(skip_x.shape, x.shape))
        # 通道维连接
        return torch.cat([x, x1], dim=1)


class DecoderBlock(nn.Module):
    """解码器中的层次块，每个层次块都是UpSampling -> Concat -> ConvBlock

    :param in_channels: 层次块的输入通道数
    :param out_channels: 层次块输出层的通道数
    """

    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.up_sample = UpSampling(in_channels, out_channels)
        self.conv_block = ConvBlock(in_channels, out_channels)

    def forward(self, x, skip_x):
        x1 = self.up_sample(x)
        concat = ConcatLayer()
        x2 = concat(x1, skip_x)
        return self.conv_block(x2)

In [5]:
from argparse import ArgumentParser
from glob import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


# noinspection PyShadowingNames,SpellCheckingInspection
def train(train_loader, valid_loader, model, criterion, optimizer, total_epoch, current_epoch=0, num_classes=2,
          device='cpu'):
    model.to(device)
    criterion.to(device)
    loss_change_list = []
    valid_loss_change_list = []
    saved_last = {}
    saved_best = {}
    loss_change = {}

    search_best = SearchBest()

    for i in range(current_epoch, total_epoch):
        model.train()
        total_loss = 0.0
        for index, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            y_onehot = convert_to_one_hot(y, num_classes=num_classes)
            predict = model(x)

            loss_value = criterion(predict, y_onehot)
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()

            total_loss += loss_value.item()
            print('Epoch {}: Batch {}/{} loss: {:.4f}'.format(i + 1, index + 1, len(train_loader), loss_value.item()))

        loss_change_list.append(total_loss / len(train_loader))
        valid_avg_loss = valid(model, criterion, valid_loader, num_classes, device)
        valid_loss_change_list.append(valid_avg_loss)
        print('Epoch {} train loss: {:.4f} valid loss: {:.4f}'.format(i + 1, total_loss / len(train_loader),
                                                                      valid_avg_loss))
        search_best(valid_avg_loss)
        if search_best.counter == 0:
            # save the relevant params of the best model state in the current time.
            saved_best['best_model_state_dict'] = model.state_dict()
            saved_best['best_optimizer_state_dict'] = optimizer.state_dict()
            saved_best['epoch'] = i + 1
    loss_change['train_loss_change_history'] = loss_change_list
    loss_change['valid_loss_change_history'] = valid_loss_change_list
    saved_last['last_model_state_dict'] = model.state_dict()
    saved_last['last_optimizer_state_dict'] = optimizer.state_dict()
    saved_last['epoch'] = total_epoch
    torch.save(saved_best, './best_model.pth')
    torch.save(saved_last, './last_model.pth')
    torch.save(loss_change, './loss_change.pth')


def convert_to_one_hot(data, num_classes):
    if type(data) is not torch.Tensor:
        raise RuntimeError('data must be a torch.Tensor')
    if data.dtype is not torch.int64:
        data = data.to(torch.int64)
    data = F.one_hot(data, num_classes=num_classes).permute((0, -1, 1, 2))
    return data.to(torch.float32)


class SearchBest(object):
    def __init__(self, min_delta=0, verbose=True):
        super(SearchBest, self).__init__()
        self.counter = 0
        self.min_delta = min_delta
        self.best_score = None
        self.verbose = verbose

    def __call__(self, valid_loss):
        if self.best_score is None:
            self.best_score = valid_loss
        elif self.best_score - valid_loss >= self.min_delta:
            self.best_score = valid_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print('performance reducing: {}'.format(self.counter))


# noinspection PyShadowingNames
def valid(model, criterion, valid_loader, num_classes, device):
    """
    :return: validate loss
    """
    model.eval()
    valid_total_loss = 0.0
    for index, (x, y) in enumerate(valid_loader):
        x = x.to(device)
        y = y.to(device)
        y_onehot = convert_to_one_hot(y, num_classes)
        with torch.no_grad():
            predict = model(x)
            valid_loss = criterion(predict, y_onehot)
            valid_total_loss += valid_loss.item()
    return valid_total_loss / len(valid_loader)

In [6]:
if __name__ == '__main__':
    lr = 1e-3
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
    
    # 后面修改为Dice损失函数看看效果
    criterion = nn.BCELoss()

    train_image_path = glob('../input/covidxray/covid19-xray/train/images/*')
    train_mask_path = glob('../input/covidxray/covid19-xray/train/masks/*')
    
    val_image_path = glob('../input/covidxray/covid19-xray/val/images/*')
    val_mask_path = glob('../input/covidxray/covid19-xray/val/masks/*')

    train_transforms = Compose([
        # todo 先Resize试一下
        ToPILImage(),
        RandomHorizontalFlip(),
        RandomRotation(degrees=(0, 180)),
#         ColorJitter(), # 这个有问题，但是是什么问题？
#         GrayScale(), # 这个貌似也有问题，问题更大
        RandomCrop(256),
        ToTensor(),
        Normalize((0.5,), (0.5,))
    ])
    valid_transforms = Compose([
        ToPILImage(),
        Resize(256),
        ToTensor(),
        Normalize((0.5,), (0.5,))
    ])
    
    train_loader = load_data(train_image_path, train_mask_path, batch_size=8, drop_last=True,
                             transforms=train_transforms)
    valid_loader = load_data(val_image_path, val_mask_path, batch_size=8, transforms=valid_transforms)
    
    continue_train = False
    torch.cuda.empty_cache()
    if not continue_train:
        epoch = 100
        train(train_loader, valid_loader, model, criterion, optimizer, epoch, device=device)
    else:
        total_epoch = 200
        pretrain_params = torch.load('../input/covid-xray-unet/last_model.pth')
        model.load_state_dict(pretrain_params['last_model_state_dict'])
        optimizer.load_state_dict(pretrain_params['last_optimizer_state_dict'])
        current_epoch = pretrain_params['epoch']
        model.train()
        train(train_loader, valid_loader, model, criterion, optimizer, total_epoch, current_epoch, device=device)

Epoch 1: Batch 1/466 loss: 0.8009
Epoch 1: Batch 2/466 loss: 0.7220
Epoch 1: Batch 3/466 loss: 0.6329
Epoch 1: Batch 4/466 loss: 0.5628
Epoch 1: Batch 5/466 loss: 0.5124
Epoch 1: Batch 6/466 loss: 0.4527
Epoch 1: Batch 7/466 loss: 0.4400
Epoch 1: Batch 8/466 loss: 0.4309
Epoch 1: Batch 9/466 loss: 0.3927
Epoch 1: Batch 10/466 loss: 0.4497
Epoch 1: Batch 11/466 loss: 0.3834
Epoch 1: Batch 12/466 loss: 0.4208
Epoch 1: Batch 13/466 loss: 0.3649
Epoch 1: Batch 14/466 loss: 0.3193
Epoch 1: Batch 15/466 loss: 0.3452
Epoch 1: Batch 16/466 loss: 0.3138
Epoch 1: Batch 17/466 loss: 0.3822
Epoch 1: Batch 18/466 loss: 0.3228
Epoch 1: Batch 19/466 loss: 0.2901
Epoch 1: Batch 20/466 loss: 0.3520
Epoch 1: Batch 21/466 loss: 0.3336
Epoch 1: Batch 22/466 loss: 0.3565
Epoch 1: Batch 23/466 loss: 0.3533
Epoch 1: Batch 24/466 loss: 0.3171
Epoch 1: Batch 25/466 loss: 0.3430
Epoch 1: Batch 26/466 loss: 0.3134
Epoch 1: Batch 27/466 loss: 0.2916
Epoch 1: Batch 28/466 loss: 0.3546
Epoch 1: Batch 29/466 loss: 0