In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2 # np.array -> torch.tensor
import os
import os.path as osp
from PIL import Image
from torchvision import transforms as T
from tqdm import tqdm
from glob import glob
import datetime
import time
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
import matplotlib

In [3]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, deep_base=True):
        super(ResNet, self).__init__()
        self.deep_base = deep_base
        if not self.deep_base:
            self.inplanes = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.inplanes = 128
            self.conv1 = conv3x3(3, 64, stride=2)
            self.bn1 = nn.BatchNorm2d(64)
            self.conv2 = conv3x3(64, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = conv3x3(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bottleneck = Bottleneck(512, 128)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        if self.deep_base:
            x = self.relu(self.bn2(self.conv2(x)))
            x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bottleneck(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet152(pretrained=True, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
        model_path = './initmodel/resnet152_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
    return model

In [4]:
class PPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(PPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)

class PSPNet(nn.Module):
    def __init__(self, bins=(1, 2, 3, 6), dropout=0.15, n_classes=5, zoom_factor=8):
        super(PSPNet, self).__init__()
        assert 2048 % len(bins) == 0
        assert n_classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor

        resnet = resnet152(pretrained=False)
        self.layer0_0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.layer0_1 = nn.Sequential(resnet.conv2, resnet.bn2, resnet.relu)
        self.layer0_2 = nn.Sequential(resnet.conv3, resnet.bn3, resnet.relu,
                                        resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        fea_dim = 2048
        self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins)
        fea_dim *= 2

        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
        )

        self.conv1_1 = nn.Sequential(nn.Conv2d(256, 48, kernel_size=1, padding=0, stride=1, bias=False),
                                    nn.BatchNorm2d(48),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(48, 48, kernel_size=1, padding=0, stride=1, bias=False),
                                    nn.BatchNorm2d(48),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(48, 256, kernel_size=1, padding=0, stride=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout2d(p=dropout),
                )


    def forward(self, x, y=None):
        x_size = x.size()
        assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0
        h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
        w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
        x0 = self.layer0_0(x)
        x0 = self.layer0_1(x0)
        x0 = self.layer0_2(x0)
        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x_ppm = self.ppm(x4)
        x = self.cls(x_ppm)
        x_conv1_1 = self.conv1_1(x1)

        return x, x_conv1_1, x3

In [5]:
# model = PSPNet()
# x = torch.rand(4, 3, 257, 257)

# y, conv1_1, aux = model(x)
# print('=======')
# print(y.size())
# print(conv1_1.size())
# print(aux.size())

In [6]:
class Layer(nn.Module):
    def __init__(self, in_ch, kernel_s=3, padding=1, stride=1, grow_rate=16, dropRate=0.2):
        super(Layer, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(in_ch)
        self.conv = nn.Conv2d(in_ch, grow_rate, kernel_size=kernel_s, stride=stride,
                               padding=padding, bias=False)
        self.droprate = nn.Dropout(p=dropRate)
    def forward(self, x):
        x_out = self.conv(self.relu(self.bn(x)))
        x_out = self.droprate(x_out)
        return x_out

In [7]:
class DenseBlock(nn.Module):
    def __init__(self, in_ch, kernel_s=3, padding=1, stride=1, grow_rate=16, n_layers=4, Upsample=False):
        super(DenseBlock, self).__init__()
        self.upsample = Upsample
        self.layers = nn.ModuleList([Layer(in_ch + i*grow_rate, kernel_s, padding, stride, grow_rate)
                                     for i in range(n_layers)])
        self.n_layers = n_layers

    def forward(self, x):
        if self.upsample:
            new_features = []
            for layer in self.layers:
                x_out = layer(x)
                x = torch.cat([x, x_out], 1)
                new_features.append(x_out)
            return torch.cat(new_features,1)
        else:
            layer_arr = [x]
            for i in range(self.n_layers):
                x_out = self.layers[i](x)
                layer_arr.append(x_out)
                if i == self.n_layers - 1:
                    x = torch.cat(layer_arr, 1)
                else:
                    x = torch.cat([x, x_out], 1)
            return x


In [8]:
class TransitionDown(nn.Module):
    def __init__(self, in_ch, dropRate=0.15):
        super(TransitionDown, self).__init__()
        self.bn = nn.BatchNorm2d(in_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding=1, bias=False)
        self.pooling =  nn.MaxPool2d(kernel_size=2, stride=2)
        self.droprate = nn.Dropout(p=dropRate)
    def forward(self, x):
        x_out = self.conv(self.relu(self.bn(x)))
        x_out = self.droprate(x_out)
        x_out = self.pooling(x_out)
        return x_out

In [9]:
def center_crop(layer, max_height, max_width):
    _, _, h, w = layer.size()
    xy1 = (w - max_width) // 2
    xy2 = (h - max_height) // 2
    return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]

class TransitionUp(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_s=3, skip=True, padding=0):
        super(TransitionUp, self).__init__()
        self.transpose = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_s, stride=2, padding=padding, bias=True)
        self.skip = skip
    def forward(self, x, skip_connection):
        x_out = self.transpose(x)
        if self.skip == True:
            x_out = center_crop(x_out, skip_connection.size(2), skip_connection.size(3))
            x_out = torch.cat([x_out, skip_connection], 1)
        return x_out

In [10]:
class FCDenseNet(nn.Module):
    def __init__(self, in_ch=3, down_blocks=(5, 5, 5, 5, 5),
                 up_blocks=(5, 5, 5), bottleneck_layers=5,
                 grow_rate=16, kernel_s=3, padding=1, dilation=1, stride=1, m=48, n_classes=5):
        super(FCDenseNet, self).__init__()
        self.down_blocks = down_blocks
        self.up_blocks = up_blocks

        #   First Convolution   #
        #########################
        self.first_conv = nn.Conv2d(in_channels=in_ch, out_channels=m, kernel_size=3, stride=1, padding=1, bias=True)
        #########################################################################
        #############################   Multi Gate  #############################
        current_ch = m
        skip_ch = []
        #   Downsampling    #
        #####################
        self.DB_down_1 = nn.ModuleList([])
        self.TD_1 = nn.ModuleList([])
        for i in range(len(down_blocks)):
            self.DB_down_1.append(DenseBlock(current_ch, kernel_s, padding, stride, grow_rate, down_blocks[i], False))
            current_ch +=(down_blocks[i]*grow_rate)
            if i > 1:
                skip_ch.insert(0,current_ch)
            self.TD_1.append(TransitionDown(current_ch))
        #   bottleneck_1   #
        ##################
        # Layer : DB (15 layers), m = 896
        self.bottleneck = DenseBlock(current_ch, kernel_s, padding, 1, grow_rate, bottleneck_layers, True)
        prev_ch = (bottleneck_layers*grow_rate)
        current_ch += prev_ch
        #   Upsampling path   #
        #######################
        self.DB_up_1 = nn.ModuleList([])
        self.TU_1 = nn.ModuleList([])
        for i in range(len(up_blocks)-1):
            kernel_tu = 3
            self.TU_1.append(TransitionUp(prev_ch, prev_ch, kernel_tu))
            current_ch = prev_ch + skip_ch[i]
            self.DB_up_1.append(DenseBlock(current_ch, kernel_s, padding, 1, grow_rate, up_blocks[i], True))
            prev_ch = grow_rate*up_blocks[i]
            current_ch += prev_ch
        #   Final DenseBlock    #
        #########################
        self.TU_1.append(TransitionUp(prev_ch, prev_ch, 3))
        current_ch = prev_ch + skip_ch[-1]
        self.DB_up_1.append(DenseBlock(current_ch, kernel_s, padding, 1, grow_rate, up_blocks[-1], False))
        current_ch += grow_rate*up_blocks[-1]


    def forward(self, x, y=None):
        _,_,h,w = x.size()

        x_out = self.first_conv(x)
        skip_connections = []
        aux_in = None

        for i in range(len(self.down_blocks)):
            x_out = self.DB_down_1[i](x_out)
            skip_connections.append(x_out)
            x_out = self.TD_1[i](x_out)
            if i == 2:
                aux_in = x_out

        x_out = self.bottleneck(x_out)

        for i in range(len(self.up_blocks)):
            # Gate 1
            skip = skip_connections.pop()
            x_out = self.TU_1[i](x_out, skip)
            x_out = self.DB_up_1[i](x_out)

        return x_out, aux_in


In [11]:
# model = FCDenseNet()
# x = torch.rand(4, 3, 256, 256)

# y, aux = model(x)
# print('===========')
# print(y.size())
# print(aux.size())

In [12]:
class Model(nn.Module):
    def __init__(self, n_classes=6, dropout=0.15,):
        super(Model, self).__init__()
        self.pspnet = PSPNet(n_classes=n_classes)
        self.fcdense = FCDenseNet(n_classes=n_classes)

        self.p_upsampling = TransitionUp(256, 256, 3, False, padding=1)
        self.ASC1 = nn.Sequential(
                                nn.Conv2d(960, 256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False),
                                nn.BatchNorm2d(256),
                                nn.ReLU(inplace=True)
        )
        self.ASC2 = nn.Sequential(
                                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False),
                                nn.BatchNorm2d(256),
                                nn.ReLU(inplace=True)
        )
        self.finalconv = nn.Sequential(
                                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                nn.BatchNorm2d(256),
                                nn.ReLU(inplace=True),
                                nn.Dropout(p=dropout),
                                nn.Conv2d(256, n_classes, kernel_size=1, stride=1, padding=0, bias=False)
        )


        class_weights= torch.tensor([0.348, 3.786, 0, 1.394, 0.893, 37.098] ,dtype=torch.float)
        self.criterion= nn.CrossEntropyLoss(ignore_index=255, weight=class_weights)
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1312, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=0.15),
                nn.Conv2d(256, n_classes, kernel_size=1)
            )
    def forward(self, x, y=None):
        _, _, h, w = x.size()
        x_p = F.interpolate(x, size=(h+1, w+1), mode='bilinear', align_corners=True)
        p_out, p_conv1, p_aux = self.pspnet(x_p)
        d_out, d_aux = self.fcdense(x)
        # print(f'p out: {p_out.size()}')
        # print(f'p conv1: {p_conv1.size()}')

        # print(f'd out: {d_out.size()}')
        p_out = self.p_upsampling(p_out, None)
        #print(f'p out: {p_out.size()}')

        out = torch.cat([p_out, p_conv1, d_out], 1)
        #print(f'out: {out.size()}')
        out = self.ASC1(out)
        #print(f'out: {out.size()}')
        out = self.ASC2(out)
        #print(f'out: {out.size()}')
        out = self.finalconv(out)
        #print(f'out: {out.size()}')

        out = F.interpolate(out, size=(256, 256), mode='bilinear', align_corners=True)
        #print(out.size())
        if self.training:
            aux_in = torch.cat([p_aux, d_aux], 1)
            aux = self.aux(aux_in)
            aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
            aux_loss = self.criterion(aux, y)
            main_loss = self.criterion(out, y)
            return out.max(1)[1], main_loss, aux_loss
        else:
           return out

