<a href="https://colab.research.google.com/github/leomensah/lung-cancer-classification/blob/main/TRANSFER_LEARNING_WITH_RESNET_ENCODER_FOR_LUNG_CANCER_CLASSIFICATION.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tifffile as tiff
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.transforms.functional as tf
from tqdm.notebook import tqdm


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [None]:
class ResNetEncoder(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super(ResNetEncoder, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer0 = nn.Sequential(self.conv1, self.bn1, self.relu)
        self.layer1 = nn.Sequential(self.maxpool, self._make_layer(block, 64, layers[0]))
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.out_dim = 512 * block.expansion

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

def resnet18(pretrained=True, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
    return model


def resnet34(pretrained=True, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
    return model


def resnet50(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
    return model


def resnet101(pretrained=True, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
    return model


def resnet152(pretrained=True, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
    return model

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, num_inputs, num_filters, bn=True, kernel_size=3, stride=1,
                 padding=None, transpose=False, dilation=1):
        super(ConvLayer, self).__init__()
        if padding is None:
            padding = (kernel_size-1)//2 if transpose is not None else 0
        if transpose:
            self.layer = nn.ConvTranspose2d(num_inputs, num_filters, kernel_size=kernel_size,
                                            stride=stride, padding=padding, dilation=dilation)
        else:
            self.layer = nn.Conv2d(num_inputs, num_filters, kernel_size=kernel_size,
                                   stride=stride, padding=padding)
        nn.init.kaiming_uniform_(self.layer.weight, a=np.sqrt(5))
        self.bn_layer = nn.BatchNorm2d(num_filters) if bn else None

    def forward(self, x):
        out = self.layer(x)
        out = F.relu(out)
        return out if self.bn_layer is None else self.bn_layer(out)
    
class ConcatLayer(nn.Module):
    def forward(self, x, dim=1):
        return torch.cat(list(x.values()), dim=dim)
    
class LambdaLayer(nn.Module):
    def __init__(self, f):
        super(LambdaLayer, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

def upconv2x2(inplanes, outplanes, size=None, stride=1):
    if size is not None:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(size=size, mode='bilinear', align_corners=True)
        ] 
    else:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        ]

In [None]:
class DecoderConnect(nn.Module):
    def __init__(self, inplanes, output_size):
        super(DecoderConnect, self).__init__()
        self.bottom_process = nn.Sequential(
            ConvLayer(inplanes, inplanes * 2, kernel_size=3),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=3),
            *upconv2x2(inplanes * 2, inplanes, size=output_size)
        )
        self.concat_process = nn.Sequential(
            ConcatLayer(),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=1),
            ConvLayer(inplanes * 2, inplanes, kernel_size=3),
            ConvLayer(inplanes, inplanes, kernel_size=3)
        )
        
    def forward(self, x):
        decoder_input = self.bottom_process(x)
        return self.concat_process({0: x, 1: decoder_input})

In [None]:
class DynamicUNet(nn.Module):
    def __init__(self, encoder, input_size=(224, 224), num_output_channels=None, verbose=0):
        super(DynamicUNet, self).__init__()
        self.encoder = encoder
        self.verbose = verbose
        self.input_size = input_size
        self.num_input_channels = 3  # This must be 3 because we're using a ResNet encoder
        self.num_output_channels = num_output_channels
        
        self.decoder = self.setup_decoder()
        
    def forward(self, x):
        encoder_outputs = []
        def encoder_output_hook(self, input, output):
            encoder_outputs.append(output)

        handles = [
            child.register_forward_hook(encoder_output_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]

        try:
            self.encoder(x)
        finally:
            if self.verbose >= 1:
                print("Removing all forward handles")
            for handle in handles:
                handle.remove()

        prev_output = None
        for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
            if prev_output is not None:
                prev_output = rdl({0: reo, 1: prev_output})
            else:
                prev_output = rdl(reo)
        return prev_output
                
    def setup_decoder(self):
        input_sizes = []
        output_sizes = []
        def shape_hook(self, input, output):
            input_sizes.append(input[0].shape)
            output_sizes.append(output.shape)

        handles = [
            child.register_forward_hook(shape_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]    

        self.encoder.eval()
        test_input = torch.randn(1, self.num_input_channels, *self.input_size)
        try:
            self.encoder(test_input)
        finally:
            if self.verbose >= 1:
                print("Removing all shape hook handles")
            for handle in handles:
                handle.remove()
        decoder = self.construct_decoder(input_sizes, output_sizes, num_output_channels=self.num_output_channels)
        return decoder
        
    def construct_decoder(self, input_sizes, output_sizes, num_output_channels=None):
        decoder_layers = []
        for layer_index, (input_size, output_size) in enumerate(zip(input_sizes, output_sizes)):
            upsampling_size_factor = int(input_size[-1] / output_size[-1])
            upsampling_channel_factor = input_size[-3] / output_size[-3]
            next_layer = []
            bs, c, h, w = input_size
            ops = []
            if layer_index == len(input_sizes) - 1:
                last_layer_ops = DecoderConnect(output_size[-3], output_size[2:])
                last_layer_ops_input = torch.randn(*output_size)
                last_layer_concat_ops_output = last_layer_ops(last_layer_ops_input)
                next_layer.extend([last_layer_ops])
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    last_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    last_layer_concat_upconv_op_output = nn.Sequential(*last_layer_concat_upconv_op)(
                        last_layer_concat_ops_output
                    )
                    next_layer.extend(last_layer_concat_upconv_op)
            elif layer_index == 0:
                first_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    *upconv2x2(
                        output_size[-3] * 2,
                        output_size[-3],
                        size=[dim * upsampling_size_factor for dim in output_size[2:]]
                    ),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3),
                    ConvLayer(
                        output_size[-3],
                        input_size[-3] if self.num_output_channels is None else self.num_output_channels,
                        kernel_size=1
                    ),
                ]
                first_layer_concat_ops_output = nn.Sequential(*first_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(first_layer_concat_ops)
            else:
                middle_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    ConvLayer(output_size[-3] * 2, output_size[-3], kernel_size=3),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3)
                ]
                middle_layer_concat_ops_output = nn.Sequential(*middle_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(middle_layer_concat_ops)
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    middle_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    middle_layer_concat_upconv_op_output = nn.Sequential(*middle_layer_concat_upconv_op)(
                        middle_layer_concat_ops_output
                    )
                    next_layer.extend(middle_layer_concat_upconv_op)
            decoder_layers.append(nn.Sequential(*next_layer))
        return nn.ModuleList(reversed(decoder_layers))

In [None]:
!pip install albumentations==0.4.6

Collecting albumentations==0.4.6
  Downloading albumentations-0.4.6.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 33.3 MB/s 
Collecting imgaug>=0.4.0
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[K     |████████████████████████████████| 948 kB 55.7 MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-py3-none-any.whl size=65174 sha256=0c505645e8dac9c7bd0132094d6a7722091bc83d6dae80f0de350dea7d57a3e1
  Stored in directory: /root/.cache/pip/wheels/cf/34/0f/cb2a5f93561a181a4bcc84847ad6aaceea8b5a3127469616cc
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
  Attempting uninstall: albumentations
    Found existing installation: albu

In [None]:
import os
import numpy as np
import glob
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms.functional as TF
import torchvision
from torchvision import transforms

import albumentations as albu
from albumentations.pytorch import ToTensorV2

class MyLidcDataset(Dataset):
  def __init__(self, images_paths, mask_paths):
    self.image_paths = images_paths
    self.mask_paths = mask_paths

    self.albu_transformations =  albu.Compose([
            albu.ElasticTransform(alpha=1.1,alpha_affine=0.5,sigma=5,p=0.15),
            albu.HorizontalFlip(p=0.15),
            ToTensorV2()
        ])

    self.transformations = transforms.Compose([transforms.ToTensor()])

  def transform(self, image, mask):
    image = image.reshape(512,512,3)
    mask = mask.reshape(512,512,3)
    mask = mask.astype('uint8')
    augmented=  self.albu_transformations(image=image,mask=mask)
    image = augmented['image']
    mask = augmented['mask']
    mask= mask.reshape([3,512,512])
    image, mask = image.type(torch.FloatTensor), mask.type(torch.FloatTensor)
    return image, mask

  def __getitem__(self, index):
    image = np.load(self.image_paths[index])
    mask = np.load(self.mask_paths[index])
    image, mask = self.transform(image, mask)
    return image, mask

  def __len__(self):
    return len(self.image_paths)

In [None]:
alpha = pd.read_csv('/content/drive/MyDrive/data/Meta_Beta/meta.csv')
alpha.head(5)

Unnamed: 0.1,Unnamed: 0,index,patient_id,nodule_no,slice_no,original_image,mask_image,malignancy,is_cancer,is_clean,is_nodule,data_split
0,0,0,1,0,0,0001_NI000_slice000,0001_MA000_slice000,5,True,False,True,Train
1,1,1,1,0,1,0001_NI000_slice001,0001_MA000_slice001,5,True,False,True,Train
2,2,2,1,0,2,0001_NI000_slice002,0001_MA000_slice002,5,True,False,True,Train
3,3,3,1,0,3,0001_NI000_slice003,0001_MA000_slice003,5,True,False,True,Train
4,4,4,1,0,4,0001_NI000_slice004,0001_MA000_slice004,5,True,False,True,Train


In [None]:
image_dir = '/content/drive/MyDrive/data/RGB_IMAGES/Image/'

In [None]:
alpha['original_image'][0]

'0001_NI000_slice000'

In [None]:
alpha['original_image_prime']= alpha['original_image'].apply(lambda x:image_dir + x +'.npy')
alpha.head(5)

Unnamed: 0.1,Unnamed: 0,index,patient_id,nodule_no,slice_no,original_image,mask_image,malignancy,is_cancer,is_clean,is_nodule,data_split,original_image_prime
0,0,0,1,0,0,0001_NI000_slice000,0001_MA000_slice000,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...
1,1,1,1,0,1,0001_NI000_slice001,0001_MA000_slice001,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...
2,2,2,1,0,2,0001_NI000_slice002,0001_MA000_slice002,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...
3,3,3,1,0,3,0001_NI000_slice003,0001_MA000_slice003,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...
4,4,4,1,0,4,0001_NI000_slice004,0001_MA000_slice004,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...


In [None]:
alpha['original_image_prime'][0]

'/content/drive/MyDrive/data/RGB_IMAGES/Image/0001_NI000_slice000.npy'

In [None]:
mask_dir = '/content/drive/MyDrive/data/RGB_IMAGES/Mask/'

In [None]:
alpha['mask_image_prime'] = alpha['original_image'].apply(lambda x:mask_dir+ x +'.npy')

In [None]:
alpha.head(5)

Unnamed: 0.1,Unnamed: 0,index,patient_id,nodule_no,slice_no,original_image,mask_image,malignancy,is_cancer,is_clean,is_nodule,data_split,original_image_prime,mask_image_prime
0,0,0,1,0,0,0001_NI000_slice000,0001_MA000_slice000,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...,/content/drive/MyDrive/data/RGB_IMAGES/Mask/00...
1,1,1,1,0,1,0001_NI000_slice001,0001_MA000_slice001,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...,/content/drive/MyDrive/data/RGB_IMAGES/Mask/00...
2,2,2,1,0,2,0001_NI000_slice002,0001_MA000_slice002,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...,/content/drive/MyDrive/data/RGB_IMAGES/Mask/00...
3,3,3,1,0,3,0001_NI000_slice003,0001_MA000_slice003,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...,/content/drive/MyDrive/data/RGB_IMAGES/Mask/00...
4,4,4,1,0,4,0001_NI000_slice004,0001_MA000_slice004,5,True,False,True,Train,/content/drive/MyDrive/data/RGB_IMAGES/Image/0...,/content/drive/MyDrive/data/RGB_IMAGES/Mask/00...


In [None]:
alpha['mask_image_prime'][0]

'/content/drive/MyDrive/data/RGB_IMAGES/Mask/0001_NI000_slice000.npy'

In [None]:
import pandas as pd

# Directory of Image, Mask folder generated from the preprocessing stage ###
image_dir = '/content/drive/MyDrive/data/RGB_IMAGES/Image/'
mask_dir = '/content/drive/MyDrive/data/RGB_IMAGES/Mask/'
#mask_dir = '/content/drive/MyDrive/data/Nodule_data/mask/'
meta = pd.read_csv('/content/drive/MyDrive/data/Meta_Beta/meta.csv')

meta['original_image_prime']= meta['original_image'].apply(lambda x:image_dir+ x +'.npy')
meta['mask_image_prime'] = meta['original_image'].apply(lambda x:mask_dir+ x +'.npy')

train_meta = meta[meta['data_split']=='Train']
val_meta = meta[meta['data_split']=='Validation']

# Get all *npy images into list for Train
train_image_paths = list(train_meta['original_image_prime'])
train_mask_paths = list(train_meta['mask_image_prime'])

# Get all *npy images into list for Validation
val_image_paths = list(val_meta['original_image_prime'])
val_mask_paths = list(val_meta['mask_image_prime'])

print("*"*50)
print("The length of image: {}, mask folders: {} for train".format(len(train_image_paths),len(train_mask_paths)))
print("The length of image: {}, mask folders: {} for validation".format(len(val_image_paths),len(val_mask_paths)))
print("Ratio between Val/ Train is {:2f}".format(len(val_image_paths)/len(train_image_paths)))
print("*"*50)

print(train_image_paths[0])
print(train_mask_paths[0])

# Create Dataset
train_dataset = MyLidcDataset(train_image_paths, train_mask_paths)
val_dataset = MyLidcDataset(val_image_paths,val_mask_paths)

batch_size = 4
tr_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
    num_workers=2)

val_dl = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=2)

**************************************************
The length of image: 8126, mask folders: 8126 for train
The length of image: 2927, mask folders: 2927 for validation
Ratio between Val/ Train is 0.360202
**************************************************
/content/drive/MyDrive/data/RGB_IMAGES/Image/0001_NI000_slice000.npy
/content/drive/MyDrive/data/RGB_IMAGES/Mask/0001_NI000_slice000.npy


In [None]:
class DiceLoss(nn.Module):
    """
    Module to compute the Dice segmentation loss. Based on the following discussion:
    https://discuss.pytorch.org/t/one-hot-encoding-with-autograd-dice-loss/9781
    """
    def __init__(self, weights=None, ignore_index=None, eps=0.0001):
        super(DiceLoss, self).__init__()
        self.weights = weights
        self.ignore_index = ignore_index
        self.eps = eps
        
    def forward(self, output, target):
        """
        Arguments:
            output: (N, C, H, W) tensor of probabilities for the predicted output
            target: (N, H, W) tensor corresponding to the pixel-wise labels
        Returns:
            loss: the Dice loss averaged over channels
        """ 
        encoded_target = output.detach() * 0
        if self.ignore_index is not None:
            mask = target == self.ignore_index
            target = target.clone()
            target[mask] = 0
            encoded_target.scatter_(1, target.unsqueeze(1), 1)
            mask = mask.unsqueeze(1).expand_as(encoded_target)
            encoded_target[mask] = 0
        else:
            encoded_target.scatter_(1, target.unsqueeze(1), 1)

        if self.weights is None:
            self.weights = 1

        intersection = output * encoded_target
        numerator = 2 * intersection.sum(0).sum(1).sum(1)
        denominator = output + encoded_target

        if self.ignore_index is not None:
            denominator[mask] = 0
        denominator = denominator.sum(0).sum(1).sum(1) + self.eps
        loss_per_channel = self.weights * (1 - (numerator / denominator))

        return loss_per_channel.sum() / output.size(1)

In [None]:
def dice_similarity(output, target, weights=None, ignore_index=None, eps=1e-8):
    """
    Arguments:
        output: (N, C, H, W) tensor of model output
        target: (N, H, W) tensor corresponding to the pixel-wise labels
    Returns:
        loss: the Dice loss averaged over channels
    """ 
    prediction = torch.softmax(output, dim=1)
    encoded_prediction = output.detach() * 0
    encoded_prediction.scatter_(1, prediction.unsqueeze(1), 1)
    
    encoded_target = output.detach() * 0
    print(encoded_target)
    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = mask.unsqueeze(1).expand_as(encoded_target)
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)

    if weights is None:
        weights = 1

    intersection = encoded_prediction * encoded_target
    numerator = 2 * intersection.sum(0).sum(1).sum(1) + eps
    denominator = intersection + encoded_target
    if ignore_index is not None:
        denominator[mask] = 0
    denominator = denominator.sum(0).sum(1).sum(1) + eps
    acc_per_channel = weights * ((numerator / denominator))

    return acc_per_channel.sum() / output.size(1)

In [None]:
model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(512, 512))

