In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # Linear algebra
from PIL import Image, ImageEnhance # Image processing
import matplotlib.pyplot as plt # To plot batche images and its masks
import torch # PyTorch
from torch import nn # To add custom layers
from torch import optim # Optimizer to update model parameters
import torch.nn.functional as F # Functional requirements
from torch.utils.data import Dataset, DataLoader # To make custom dataset class & to batch data
from torchvision import models, transforms # Get pretrained models & to make data augmentation pipeline
from torch.utils.data import Subset # To divide data into train, val & test sets
import os # To iterate through dirs, file read-write stuff.

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
img_path = '../input/semantic-drone-dataset/semantic_drone_dataset/original_images/'
mask_path = '../input/semantic-drone-dataset/semantic_drone_dataset/label_images_semantic/'

In [3]:
# Data Augmentation

class RandomCrop(object):
    """
   A callable class to apply random cropping on a image and its mask.
   
   ...

   Attributes
   ----------
   min_crop_width : int
       Minimum width of a region to randomly crop from image.
   min_crop_height : int
       Minimum height of a region to randomly crop from image.

   Methods
   -------
   __call__(sample):
       Given tuple of input image and its corresponding mask image, randomly a region will be cropped from both image and mask.
       Width of this region will be b.w. min_crop_width and image width, analogously for height.
       This region will be extracted from input image and mask image.
       This cropped image and cropped mask will be returned as a tuple
       
   """


    def __init__(self, min_crop_width=224, min_crop_height=224):
        """
        Accepts all the necessary attributes for performing random cropping

        Parameters
        ----------
            min_crop_width : int
                Minimum width of a region to randomly crop from image.
            min_crop_height : int
                Minimum height of a region to randomly crop from image.
        """

        # Make sure min val of parameters are at least 224 pixels
        assert (min_crop_width >= 224 and min_crop_height >= 224)

        self.min_crop_width = min_crop_width
        self.min_crop_height = min_crop_height


    def __call__(self, sample):
        """
        Returns randomly cropped region from image and mask.

        Parameters
        ----------
        sample : tuple
            Tuple of image and its corresponding mask image

        Returns
        -------
        Tuple of randomly cropped image and mask
        """

        # Get input and its mask from arg.
        img, mask = sample[0], sample[1]
        width, height = img.size

        # Only perform random crop if random num is greater than 0.5
        if np.random.uniform() > 0.5:
            # Randomly choose width b.w. [self.min_crop_width, image width], analogously for height
            desired_width = np.random.randint(low=self.min_crop_width, high=width + 1)
            desired_height = np.random.randint(low=self.min_crop_height, high=height + 1)

            # Randomly select top left point of our region
            x = np.random.randint(low=0, high=width - desired_width + 1)
            y = np.random.randint(low=0, high=height - desired_height + 1)

            # Perform actual cropping on image and mask given above parameters
            img = img.crop((x, y, x + desired_width - 1, y + desired_height - 1))
            mask = mask.crop((x, y, x + desired_width - 1, y + desired_height - 1))

        return img, mask


class Resize(object):
    """
    A callable class to resize input image and its mask.
    
    ...

    Attributes
    ----------
    desired_width : int
        Desired width to resize the image to.
    desired_height : int
        Desired height to resize the image to.

    Methods
    -------
    __call__(sample):
        Given tuple of input image and its corresponding mask image.
        Width and height for resize operation will be provided while object init.
        Resized image and its mask will be returned as a tuple.
        
    """

    def __init__(self, desired_width=1024, desired_height=1024):
        """
        Accepts all the necessary attributes for performing resizing operation

        Parameters
        ----------
            desired_width : int
                Desired width to resize the image to.
            desired_height : int
                Desired height to resize the image to.
        """

        # Make sure parameter values are atlease 224 pixels
        assert (desired_width >= 224 and desired_height >= 224)

        self.desired_width = desired_width
        self.desired_height = desired_height

    def __call__(self, sample):
        """
        Returns resized image and mask.

        Parameters
        ----------
        sample : tuple
            Tuple of image and its corresponding mask image

        Returns
        -------
        Tuple of resized image and mask
        """

        # Get input and its mask from arg.
        img, mask = sample[0], sample[1]
        width, height = img.size

        # Performs actual resizing of image and mask
        img = img.resize((self.desired_width, self.desired_height), resample=Image.NEAREST)
        mask = mask.resize((self.desired_width, self.desired_height), resample=Image.NEAREST)

        return img, mask