In [13]:
class Gleason(Dataset):
    def __init__(self, imgdir, maskdir=None, train=True, val=False,
                 test=False, transform=None, target_transform=None):
        super(Gleason, self).__init__()
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.imglist = os.listdir(imgdir)
        if not test:
            self.masklist = [item.replace('.jpg', '_classimg_nonconvex.png') for item in self.imglist]
        else:
            self.masklist = []

        self.train = train
        self.val = val
        self.test = test
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imglist)

    def __getitem__(self, idx):
        image = np.array(Image.open(f'{self.imgdir}/{self.imglist[idx]}'))
        if self.test == True:
            transformed = self.transform(image=image)
            image = transformed["image"]
            return image
        mask = np.array(Image.open(f'{self.maskdir}/{self.masklist[idx]}'))
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask

In [14]:
def get_dataset(imgdir, maskdir=None, train=True, val=False, test=False,
                transform=None, target_transform=None):
    dataset = Gleason(imgdir=imgdir, maskdir=maskdir, train=train,
                      val=val, test=test, transform=transform, target_transform=target_transform)
    return dataset


def get_transform(train):
    if train:
        return A.Compose([
            A.Resize(width=256, height=256, interpolation=cv2.INTER_LINEAR),
            A.HorizontalFlip(),
            A.geometric.rotate.Rotate(limit=(-15, -15), value=(255,255,255), p=1),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
        A.Resize(width=256, height=256, interpolation=cv2.INTER_LINEAR),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
        ToTensorV2(), # numpy.array -> torch.tensor (B, 3, H, W)
        ])


