In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Aug 29 08:46:21 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    24W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2 # np.array -> torch.tensor
import os
import os.path as osp
from PIL import Image
from torchvision import transforms as T
from tqdm import tqdm
from glob import glob
import datetime
import time

## <center>Multi-scale Parallel Branch PsPnet and Fully Convolutional DenseNets</center>

In [4]:
class Layer(nn.Module):
    def __init__(self, in_ch, kernel_s=3, padding='same', dilation=1, stride=1, grow_rate=16, dropRate=0.2):
        super(Layer, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(in_ch)
        self.conv = nn.Conv2d(in_ch, grow_rate, kernel_size=kernel_s, stride=stride,
                               padding=padding, dilation=dilation, bias=False)
        self.droprate = nn.Dropout(p=dropRate)
    def forward(self, x):
        x_out = self.conv(self.relu(self.bn(x)))
        x_out = self.droprate(x_out)
        return x_out

In [5]:
class DenseBlock(nn.Module):
    def __init__(self, in_ch, kernel_s=3, padding='same', dilation=1, stride=1, grow_rate=16, n_layers=4, Upsample=False):
        super(DenseBlock, self).__init__()
        self.upsample = Upsample
        self.layers = nn.ModuleList([Layer(in_ch + i*grow_rate, kernel_s, padding, dilation, stride, grow_rate)
                                     for i in range(n_layers)])
        self.n_layers = n_layers

    def forward(self, x):
        if self.upsample:
            new_features = []
            for layer in self.layers:
                x_out = layer(x)
                x = torch.cat([x, x_out], 1)
                new_features.append(x_out)
            return torch.cat(new_features,1)
        else:
            layer_arr = [x]
            for i in range(self.n_layers):
                x_out = self.layers[i](x)
                layer_arr.append(x_out)
                if i == self.n_layers - 1:
                    x = torch.cat(layer_arr, 1)
                else:
                    x = torch.cat([x, x_out], 1)
            return x


In [6]:
class TransitionDown(nn.Module):
    def __init__(self, in_ch, dropRate=0.15):
        super(TransitionDown, self).__init__()
        self.bn = nn.BatchNorm2d(in_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1, padding='same', bias=False)
        self.pooling =  nn.MaxPool2d(kernel_size=2, stride=2)
        self.droprate = nn.Dropout(p=dropRate)
    def forward(self, x):
        x_out = self.conv(self.relu(self.bn(x)))
        x_out = self.droprate(x_out)
        x_out = self.pooling(x_out)
        return x_out

In [7]:
def center_crop(layer, max_height, max_width):
    _, _, h, w = layer.size()
    xy1 = (w - max_width) // 2
    xy2 = (h - max_height) // 2
    return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]

In [8]:
class TransitionUp(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_s=3):
        super(TransitionUp, self).__init__()
        self.transpose = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_s, stride=2, padding=0, bias=True)
    def forward(self, x, skip_connection):
        x_out = self.transpose(x)
        x_out = center_crop(x_out, skip_connection.size(2), skip_connection.size(3))
        x_out = torch.cat([x_out, skip_connection], 1)
        return x_out