In [None]:
list(model.parameters())[-1]

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [None]:
def iou_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)

def dice_coef(output, target):
    smooth = 1e-5

    # we need to use sigmoid because the output of Unet is logit.
    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()
    

    return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)

def dice_coef2(output, target):
    "This metric is for validation purpose"
    smooth = 1e-5

    output = output.view(-1)
    output = (output>0.5).float().cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()
    

    return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)

def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['BCEDiceLoss']

class BCEDiceLoss(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, input, target):
    bce = F.binary_cross_entropy_with_logits(input, target)
    smooth = 1e-5
    input = torch.sigmoid(input)
    num = target.size(0)
    input = input.view(num, -1)
    target = target.view(num, -1)
    intersection = (input * target)
    dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
    dice = 1 - dice.sum() / num

    return 0.5 * bce + dice

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
# Class for handling the dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, resize_shape=(360, 480), is_train=True):
        self.images, self.labels = [tpl[0] for tpl in data], \
                                   [tpl[1] for tpl in data]
        self.resize_shape = resize_shape
        self.is_train = is_train

    def transform(self, index):
        input, target = map(
            Image.open, (self.images[index], self.labels[index]))
        input, target = (
            tf.resize(input, self.resize_shape),
            tf.resize(target, self.resize_shape, interpolation=Image.NEAREST)
        )
        if self.is_train:
            horizontal_draw = torch.rand(1).item()
            vertical_draw = torch.rand(1).item()
            if horizontal_draw > 0.5:
                input, target = tf.hflip(input), tf.hflip(target)
            if vertical_draw > 0.5:
                input, target = tf.vflip(input), tf.vflip(target)
        
        input, target = map(tf.to_tensor, (input, target))
        torch.clamp((255 * target), 0, 32, out=target)
        return tf.normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), target.long()

    def __getitem__(self, index):
        return self.transform(index)

    def __len__(self):
        return len(self.images)