class ColorTransform(object):
    """
    A callable class to apply color transform on just the input image, not mask image.
    
    ...

    Attributes
    ----------
    brightness : float
        Factor required to change the brightness of the image
    contrast : float
        Factor required to change the contrast of the image
    saturation : float
        Factor required to change the saturation of the image

    Methods
    -------
    __call__(sample):
        Given tuple of input image and its corresponding mask image.
        A number b.w. [max(0, 1 - self.brightness), 1 + self.brightness] generated for all three params.
        Apply color transform on input image based on above generated parameters.
        Returns transformed input image and its unchanged mask as a tuple.
        
    """

    def __init__(self, brightness=0., contrast=0., saturation=0.):
        """
        Accepts all the necessary attributes for performing color transform operation

        Parameters
        ----------
            brightness : float
                Factor required to change the brightness of the image
            contrast : float
                Factor required to change the contrast of the image
            saturation : float
                Factor required to change the saturation of the image
        """

        # Make sure param values aren't negative
        assert (brightness >= 0 and contrast >= 0 and saturation >= 0)

        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation

    def __call__(self, sample):
        """
        Returns color transformed input image and its unchanged mask.

        Parameters
        ----------
        sample : tuple
            Tuple of image and its corresponding mask image

        Returns
        -------
        Tuple of color transformed input image and its unchanged mask.
        """

        # Get input and its mask from arg.
        img, mask = sample[0], sample[1]

        # Generate a number b.w. [max(0, 1 - self.brightness), 1 + self.brightness] likewise for other two params
        brightness_factor = np.random.uniform(max(0, 1 - self.brightness), 1 + self.brightness)
        contrast_factor = np.random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
        saturation_factor = np.random.uniform(max(0, 1 - self.saturation), 1 + self.saturation)

        # For factor < 1 brightness of the image will be decreased, analogously for contrast and saturation.
        # For factor = 1 original brightness will be retained, analogously for contrast and saturation.
        # For factor > 1 brightness of the image will be increased, analogously for contrast and saturation.

        # Apply actual color transform on input image based on above generated params.
        img = ImageEnhance.Brightness(img).enhance(brightness_factor)
        img = ImageEnhance.Contrast(img).enhance(contrast_factor)
        img = ImageEnhance.Color(img).enhance(saturation_factor)

        return img, mask


class FlipTransform(object):
    """
    A callable class to flip input image and its mask.
    
    ...

    Attributes
    ----------
    None

    Methods
    -------
    __call__(sample):
        Given tuple of input image and its corresponding mask image.
        Randomly flip (horizontally and vertically) both input image and mask image.
        Resized image and its mask will be returned as a tuple.
        
    """

    def __call__(self, sample):
        """
        Returns resized image and mask.

        Parameters
        ----------
        sample : tuple
            Tuple of image and its corresponding mask image

        Returns
        -------
        Tuple of resized image and mask
        """

        # Get image and its mask from arg.
        img, mask = sample[0], sample[1]

        # If randomly generated number (b.w. 0 and 1) is > 0.5 then flip (input image and mask image) horizontally else not.
        if np.random.uniform() > 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        # If randomly generated number (b.w. 0 and 1) is > 0.5 then flip (input image and mask image) vertically else not.
        if np.random.uniform() > 0.5:
            img = img.transpose(Image.FLIP_TOP_BOTTOM)
            mask = mask.transpose(Image.FLIP_TOP_BOTTOM)

        return img, mask