In [9]:
class MPB_FCDenseNet(nn.Module):
    def __init__(self, in_ch=3, down_blocks=(4, 5, 6, 7, 8),
                 up_blocks=(8 , 7 , 6, 5 ,4), bottleneck_layers=10,
                 grow_rate=16, kernel_s=[3, 5, 9], padding='same', dilation=1, stride=1, m=48, n_classes=6):
        super(MPB_FCDenseNet, self).__init__()
        self.down_blocks = down_blocks
        self.up_blocks = up_blocks
        self.criterion = nn.CrossEntropyLoss(ignore_index=255)
        #   First Convolution   #
        #########################
        self.first_conv = nn.Conv2d(in_channels=in_ch, out_channels=m, kernel_size=3, stride=1, padding=1, bias=True)
        #########################################################################
        #############################   Multi Gate  #############################
        current_ch = m
        skip_ch = []
        #   Downsampling    #
        #####################
        self.DB_down_1 = nn.ModuleList([])
        self.TD_1 = nn.ModuleList([])
        for i in range(len(down_blocks)):
            if i == 0:
                dilation_d = 1
            elif i%2 == 0:
                dilation_d = 4
            else:
                dilation_d = 2
            self.DB_down_1.append(DenseBlock(current_ch, kernel_s[0], padding, dilation_d, stride, grow_rate, down_blocks[i], False))
            current_ch +=(down_blocks[i]*grow_rate)
            skip_ch.insert(0,current_ch)
            self.TD_1.append(TransitionDown(current_ch))
        #   bottleneck_1   #
        ##################
        # Layer : DB (15 layers), m = 896
        self.bottleneck = DenseBlock(current_ch, kernel_s[0], padding, 2, 1, grow_rate, bottleneck_layers, True)
        prev_ch = (bottleneck_layers*grow_rate)
        current_ch += prev_ch
        #   Upsampling path   #
        #######################
        self.DB_up_1 = nn.ModuleList([])
        self.TU_1 = nn.ModuleList([])
        for i in range(len(up_blocks)-1):
            kernel_tu = 3
            self.TU_1.append(TransitionUp(prev_ch, prev_ch, kernel_tu))
            current_ch = prev_ch + skip_ch[i]
            self.DB_up_1.append(DenseBlock(current_ch, kernel_s[0], padding, 2, 1, grow_rate, up_blocks[i], True))
            prev_ch = grow_rate*up_blocks[i]
            current_ch += prev_ch
        #   Final DenseBlock    #
        #########################
        self.TU_1.append(TransitionUp(prev_ch, prev_ch, 3))
        current_ch = prev_ch + skip_ch[-1]
        self.DB_up_1.append(DenseBlock(current_ch, kernel_s[0], padding, 2, 1, grow_rate, up_blocks[-1], False))
        current_ch += grow_rate*up_blocks[-1]

        ########################
        #   Aux loss
        if self.training:
            self.aux = nn.Sequential(
                    nn.Conv2d(288, 256, kernel_size=3, padding=1, bias=False),
                    nn.BatchNorm2d(256),
                    nn.ReLU(inplace=True),
                    nn.Dropout2d(p=0.2),
                    nn.Conv2d(256, n_classes, kernel_size=1)
                )
        self.finalConv = nn.Conv2d(in_channels=current_ch, out_channels=n_classes,
                                   kernel_size=1, stride=1, padding=0, bias=True)
    def forward(self, x, y=None):
        _,_,h,w = x.size()

        x_out = self.first_conv(x)
        skip_connections = []
        aux_in = None
        x_out_1 = x_out

        for i in range(len(self.down_blocks)):
            x_out_1 = self.DB_down_1[i](x_out_1)
            skip = x_out_1
            skip_connections.append(skip)

            x_out_1 = self.TD_1[i](x_out_1)
            if i == 2:
                aux_in = x_out_1

        x_out_1 = self.bottleneck(x_out_1)

        for i in range(len(self.up_blocks)):
            # Gate 1
            skip = skip_connections.pop()
            x_out_1 = self.TU_1[i](x_out_1, skip)
            x_out_1 = self.DB_up_1[i](x_out_1)
        x_out = x_out_1
        return x_out, aux_in
        # x_out = self.finalConv(x_out_1)
        # if self.training:
        #     aux = self.aux(aux_in)
        #     aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
        #     aux_loss = self.criterion(aux, y)
        #     main_loss = self.criterion(x_out, y)
        #     return x_out, aux_loss, main_loss
        # else:
        #    return x_out