In [15]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [16]:
#metrics
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
    # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
    assert output.shape == target.shape
    #output = output.view(-1)
    #target = target.view(-1)
    output[target == ignore_index] = ignore_index
    intersection = output[output == target]
    area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
    area_output = torch.histc(output, bins=K, min=0, max=K-1)
    area_target = torch.histc(target, bins=K, min=0, max=K-1)
    area_union = area_output + area_target - area_intersection
    return area_intersection, area_union, area_target

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model(n_classes=6).to(device)

In [18]:
# img = os.listdir('/content/drive/MyDrive/Dataset/zoomx20/Image2')
# mask = os.listdir('/content/drive/MyDrive/Dataset/zoomx20/Mask2')
# a = []

# for i in img:
#   t = i.replace('.jpg', '_classimg_nonconvex.png')
#   if t not in mask:
#     a.append(i)

In [19]:
#load data
batch_size = 7
n_workers = os.cpu_count()
print("num_workers =", n_workers)
train_dataset = get_dataset(imgdir='/content/drive/MyDrive/Dataset/zoomx20/Image2',
                        maskdir='/content/drive/MyDrive/Dataset/zoomx20/Mask2',
                        train=True, val=False, test=False, transform=get_transform(train=False))

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

#optimizer
base_lr = 1e-3
n_eps = 21
optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(trainloader) * n_eps)) ** 0.9)