class AffineTransform(object):
    """
    A callable class to apply affine transform on input image and its mask.
    
    ...

    Attributes
    ----------
    degrees : float
        Required to rotate image and mask to some random angle
    translate : tuple of 2 vals
        Required to shift (by dx in x-direction and by dy in y-direction) image and mask to some random parameters.
    scale : tuple of 2 vals
        Required to scale image and mask to some random param (like downscale or upscale to some degree).
    shear : tuple of 4 vals
        Required to shear image and mask to some random params.
    resample : int
        Resampling method for affine transform. Default is nearest neighbor strategy for segmentation task.

    Methods
    -------
    __call__(sample):
        Given tuple of input image and its corresponding mask image.
        Draw random values given init args for rotation, translation, scaling and shearing.
        Returns affine transformed image and its mask as a tuple.
    """

    def __init__(self, degrees=0., translate=(0, 0), scale=(1., 1.), shear=(0., 0., 0., 0.), resample=Image.BILINEAR):
        """
        Accepts all the necessary attributes for performing affine transform.

        Parameters
        ----------
            degrees : float
                Required to rotate image and mask to some random angle
            translate : tuple of 2 vals
                Required to shift (by dx in x-direction and by dy in y-direction) image and mask to some random parameters.
            scale : tuple of 2 vals
                Required to scale image and mask to some random param (like downscale or upscale to some degree).
            shear : tuple of 4 vals
                Required to shear image and mask to some random params.
            resample : int
                Resampling method for affine transform. Default is nearest neighbor strategy for segmentation task.
        """

        assert (0 <= degrees <= 180)
        assert (all(0 <= val <= 1 for val in translate))
        assert (all(val > 0 for val in scale))
        self.degrees = degrees
        self.translate = translate
        self.scale = scale
        self.shear = shear
        self.resample = resample

    def __call__(self, sample):
        """
        Returns affine tranformed image and its mask.

        Parameters
        ----------
        sample : tuple
            Tuple of image and its corresponding mask image

        Returns
        -------
        Tuple of affine tranformed image and its mask
        """

        # Get image and its mask from arg.
        img, mask = sample[0], sample[1]
        width, height = img.size

        # Randomly generate angle b.w. [-self.degrees, self.degrees]
        angle = np.random.uniform(-self.degrees, self.degrees)

        # Randomly generate dx and dy based on params to translate the image
        max_dx = self.translate[0] * width
        max_dy = self.translate[1] * height
        translations = (np.round(np.random.uniform(-max_dx, max_dx)),
                        np.round(np.random.uniform(-max_dy, max_dy)))

        # Randomly generate new scale based on params
        new_scale = np.random.uniform(self.scale[0], self.scale[1])

        # Randomly generate shear ranges based on params
        shear_ranges = [np.random.uniform(self.shear[0], self.shear[1]),
                        np.random.uniform(self.shear[2], self.shear[3])]

        # Apply affine transform based on above generated values on image and mask
        img = transforms.functional.affine(img, angle=angle, translate=translations,
                                           scale=new_scale, shear=shear_ranges, resample=self.resample)
        mask = transforms.functional.affine(mask, angle=angle, translate=translations,
                                            scale=new_scale, shear=shear_ranges, resample=self.resample)

        return img, mask


# Data augmentation pipeline for train images
train_transform = transforms.Compose([RandomCrop(2048, 2048),
                                      Resize(512, 512),
                                      ColorTransform(brightness=0.5, contrast=0.5, saturation=0.5),
                                      FlipTransform(),
                                      AffineTransform(degrees=180, translate=(.2, .2), scale=(0.75, 1.25),
                                                      shear=[-30, 30, -30, 30], resample=Image.NEAREST)])

# Transform for validation images
val_transform = transforms.Compose([Resize(512, 512)])

In [4]:
# To get pixel distribution for each class in whole dataset
def get_class_distribution(mask_path, nb_total_classes):
    n = 0
    class_info = {}
    temp = np.zeros(nb_total_classes)
    
    for fn in os.listdir(mask_path):
        mask = Image.open(os.path.join(mask_path, fn))
        w, h = mask.size
        arr = np.ravel(np.asarray(mask, dtype = np.int32))
        temp += (np.bincount(arr, minlength = nb_total_classes)/(w*h))
        n += 1
            
    temp /= n
    percs = np.sort(temp)[::-1]
    indices = np.argsort(temp)[::-1]
    
    return percs, indices

# To select top-k classes (Excluding background class) also return their weights to pass it to cross entropy loss
def select_top_k_classes(class_dist, class_indices, k=5):
    assert(k <= class_dist.shape[0])
    assert(len(class_dist) == len(class_indices))
    class_info = {}
    if np.any(class_indices[:k] == 0):
        k += 1
        zero_pos = np.where(class_indices == 0)[0]
        keep_classes = class_indices[:k]
        keep_classes_dist = class_dist[:k]
        keep_classes_dist[zeros_pos] += np.sum(class_dist[k:])
    else:
        keep_classes = class_indices[:k]
        keep_classes_dist = class_dist[:k]
        keep_classes = np.append(keep_classes, 0)
        keep_classes_dist = np.append(keep_classes_dist, np.sum(class_dist[k:]))
    class_weights = 1/keep_classes_dist
    class_info['class_weights'] = class_weights
    class_info['keep_classes'] = keep_classes
    return class_info

class_dist, class_indices = get_class_distribution(mask_path, nb_total_classes = 23)
keep_classes_info = select_top_k_classes(class_dist, class_indices, k = 5)