In [10]:
def conv3x3(in_ch, out_ch, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_ch, out_ch, stride)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_ch, out_ch)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_ch, out_ch, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.conv3 = nn.Conv2d(out_ch, out_ch * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, deep_base=True):
        super(ResNet, self).__init__()
        self.deep_base = deep_base
        if not self.deep_base:
            self.in_ch = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.in_ch = 128
            self.conv1 = conv3x3(3, 64, stride=2)
            self.bn1 = nn.BatchNorm2d(64)
            self.conv2 = conv3x3(64, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = conv3x3(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, out_ch, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_ch != out_ch * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_ch, out_ch * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch * block.expansion),
            )

        layers = []
        layers.append(block(self.in_ch, out_ch, stride, downsample))
        self.in_ch = out_ch * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_ch, out_ch))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        if self.deep_base:
            x = self.relu(self.bn2(self.conv2(x)))
            x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet152(pretrained=True, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
        model_path = '/content/drive/MyDrive/Colab/resnet152_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
    return model

In [11]:
class PPM(nn.Module):
    def __init__(self, in_ch, out_ch, bins):
        super(PPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)

class PSPNet(nn.Module):
    def __init__(self, bins=(1, 2, 3, 6), dropout=0.15, classes=6, zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
        super(PSPNet, self).__init__()
        assert 2048 % len(bins) == 0
        assert classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor
        self.use_ppm = use_ppm
        self.criterion = criterion

        resnet = resnet152(pretrained=True)
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu,
                                    resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        fea_dim = 2048
        if use_ppm:
            self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins)
            fea_dim *= 2

        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding='same', dilation=2, stride=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
        )

    def forward(self, x, y=None):
        x_size = x.size()
        assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0
        h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
        w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)

        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)
        if self.use_ppm:
            x = self.ppm(x)
        x = self.cls(x)
        if self.zoom_factor != 1:
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
        return x, x_tmp

In [12]:
class MyModel(nn.Module):
    def __init__(self, n_classes=6):
        super(MyModel, self).__init__()
        self.criterion=nn.CrossEntropyLoss(ignore_index=255)
        self.pspnet = PSPNet()
        self.MPB = MPB_FCDenseNet()
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1312, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=0.15),
                nn.Conv2d(256, n_classes, kernel_size=1)
            )
        self.finalConv = nn.Conv2d(in_channels=768, out_channels=n_classes,
                                    kernel_size=1, stride=1, padding=0, bias=True)
    def forward(self,x ,y=None):
        x_fcdense = center_crop(x, x.size(2)-1, x.size(3)-1)
        _, _, h, w = x.size()
        pspnet, pspnet_aux = self.pspnet(x)
        fcdense, fcdense_aux = self.MPB(x_fcdense)

        pspnet_aux = center_crop(pspnet_aux, pspnet_aux.size(2)-1, pspnet_aux.size(3)-1)
        pspnet = center_crop(pspnet, pspnet.size(2)-1, pspnet.size(3)-1)

        x_out = torch.cat([pspnet, fcdense], dim=1)
        aux_in = torch.cat([pspnet_aux, fcdense_aux], dim=1)

        x_out = F.interpolate(x_out, size=(h, w), mode='bilinear', align_corners=True)
        x_out = self.finalConv(x_out)
        if self.training:
            aux = self.aux(aux_in)
            aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
            aux_loss = self.criterion(aux, y)
            main_loss = self.criterion(x_out, y)
            return x_out.max(1)[1], main_loss, aux_loss
        else:
           return x_out

In [13]:
class Gleason(Dataset):
    def __init__(self, imgdir, maskdir=None, train=True, val=False,
                 test=False, transform=None, target_transform=None):
        super(Gleason, self).__init__()
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.imglist = sorted(os.listdir(imgdir))

        if not test:
            self.masklist = [item.replace('.jpg', '_classimg_nonconvex.png') for item in self.imglist]
        else:
            self.masklist = []

        self.train = train
        self.val = val
        self.test = test
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imglist)

    def __getitem__(self, idx):
        image = np.array(Image.open(osp.join(self.imgdir, self.imglist[idx])))
        if self.test == True:
            transformed = self.transform(image=image)
            image = transformed["image"]
            return image

        mask = np.array(Image.open(osp.join(self.maskdir, self.masklist[idx])))
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask

