In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2 # np.array -> torch.tensor
import os
import os.path as osp
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import functional as Ft

from tqdm import tqdm
from glob import glob
import datetime
import time
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
import matplotlib

In [4]:
# VGG Cell
class conv(nn.Module):
    def __init__(self, in_ch, out_ch, num, stride=1, padding=1, dilation=1):
        super(conv, self).__init__()
        first_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=padding, dilation=dilation, bias=False),
                                    nn.BatchNorm2d(out_ch),
                                    nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.15)
                                )
        self.conv = nn.ModuleList([])
        self.conv.append(first_conv)
        for _ in range(num-1):
            self.conv.append(
                nn.Sequential(nn.Conv2d(out_ch, out_ch, 3, stride=stride, padding=padding, dilation=dilation, bias=False),
                            nn.BatchNorm2d(out_ch),
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.15)
                        )
            )
    def forward(self, x):
        for i in range(len(self.conv)):
            x = self.conv[i](x)
        return x

In [5]:
class FCN8s(nn.Module):
    def __init__(self, n_classes=6):
        super().__init__()
        self.n_class = n_classes
        # VGG BackBone
        self.conv1 = conv(3, 64, 2)
        # [N,3,224,224]->[N,64,256,256]
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [N,64,224,224]->[N,64,128,128]

        self.conv2 = conv(64, 128, 2)
        # [N,64,112,112]->[N,128,128,128]
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [N,128,112,112]->[N,128,64,64]

        self.conv3 = conv(128, 256, 3)
        # [N,128,56,56]->[N,256,64,64]
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [N,256,56,56]->[N,256,32,32]

        self.conv4 = conv(256, 512, 3, stride=1, padding=2, dilation=2)
        # [N,256,28,28]->[N,512,32,32]
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [N,512,28,28]->[N,512,16,16]

        self.conv5 = conv(512, 512, 3, stride=1, padding=2, dilation=2)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [N,512,14,14]->[N,512,8,8]

        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, padding=3),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.Sequential(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=128,
                                   kernel_size=1)
        self.upscore2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                           kernel_size=4, stride=2, padding=1)

        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=128,
                                     kernel_size=1)

        self.upscore_pool4 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                                kernel_size=4, stride=2, padding=1)
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=128,
                                     kernel_size=1)

        self.upscore8 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                           kernel_size=16, stride=8, padding=4)
        self.finalconv = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, padding=0))
    def forward(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)

        x6 = self.conv6(p5)
        x7 = self.conv7(x6)

        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)

        s4 = self.score_pool4(p4)
        f4 = torch.add(s4, u2)
        u4 = self.upscore_pool4(f4)

        s3 = self.score_pool3(p3)
        f3 = torch.add(s3, u4)
        u3 = self.upscore8(f3)

        out = self.finalconv(u3)
        return out, p3

In [6]:
# model = FCN8s()
# x = torch.rand(4, 3, 256, 256)
# y, aux = model(x)
# print(aux.size())

In [7]:
def conv3x3(in_ch, out_ch, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_ch, out_ch, stride)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_ch, out_ch)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_ch, out_ch, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.conv3 = nn.Conv2d(out_ch, out_ch * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, deep_base=True):
        super(ResNet, self).__init__()
        self.deep_base = deep_base
        if not self.deep_base:
            self.in_ch = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.in_ch = 128
            self.conv1 = conv3x3(3, 64, stride=2)
            self.bn1 = nn.BatchNorm2d(64)
            self.conv2 = conv3x3(64, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = conv3x3(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, out_ch, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_ch != out_ch * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_ch, out_ch * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch * block.expansion),
            )

        layers = []
        layers.append(block(self.in_ch, out_ch, stride, downsample))
        self.in_ch = out_ch * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_ch, out_ch))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        if self.deep_base:
            x = self.relu(self.bn2(self.conv2(x)))
            x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet152(pretrained=True, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
        model_path = '/content/drive/MyDrive/Dataset/resnet152_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
    return model

In [8]:
class PPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(PPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)

class PSPNet(nn.Module):
    def __init__(self, bins=(1, 2, 3, 6), dropout=0.15, n_classes=6, zoom_factor=8):
        super(PSPNet, self).__init__()
        assert 2048 % len(bins) == 0
        assert n_classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor

        resnet = resnet152(pretrained=True)
        self.layer0_0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.layer0_1 = nn.Sequential(resnet.conv2, resnet.bn2, resnet.relu)
        self.layer0_2 = nn.Sequential(resnet.conv3, resnet.bn3, resnet.relu,
                                        resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        fea_dim = 2048
        self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins)
        fea_dim *= 2
        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(512, 128, kernel_size=1)
        )
        self.finalconv = nn.Conv2d(128, 128, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        _, _, h, w = x.size()
        x0_0 = self.layer0_0(x)
        x0_1 = self.layer0_1(x0_0)
        x0_2 = self.layer0_2(x0_1)
        x1 = self.layer1(x0_2)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x_ppm = self.ppm(x4)
        x_cls = self.cls(x_ppm)
        x = F.interpolate(x_cls, size=(h, w), mode='bilinear', align_corners=True)
        x = self.finalconv(x)
        return x, x2

In [9]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)):
          self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list):
          self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average:
          return loss.mean()
        else:
          return loss.sum()