In [5]:
class CustomDataset(Dataset):
    def __init__(self, img_path, mask_path, keep_classes_info):
        # Path containing input images
        self.img_path = img_path
        
        # Path containing mask images
        self.mask_path = mask_path
        
        # Number of classes in output segmentation map
        self.nb_classes = keep_classes_info['keep_classes'].shape[0]
        
        # Store filenames of all the images in a list
        self.filenames = []
        self.data_len = 0
        for fn in os.listdir(img_path):
            self.filenames.append(fn.split('.')[0])
            self.data_len += 1
           
        # Generate a look up table to remove unwanted class from mask images 
        self.lut = np.zeros(256,dtype=np.uint8)
        
        # Rearrange class weights according to new class_ids
        self.class_weights = np.zeros(self.nb_classes)
        new_id = 1
        for class_id, class_weight in zip(keep_classes_info['keep_classes'], keep_classes_info['class_weights']):
            if class_id:
                self.lut[class_id] = new_id
                self.class_weights[new_id] = class_weight
                new_id += 1
            else:
                self.class_weights[class_id] = class_weight
                
        
    def __len__(self):
        '''
            Returns dataset size
        '''
        return self.data_len
    
    def __getitem__(self, idx):
         # Read image (mask image too) on given idx from list
        img = Image.open(os.path.join(img_path, self.filenames[idx]) + '.jpg')
        mask = Image.open(os.path.join(mask_path, self.filenames[idx]) + '.png')
        
        # Replace unwanted class in mask image from LUT
        np_mask_img = np.array(mask)
        mask = Image.fromarray(self.lut[np_mask_img])
        
        return img, mask
    
# Map inputs to desired state by passing them through augmentation pipeline
class TransformDataset(Dataset):
    def __init__(self, dataset, nb_classes, transform = None, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]):
        self.dataset = dataset
        self.nb_classes = nb_classes
        self.transform = transform
        self.normalize = transforms.Normalize(mean=mean, std=std)
        self.toTensor = transforms.ToTensor()

    def __getitem__(self, idx):
        if self.transform:
            img, mask = self.transform(self.dataset[idx])
            
        # Plot augmented image and its corresponding mask
#         plt.imshow(np.array(img))
#         plt.show()
#         plt.imshow(np.array(mask))
#         plt.show()

        # Convert image to torch tensor
        img = self.toTensor(img)
        mask = torch.from_numpy(np.array(mask))
        
        # Mean normalization on image and one hot encode the mask image
        img = self.normalize(img)
        mask = F.one_hot(mask.to(torch.int64), num_classes = self.nb_classes).permute(2, 0, 1).to(torch.float32)
        
        return img, mask

    def __len__(self):
        return len(self.dataset)

# Creating dataset class
data = CustomDataset(img_path, mask_path, keep_classes_info)

In [6]:
# Create random indices for train, val and test set
val_split = .1

dataset_size = len(data)
np.random.seed(4)
indices = np.random.permutation(dataset_size)
split = int(np.floor(val_split * dataset_size))

# Divide the data into train, val and test dataset given above indices
train_dataset = TransformDataset(Subset(data, indices[(2 * split):]), data.nb_classes, train_transform)
val_dataset = TransformDataset(Subset(data, indices[split: (2 * split)]), data.nb_classes, val_transform)
test_dataset = TransformDataset(Subset(data, indices[:split]), data.nb_classes, val_transform)

# Create data loader for batching
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 8, num_workers = 2, shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 40, num_workers = 2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 40, num_workers = 2)

In [7]:
# Loss criterion for image segmentation problem
# Combination of dice loss and cross entropy loss is excellent choice
class DiceBCELoss(nn.Module):
    def __init__(self, weight = None):
        super(DiceBCELoss, self).__init__()
        if weight is not None:
            self.weight = torch.cuda.FloatTensor(weight)
        else:
            self.weight = None

    def forward(self, inputs, targets, smooth=1):
        
        # Apply softmax to get prob. distribution for each pixel
        probs = F.softmax(inputs, 1)
        
        with torch.no_grad():
            gt = torch.argmax(targets, 1)
        
        BCE = F.cross_entropy(inputs, gt, weight = self.weight, reduction='mean')
        
        probs = probs.view(-1)
        targets = targets.view(-1)
        
        intersection = (probs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(probs.sum() + targets.sum() + smooth)  
        
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE
    
# Dice score used as an accuracy metric
def calculateDiceScore(preds, targets, smooth = 1):
    preds = F.softmax(preds, 1)       
        
    preds = preds.view(-1)
    targets = targets.view(-1)

    intersection = (preds * targets).sum()                            
    return (2.*intersection + smooth)/(preds.sum() + targets.sum() + smooth)  

In [8]:
# Loading model from torchvision
model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=False, num_classes=21, aux_loss=None)

# Change last layer for our requirement (we have k+1 classes total including background)
model.classifier[4] = nn.Conv2d(256, data.nb_classes, (1, 1), (1, 1))
model.aux_classifier[4] = nn.Conv2d(256, data.nb_classes, (1, 1), (1, 1))