In [14]:
def get_dataset(imgdir, maskdir=None, train=True, val=False, test=False,
                transform=None, target_transform=None):
    dataset = Gleason(imgdir=imgdir, maskdir=maskdir, train=train,
                      val=val, test=test, transform=transform, target_transform=target_transform)
    return dataset


def get_transform(train):
    if train:
        return A.Compose([
            A.Resize(width=257, height=257),
            A.HorizontalFlip(),
            A.RandomBrightnessContrast(),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
        A.Resize(width=257, height=257),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
        ToTensorV2(), # numpy.array -> torch.tensor (B, 3, H, W)
        ])


In [15]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [16]:
#metrics
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
    # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
    assert output.shape == target.shape
    output = output.view(-1)
    target = target.view(-1)
    output[target == ignore_index] = ignore_index
    intersection = output[output == target]
    area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
    area_output = torch.histc(output, bins=K, min=0, max=K-1)
    area_target = torch.histc(target, bins=K, min=0, max=K-1)
    area_union = area_output + area_target - area_intersection
    return area_intersection, area_union, area_target

In [17]:
#device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#load data
batch_size = 5
n_workers = os.cpu_count()
print("num_workers =", n_workers)
train_dataset = get_dataset(imgdir='/content/drive/MyDrive/MyProject/Train_cropped_8',
                        maskdir='/content/drive/MyDrive/MyProject/Mask_cropped_8',
                        train=True, val=False, test=False, transform=get_transform(train=False))

trainloader = DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=n_workers)

#model
model = MyModel().to(device)

#loss
criterion = nn.CrossEntropyLoss(ignore_index=255)

#optimizer
base_lr = 1e-3
n_eps = 21
optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(trainloader) * n_eps)) ** 0.9)

#meter
train_loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()

num_workers = 8


In [18]:
max_batch = len(os.listdir('/content/drive/MyDrive/MyProject/Train_cropped_8'))//2
max_batch

1289

In [None]:
start_time = time.time()
#training script
for ep in range(1, n_eps):
    train_loss_meter.reset()
    intersection_meter.reset()
    union_meter.reset()
    target_meter.reset()
    model.train()
    max_iter = n_eps * len(trainloader)
    if ep == 7:
      train_dataset_agu = get_dataset(imgdir='/content/drive/MyDrive/MyProject/Train_cropped_8',
                        maskdir='/content/drive/MyDrive/MyProject/Mask_cropped_8',
                        train=True, val=False, test=False, transform=get_transform(train=True))
      trainloader = DataLoader(train_dataset_agu, batch_size=batch_size, shuffle=True, num_workers=n_workers)

    for batch_id, (x, y) in enumerate(tqdm(trainloader), start=1):
        if batch_id == max_batch - 1:
            break
        #qua trinh hoc mo hinh theo batch
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat_mask, main_loss, aux_loss = model(x, y)
        loss = main_loss + aux_loss*0.4
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        #save metrics
        with torch.no_grad():
            train_loss_meter.update(loss.item())
            intersection, union, target = intersectionAndUnionGPU(y_hat_mask.float(), y.float(), 6)
            intersection_meter.update(intersection)
            union_meter.update(union)
            target_meter.update(target)
    #compute iou, dice
    with torch.no_grad():
        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) #vector 6D
        dice_class = (2 * intersection_meter.sum) / (intersection_meter.sum + union_meter.sum + 1e-10) #vector 6D
        mIoU = torch.round(torch.mean(iou_class), decimals=3) #mean vector 6D
        mDice = torch.round(torch.mean(dice_class), decimals=3) #mean vector 6D

    print(f"\nEP {ep}, train loss = {train_loss_meter.avg}, mIoU = {mIoU}, mDice = {mDice}")

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