In [10]:
class MainModel(nn.Module):
    def __init__(self, n_classes=6, lossFunc=FocalLoss()):
        super(MainModel, self).__init__()
        self.criterion= lossFunc
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(768, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=0.15),
                nn.Conv2d(256, n_classes, kernel_size=1)
            )

        self.pspnet = PSPNet()
        self.fcn8s = FCN8s()
        self.conv = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, n_classes, kernel_size=1, stride=1, dilation=1, padding=0, bias=False),
        )

    def forward(self, x, y=None):
        _, _, h, w = x.size()
        psp, psp_aux = self.pspnet(x)
        fcn, fcn_aux = self.fcn8s(x)
        x_out = torch.cat([psp, fcn], dim=1)
        x_out = self.conv(x_out)

        aux_in = torch.cat([psp_aux, fcn_aux], dim=1)
        if self.training:
            aux = self.aux(aux_in)
            aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
            aux_loss = self.criterion(aux, y)
            main_loss = self.criterion(x_out, y)
            return x_out.max(1)[1], main_loss, aux_loss
        else:
            return x_out


In [11]:
class Gleason(Dataset):
    def __init__(self, imgdir, maskdir=None, train=True, val=False,
                 test=False, transform=None, target_transform=None):
        super(Gleason, self).__init__()
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.imglist = os.listdir(imgdir)
        if not test:
            self.masklist = [item.replace('.jpg', '_classimg_nonconvex.png') for item in self.imglist]
        else:
            self.masklist = []

        self.train = train
        self.val = val
        self.test = test
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imglist)

    def __getitem__(self, idx):
        image = np.array(Image.open(f'{self.imgdir}/{self.imglist[idx]}'))
        if self.test == True:
            transformed = self.transform(image=image)
            image = transformed["image"]
            return image
        mask = np.array(Image.open(f'{self.maskdir}/{self.masklist[idx]}'))
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask

In [12]:
def get_dataset(imgdir, maskdir=None, train=True, val=False, test=False,
                transform=None, target_transform=None):
    dataset = Gleason(imgdir=imgdir, maskdir=maskdir, train=train,
                      val=val, test=test, transform=transform, target_transform=target_transform)
    return dataset


def get_transform(train):
    if train:
        return A.Compose([
            A.Resize(width=256, height=256),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
            A.PadIfNeeded(min_height=256, min_width=256),
            A.RandomCrop(256, 256),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
        A.Resize(width=256, height=256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
        ToTensorV2(), # numpy.array -> torch.tensor (B, 3, H, W)
        ])


In [13]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [14]:
#metrics
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
    # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
    assert output.shape == target.shape
    output = output.view(-1)
    target = target.view(-1)
    output[target == ignore_index] = ignore_index
    intersection = output[output == target]
    area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
    area_output = torch.histc(output, bins=K, min=0, max=K-1)
    area_target = torch.histc(target, bins=K, min=0, max=K-1)
    area_union = area_output + area_target - area_intersection
    return area_intersection, area_union, area_target

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MainModel().to(device)

In [19]:
#load data
batch_size = 10
n_workers = os.cpu_count()
print("num_workers =", n_workers)
n_eps = 100
#optimizer
base_lr = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
#dataloader
train_dataset = get_dataset(imgdir=f'/content/drive/MyDrive/Dataset/Image_x4/Image',
                  maskdir=f'/content/drive/MyDrive/Dataset/Image_x4/Mask',
                  train=True, val=False, test=False, transform=get_transform(train=False))
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(trainloader) * n_eps)) ** 0.9)
max_iter = n_eps * len(trainloader)
#meter
train_loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()

num_workers = 2


In [None]:
start_time = time.time()
#training script
for ep in range(0, n_eps):
    train_loss_meter.reset()
    intersection_meter.reset()
    union_meter.reset()
    target_meter.reset()
    if ep > 15:
      train_dataset = get_dataset(imgdir=f'/content/drive/MyDrive/Dataset/Image_x4/Image',
                  maskdir=f'/content/drive/MyDrive/Dataset/Image_x4/Mask',
                  train=True, val=False, test=False, transform=get_transform(train=True))
      trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for batch_id, (x, y) in enumerate(tqdm(trainloader), start=1):
        #qua trinh hoc mo hinh theo batch
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat_mask, main_loss, aux_loss = model(x, y)
        loss = main_loss + aux_loss*0.4
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        #save metrics
        with torch.no_grad():
            train_loss_meter.update(loss.item())
            intersection, union, target = intersectionAndUnionGPU(y_hat_mask.float(), y.float(), 6)
            intersection_meter.update(intersection)
            union_meter.update(union)
            target_meter.update(target)
    #compute iou, dice
    with torch.no_grad():
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) #vector 6D
        dice_class = (2 * intersection_meter.sum) / (intersection_meter.sum + union_meter.sum + 1e-10) #vector 6D
        mIoU = torch.mean(iou_class) #mean vector 6D
        mDice = torch.mean(dice_class) #mean vector 6D

    print(f"EP {ep}, train loss = {train_loss_meter.avg}, mIoU = {round(mIoU.item(), 4)}, mDice = {round(mDice.item(), 4)}, Accurancy = {round(accuracy.item(), 4)}")
    if ep % 10==0:
      torch.save(model.state_dict(), f'/content/drive/MyDrive/Model/Savemodel/Model_ep{ep}.pth')

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

 90%|█████████ | 687/762 [3:48:42<24:48, 19.85s/it]