# Freeze some of the initial layers to finetune the model
idx = 0
for name, param in model.named_parameters():
    # print(idx, name)
    if idx >= 165:
        break
    param.requires_grad = False
    idx += 1
    
# Defining device to cuda if GPU is avaliable else to cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Move model to device
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/checkpoints/resnet101-5d3b4d8f.pth


HBox(children=(FloatProgress(value=0.0, max=178728960.0), HTML(value='')))




Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth


In [9]:
# Defining loss criterion
criterion = DiceBCELoss(weight = data.class_weights)

# Defining optimizer to update model params (Only some of the last layers), Adam's a good default
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters()))

# Learning rate scheduler to update lr when loss stops improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

In [10]:
# Path at which to store model and required configs
model_save_path = 'aerial_imagery_seg_model.pt'

# Total epochs to train for
epochs = 50

# Keep track of min val loss
min_val_loss = 100

# Training loop
for epoch in range(1, epochs + 1):
    train_loss = 0.
    val_loss = 0.
    
    # Switch model to training mode
    model.train()
    
    # A pass through train dataset
    for imgs, masks in train_loader:
        # Move batch data to device
        imgs, masks = imgs.to(device), masks.to(device)
        
        # Clear previous gradients
        optimizer.zero_grad()
        
        # Forward pass the batch and get predictions
        preds = model(imgs)['out']
        
        # Calculate Loss
        loss = criterion(preds, masks)
        
        # Add to calculate loss for whole dataset
        train_loss += (loss.item() * imgs.size(0))
        
        # Backpropagate gradients
        loss.backward()
        
        # Make weight updates
        optimizer.step()
        
        # Empty cuda cache to clear useless data from VRAM for better utilization
        torch.cuda.empty_cache()
        
    # Switch model to inference mode
    model.eval()
    
    # Check model performance on val set
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)

            preds = model(imgs)['out']

            loss = criterion(preds, masks)
            val_loss += (loss.item() * imgs.size(0))

            torch.cuda.empty_cache()
        
    # Calculate avg train and avg val loss
    train_loss /= len(train_dataset)
    val_loss /= len(val_dataset)
    
    # Decrease lr depending on val_loss & given params (object init)
    scheduler.step(val_loss)
        
    # If loss is decreasing then store model in file else not
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        torch.save(model, model_save_path)
        
    print('Epoch {}:\tTrain Loss: {}\tVal Loss: {}'.format(epoch, train_loss, val_loss))

Epoch 1:	Train Loss: 1.5789325475692748	Val Loss: 0.9795246720314026
Epoch 2:	Train Loss: 0.9426460683345794	Val Loss: 0.6723008155822754
Epoch 3:	Train Loss: 0.8570941805839538	Val Loss: 0.6240894794464111
Epoch 4:	Train Loss: 0.8086203053593636	Val Loss: 0.8219184279441833
Epoch 5:	Train Loss: 0.7614537358283997	Val Loss: 0.7400764226913452
Epoch 6:	Train Loss: 0.7334292367100715	Val Loss: 0.8786985874176025
Epoch 7:	Train Loss: 0.6875955708324909	Val Loss: 0.5164713859558105
Epoch 8:	Train Loss: 0.6216550342738628	Val Loss: 0.5003544688224792
Epoch 9:	Train Loss: 0.6054667018353939	Val Loss: 0.4883110523223877
Epoch 10:	Train Loss: 0.5991924457252026	Val Loss: 0.47872114181518555
Epoch 11:	Train Loss: 0.5734261631965637	Val Loss: 0.4770793616771698
Epoch 12:	Train Loss: 0.5523261554539204	Val Loss: 0.4732629954814911
Epoch 13:	Train Loss: 0.5539035253226757	Val Loss: 0.46066129207611084
Epoch 14:	Train Loss: 0.5280256122350693	Val Loss: 0.4582401216030121
Epoch 15:	Train Loss: 0.570

KeyboardInterrupt: 

In [13]:
# Load trained model from saved file
model = torch.load(model_save_path)

# Move model to device
model.to(device)

# Switch model to inference mode
model.eval()

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [14]:
# Calculate dice score on test set, to get test accuracy
with torch.no_grad():
    test_acc = 0       
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        preds = model(imgs)['out']

        dice_score = calculateDiceScore(preds, masks).item()
        test_acc += (dice_score * imgs.size(0))

        torch.cuda.empty_cache()
    print('Test Acc: {}'.format(test_acc / len(test_dataset)))

Test Acc: 0.8391628265380859