100%|██████████| 645/645 [06:04<00:00,  1.77it/s]



EP 1, train loss = 1.389914178386215, mIoU = 0.2549999952316284, mDice = 0.3580000102519989


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 2, train loss = 1.1560389695703521, mIoU = 0.3240000009536743, mDice = 0.43299999833106995


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 3, train loss = 1.0329085405482803, mIoU = 0.37599998712539673, mDice = 0.4950000047683716


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 4, train loss = 0.9211631573909937, mIoU = 0.45399999618530273, mDice = 0.5789999961853027


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 5, train loss = 0.825453359928242, mIoU = 0.5, mDice = 0.6200000047683716


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 6, train loss = 0.7332797919363938, mIoU = 0.5569999814033508, mDice = 0.6660000085830688


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 7, train loss = 0.9408440104750699, mIoU = 0.45500001311302185, mDice = 0.5830000042915344


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 8, train loss = 0.8466503408297088, mIoU = 0.5, mDice = 0.621999979019165


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 9, train loss = 0.7594469275354415, mIoU = 0.5339999794960022, mDice = 0.6480000019073486


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 10, train loss = 0.7086426886477212, mIoU = 0.5580000281333923, mDice = 0.6660000085830688


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 11, train loss = 0.6696290494397629, mIoU = 0.5699999928474426, mDice = 0.675000011920929


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 12, train loss = 0.6122102198674697, mIoU = 0.6140000224113464, mDice = 0.7059999704360962


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 13, train loss = 0.5437817117271497, mIoU = 0.6349999904632568, mDice = 0.7200000286102295


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 14, train loss = 0.5013113162314244, mIoU = 0.6669999957084656, mDice = 0.7409999966621399


100%|██████████| 645/645 [05:00<00:00,  2.15it/s]



EP 15, train loss = 0.4722265360660331, mIoU = 0.6669999957084656, mDice = 0.7400000095367432


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 16, train loss = 0.44680915851463643, mIoU = 0.6840000152587891, mDice = 0.7509999871253967


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 17, train loss = 0.418000176756881, mIoU = 0.6869999766349792, mDice = 0.753000020980835


100%|██████████| 645/645 [04:59<00:00,  2.15it/s]



EP 18, train loss = 0.4194678149251051, mIoU = 0.6959999799728394, mDice = 0.7580000162124634


 11%|█         | 68/645 [00:32<04:27,  2.16it/s]

In [None]:
Val_train = get_dataset(imgdir='/content/drive/MyDrive/MyProject/Train_cropped_8_1',
                        maskdir='/content/drive/MyDrive/MyProject/Mask_cropped_8_1',
                        train=False,
                        val=True,
                        test=False,
                        transform=get_transform(train=False))
Valloader = torch.utils.data.DataLoader(Val_train, batch_size=batch_size,
                                          shuffle=False, num_workers=n_workers)

In [None]:
model.eval()
test_intersection_meter = AverageMeter()
test_union_meter = AverageMeter()
test_target_meter = AverageMeter()
with torch.no_grad():
    for batch_id, (x, y) in enumerate(tqdm(Valloader), start=1):
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat = model(x)
        y_hat = y_hat.squeeze(1)
        y_hat_mask = y_hat.argmax(dim=1)

        intersection, union, target = intersectionAndUnionGPU(y_hat_mask, y, 6)
        test_intersection_meter.update(intersection)
        test_union_meter.update(union)
        test_target_meter.update(target)

    iou_class = test_intersection_meter.sum / (test_union_meter.sum + 1e-10)
    dice_class = 2*test_intersection_meter.sum / (test_intersection_meter.sum + test_union_meter.sum + 1e-10)

    mIoU = torch.mean(iou_class)
    mDice = torch.mean(dice_class)

print("TEST: IoU = {}, dice = {}".format(mIoU, mDice))

In [None]:
from torchvision import models
from torchsummary import summary


summary(model, (3, 257, 257))