#meter
train_loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()

num_workers = 2


In [20]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Model/Savemodel/Model_ep9.pth'))

<All keys matched successfully>

In [None]:
start_time = time.time()
#training script
for ep in range(11, n_eps):
    train_loss_meter.reset()
    intersection_meter.reset()
    union_meter.reset()
    target_meter.reset()
    model.train()
    max_iter = n_eps * len(trainloader)
    for batch_id, (x, y) in enumerate(tqdm(trainloader), start=1):
        #qua trinh hoc mo hinh theo batch
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat_mask, main_loss, aux_loss = model(x, y)
        loss = main_loss + aux_loss*0.4
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        #save metrics
        with torch.no_grad():
            train_loss_meter.update(loss.item())
            intersection, union, target = intersectionAndUnionGPU(y_hat_mask.float(), y.float(), 6)
            intersection_meter.update(intersection)
            union_meter.update(union)
            target_meter.update(target)
    #compute iou, dice
    with torch.no_grad():
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) #vector 6D
        dice_class = (2 * intersection_meter.sum) / (intersection_meter.sum + union_meter.sum + 1e-10) #vector 6D
        mIoU = torch.mean(iou_class) #mean vector 6D
        mDice = torch.mean(dice_class) #mean vector 6D

    print(f"\nEP {ep}, train loss = {train_loss_meter.avg}, mIoU = {round(mIoU.item(), 4)}, mDice = {round(mDice.item(), 4)}, Accurancy = {round(accuracy.item(), 4)}")
    if ep % 5==0:
      torch.save(model.state_dict(), f'/content/drive/MyDrive/Model/Savemodel/Model_ep{ep}.pth')

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

 44%|████▍     | 190/429 [15:04<17:51,  4.49s/it]