In [None]:
model = DynamicUNet(resnet34(), num_output_channels=3, input_size=(512, 512))
if torch.cuda.is_available():
    model = model.cuda()

decoder_parameters = [item for module in model.decoder for item in module.parameters()]
optimizer = optim.AdamW(decoder_parameters)  # Only training the decoder for now

criterion = DiceLoss()

# Training specific parameters
num_epochs = 10
num_up_epochs, num_down_epochs = 3, 7
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-2,
    total_steps=num_epochs * len(tr_dl),
)


In [None]:
model.train()
losses = []
accuracies = []

tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
    tr_loss, tr_correct_pixels, tr_total_pixels, tr_dice_similarity, total = 0., 0., 0., 0., 0.
    tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
    for i, (x, y) in enumerate(tqdm_epoch_iterator):
        optimizer.zero_grad()
        print(y.shape)
        if torch.cuda.is_available():
          print("====== GPU is Here =========")
          x, y = x.cuda(), y.squeeze(dim=1).cuda()
        output = model(x)
        prediction = torch.softmax(output, dim=1)
        print(prediction.shape)
        print(y.shape)
        tr_correct_pixels += ((prediction == y).sum())
        tr_total_pixels += y.numel()
        #tr_dice_similarity += dice_similarity(output, y.squeeze(1)) * len(y)
        loss = criterion(output, y.squeeze(1))
        tr_loss += loss.data.cpu() * len(y)
        total += len(y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if i % 1 == 0:
            curr_loss = tr_loss / total
            curr_acc = tr_correct_pixels / tr_total_pixels
            #curr_dice = tr_dice_similarity / total
            tqdm_epoch_iterator.set_postfix({
                "Loss": curr_loss.item(), "Accuracy": curr_acc.item()
            })
    overall_loss = tr_loss.item() / total
    overall_acc = tr_correct_pixels.item() / tr_total_pixels
    losses.append(overall_loss)
    accuracies.append(overall_acc)
    tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})