**<center>RUN TEST MODEL ON VALIDATION DATASET</center>**

In [None]:
Val_train = get_dataset(imgdir='/content/drive/MyDrive/MyProject/Val/val_dataset',
                        maskdir='/content/drive/MyDrive/MyProject/Val/val_mask',
                        train=False,
                        val=True,
                        test=False,
                        transform=get_transform(train=False))
Valloader = torch.utils.data.DataLoader(Val_train, batch_size=batch_size,
                                          shuffle=False, num_workers=n_workers)

In [None]:
model.eval()
test_intersection_meter = AverageMeter()
test_union_meter = AverageMeter()
test_target_meter = AverageMeter()
with torch.no_grad():
    for batch_id, (x, y) in enumerate(tqdm(Valloader), start=1):
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat = model(x)
        y_hat = y_hat.squeeze(1)
        y_hat_mask = y_hat.argmax(dim=1)

        intersection, union, target = intersectionAndUnionGPU(y_hat_mask, y, 6)
        test_intersection_meter.update(intersection)
        test_union_meter.update(union)
        test_target_meter.update(target)

    accuracy = sum(test_intersection_meter.val) / (sum(test_target_meter.val) + 1e-10)
    iou_class = test_intersection_meter.sum / (test_union_meter.sum + 1e-10)
    dice_class = 2*test_intersection_meter.sum / (test_intersection_meter.sum + test_union_meter.sum + 1e-10)
    mIoU = torch.mean(iou_class)
    mDice = torch.mean(dice_class)

print(f"\nTEST: mIoU = {round(mIoU.item(), 4)}, mDice = {round(mDice.item(), 4)}, Accurancy = {round(accuracy.item(), 4)}")

**<center>SHOW PREDICT IMAGE</center>**

In [None]:
#predict
test_dataset = get_dataset(imgdir='/content/drive/MyDrive/MyProject/Val/slide002_core037/img',
                        maskdir='/content/drive/MyDrive/MyProject/Val/slide002_core037/mask',
                        train=False,
                        val=True,
                        test=False,
                        transform=get_transform(train=False))
id = np.random.randint(16)
with torch.no_grad():
    model.eval()
    test_intersection_meter = AverageMeter()
    test_union_meter = AverageMeter()
    test_target_meter = AverageMeter()

    x, y = test_dataset.__getitem__(id)
    xs = x.unsqueeze(0).to(device).float()
    y = y.to(device).long()
    y_hat = model(xs).argmax(dim=1).squeeze()

    intersection, union, target = intersectionAndUnionGPU(y_hat, y, 6)
    test_intersection_meter.update(intersection)
    test_union_meter.update(union)
    test_target_meter.update(target)

    accuracy = sum(test_intersection_meter.val) / (sum(test_target_meter.val) + 1e-10)
    iou_class = test_intersection_meter.sum / (test_union_meter.sum + 1e-10)
    dice_class = 2*test_intersection_meter.sum / (test_intersection_meter.sum + test_union_meter.sum + 1e-10)
    mIoU = torch.mean(iou_class)
    mDice = torch.mean(dice_class)

    print(f"\nTEST: mIoU = {round(mIoU.item(), 4)}, mDice = {round(mDice.item(), 4)},\
     Accurancy = {round(accuracy.item(), 4)}")
    y_hat = y_hat.cpu().numpy()

    colors =["gray", "green", "blue", "blue", "gold", "red"]
    fig, axes  = plt.subplots(1, 3, figsize=(10, 6))
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    axes[0].imshow(Image.fromarray((x[0] * 255).astype(np.uint8)))
    axes[1].imshow(y, cmap=matplotlib.colors.ListedColormap(colors), interpolation="none", vmin=0, vmax=5)
    axes[2].imshow(y_hat, cmap=matplotlib.colors.ListedColormap(colors), interpolation="none", vmin=0, vmax=5)
    plt.show()