# 1. Setup

In [1]:
!pip install madgrad

Collecting madgrad
  Downloading madgrad-1.1-py3-none-any.whl (7.4 kB)
Installing collected packages: madgrad
Successfully installed madgrad-1.1


## 1.1 Libraries

In [1]:
import torch, torchvision
import torch.nn as nn
from torch import optim
import torch.nn.parallel
import torch.nn.functional as F
from torch.autograd import Function
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision.transforms.functional as TF
import madgrad

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage import io

from tqdm import notebook
from IPython.display import clear_output

import PIL.Image
from PIL import Image, ImageDraw, ImageFilter
PIL.Image.MAX_IMAGE_PIXELS = None

## 1.2 CUDA

In [2]:
# CUDA
device = 'cpu'
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('GPU accelerator not available.')
else:
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(gpu_info)

Mon May 10 05:11:02 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01    Driver Version: 460.73.01    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-PCIE-40GB      Off  | 00000000:01:00.0 Off |                    0 |
| N/A   57C    P0    46W / 250W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## 1.3 CPU

In [None]:
!cat /proc/cpuinfo

# 2. Processing Data for Training and Validation

## 2.1 RLE and Mask Functions

In [4]:
# https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.ravel()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    print("rle2mask completed.")
    return img.reshape(shape).T

## 2.2 Tiling Functions for Training and Validation
The tiles here overlapping i.e. they have a stride smaller than the sides of the image.

In [1]:
def tile(img, atc_mask, mask, x_max=None, y_max=None, tile_size=None, stride=None, 
         img_save_path='path', mask_save_path='path', file_prefix='filename'):
    counter = 0
    y_list = list(range(0,y_max,stride))
    x_list = list(range(0,x_max,stride))

    for i in range(len(y_list)):
        if y_list[i]+tile_size > y_max:
            y_delta = y_list[i] + tile_size - y_max
            y_list[i] = y_list[i] - y_delta - 1

    for i in range(len(x_list)):
        if x_list[i]+tile_size > x_max:
            x_delta = x_list[i]+tile_size - x_max
            x_list[i] = x_list[i] - x_delta - 1

    for y in y_list:
        y_end = y + tile_size

        for x in x_list:
            x_end = x + tile_size
            atc_mask_crop = atc_mask[y:y_end, x:x_end]
            
            if (atc_mask_crop != 0).any():
                # crop and save image
                img_crop = img[y:y_end, x:x_end, :]
                img_file_name = img_save_path + file_prefix+'_'+str(x)+'_'+str(x_end)+'_'+str(y)+'_'+str(y_end)+'.tiff'
                io.imsave(img_file_name, img_crop, check_contrast=False)
                # crop and save mask
                mask_crop = mask[y:y_end, x:x_end]
                mask_file_name = mask_save_path + file_prefix+'_'+str(x)+'_'+str(x_end)+'_'+str(y)+'_'+str(y_end)
                np.save(mask_file_name, mask_crop)
                counter += 1
    print("Tiling "+file_prefix+" completed. "+str(counter))


def anatomic_mask(json_filepath, x_max, y_max):
    read_file = open(json_filepath, "r") 
    data = json.load(read_file)
    data[0]

    polys = []
    for index in range(data.__len__()):
        geom = data[index]['geometry']['coordinates']
        if len(geom)>1:
            for j in range(len(geom)):
                polys.append(geom[j][0])
        else:
            polys.append(geom[0])

    msk = Image.new('L', (x_max, y_max), 0)  # (w, h)
    for i in range(0,len(polys)):
        poly = polys[i]
        ImageDraw.Draw(msk).polygon(tuple(map(tuple, poly)), outline=1, fill=1) 

    return np.array(msk)


def json2mask(json_path, x_max, y_max):
    # Open the mask file
    read_file = open(json_path, "r") 
    data = json.load(read_file)

    polys = []
    for index in range(data.__len__()):
        geom = np.array(data[index]['geometry']['coordinates'])
        polys.append(geom)

    msk = Image.new('L', (x_max, y_max), 0)  # (w, h)
    for i in range(len(polys)):
        poly = polys[i]
        ImageDraw.Draw(msk).polygon(tuple(map(tuple, poly[0])), outline=0, fill=1) 

    # Tiling needs rescaled image size!
    mask = np.array(msk)
    return mask


def make_tiles(tile_size, stride, file_name='none', rle=None, json=False, read_dir='na'):
    img_path = read_dir + file_name + '.tiff'
    json_path = read_dir + file_name + '.json'
    json_path_anat = read_dir + file_name + '-anatomical-structure.json'
    try: img = io.imread(img_path, plugin='pil').squeeze()
    except: img = io.imread(img_path).squeeze()

    y_max = img.shape[0]
    x_max = img.shape[1]
    atc_mask = anatomic_mask(json_path_anat, x_max, y_max)
    if json:
        mask = json2mask(json_path, x_max, y_max)
    elif rle:
        mask = rle2mask(rle, (x_max, y_max)) # WxH
    else:
        mask = np.zeros([y_max, x_max])

    tile(img, atc_mask, mask, x_max, y_max, tile_size, stride, 
         img_save_path='train_image_tiles/', mask_save_path='train_mask_tiles/', file_prefix=file_name)

## 2.3 Create Tiles for Training and Validation

In [6]:
TILE_SIZE = 1536
STRIDE = 1024

In [7]:
os.makedirs('train_image_tiles/')
os.makedirs('train_mask_tiles/')

train_df = pd.read_csv("new_train.csv")

for i in notebook.tqdm(range(len(train_df))):
    read_dir = 'train/'
    file_name = train_df.iloc[i, 0]
    rle = train_df.iloc[i, 1]
    make_tiles(tile_size = TILE_SIZE, stride = STRIDE, file_name=file_name, rle=rle, read_dir=read_dir)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

rle2mask completed.
Tiling 2f6ecfcdf completed. 373
rle2mask completed.
Tiling 8242609fa completed. 870
rle2mask completed.
Tiling aaa6a05cc completed. 162
rle2mask completed.
Tiling cb2d976f4 completed. 997
rle2mask completed.
Tiling b9a3865fc completed. 850
rle2mask completed.
Tiling b2dc8411c completed. 242
rle2mask completed.
Tiling 0486052bb completed. 421
rle2mask completed.
Tiling e79de561c completed. 324
rle2mask completed.
Tiling 095bf7a1f completed. 987
rle2mask completed.
Tiling 54f2eec69 completed. 335
rle2mask completed.
Tiling 4ef6695ce completed. 1397
rle2mask completed.
Tiling 26dc41664 completed. 913
rle2mask completed.
Tiling c68fe75ea completed. 1323
rle2mask completed.
Tiling afa5e8098 completed. 1117
rle2mask completed.
Tiling 1e2425f28 completed. 663



In [10]:
train_df = pd.read_csv("train_augmented.csv")

for i in notebook.tqdm(range(len(train_df))):
    read_dir = ''
    file_name = train_df.iloc[i, 0]
    rle = train_df.iloc[i, 1]
    make_tiles(tile_size = TILE_SIZE, stride = STRIDE, file_name=file_name, rle=rle, read_dir=read_dir)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))

rle2mask completed.
Tiling A_54f2eec69 completed. 335
rle2mask completed.
Tiling A_8242609fa completed. 870
rle2mask completed.
Tiling A_b9a3865fc completed. 896
rle2mask completed.
Tiling A_aaa6a05cc completed. 162
rle2mask completed.
Tiling A_afa5e8098 completed. 1117
rle2mask completed.
Tiling A_cb2d976f4 completed. 997
rle2mask completed.
Tiling A_e79de561c completed. 324
rle2mask completed.
Tiling A_8242609fa2 completed. 870



In [14]:
train_df = pd.read_csv("train_external.csv")

for i in notebook.tqdm(range(len(train_df))):
    read_dir = 'Panorama/'
    file_name = train_df.iloc[i, 0]
    mk = train_df.iloc[i, 1]
    make_tiles(tile_size = TILE_SIZE, stride = STRIDE, file_name=file_name, json=mk, read_dir=read_dir)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7.0), HTML(value='')))

Tiling 095_mw completed. 359
Tiling Colon_hist_slide completed. 144
Tiling Kidney_20x_Objective_5_minutes completed. 317
Tiling Kidney_by_elearning completed. 35
Tiling Kidney_EZ_Microscope completed. 132
Tiling Kidney_FF_TrMass_20x completed. 1280
Tiling Ovarian_neoplasm completed. 24



In [12]:
train_df = pd.read_csv("train_psudo.csv")

for i in notebook.tqdm(range(len(train_df))):
    read_dir = 'test/'
    file_name = train_df.iloc[i, 0]
    rle = train_df.iloc[i, 1]
    make_tiles(tile_size = TILE_SIZE, stride = STRIDE, file_name=file_name, rle=rle, read_dir=read_dir)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

rle2mask completed.
Tiling 2ec3f1bb9 completed. 882
rle2mask completed.
Tiling 3589adb90 completed. 331
rle2mask completed.
Tiling d488c759a completed. 717
rle2mask completed.
Tiling aa05346ff completed. 984
rle2mask completed.
Tiling 57512b7f1 completed. 595



## 2.4 Delete Training Data Folders!!

In [None]:
# Code for deleting folders
#!rm -rf train_image_tiles/
#!rm -rf train_mask_tiles/

## 2.5 Empty Trash

In [4]:
!rm -rf ~/.local/share/Trash/*

# 3. Dataset for Training and Validation

## 3.1 Dataset Objects

In [3]:
class HuBMAPDataset(Dataset):
    def __init__(self, image_path='train_image_tiles/', mask_path='train_mask_tiles/', tile_size=1536, mean=None, std=None):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.tile_size = tile_size
        self.image_list = os.listdir(self.image_path)
        self.mask_list = os.listdir(self.mask_path)
        #self.image_list = [x for x in self.image_list if x[:9] in ['2ec3f1bb9','3589adb90','57512b7f1','aa05346ff','d488c759a']]
        self.nrm = transforms.Compose([transforms.ToTensor(), 
                                       transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05), 
                                       transforms.Normalize(mean=mean, std=std)])
        self.flips = transforms.Compose([transforms.RandomVerticalFlip(p=0.5), 
                                       transforms.RandomHorizontalFlip(p=0.5)])
        self.rand_blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))


    def __getitem__(self, idx):
        img_file_name = self.image_list[idx]
        mask_file_name = img_file_name[:-5]+'.npy'
        image = Image.open(self.image_path + img_file_name).convert("RGB")
        
        if mask_file_name in self.mask_list:
            mask = np.load(self.mask_path + mask_file_name)
            mask = torch.tensor(mask, dtype=torch.long)
        else:
            mask = torch.zeros([self.tile_size,self.tile_size], dtype=torch.long)
        

        # Normalize image and ToTensor
        image = self.nrm(image)
        if img_file_name[:9] not in ['2ec3f1bb9','3589adb90','57512b7f1','aa05346ff','d488c759a'] and img_file_name[:2] != 'A_':
            image = self.rand_blur(image)
        
        seed = torch.randint(0,81261917,[1]) # make a seed with random int generator 
        torch.manual_seed(seed) # needed for torchvision
        image = self.flips(image)
        torch.manual_seed(seed) # needed for torchvision
        mask = self.flips(mask)
        
        return image, mask

    def __len__(self):
        return len(self.image_list)

In [4]:
class TestDataset(Dataset):
    def __init__(self, image_path='train_image_tiles/', mask_path='train_mask_tiles/', tile_size=1536, mean=None, std=None):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.tile_size = tile_size
        self.image_list = os.listdir(self.image_path)
        self.mask_list = os.listdir(self.mask_path)
        self.image_list = [x for x in self.image_list if x[:9] in ['2ec3f1bb9','3589adb90','57512b7f1','aa05346ff','d488c759a']]
        self.nrm = transforms.Compose([transforms.ToTensor(), 
                                       transforms.Normalize(mean=mean, std=std)])

    def __getitem__(self, idx):
        img_file_name = self.image_list[idx]
        mask_file_name = img_file_name[:-5]+'.npy'
        image = Image.open(self.image_path + img_file_name).convert("RGB")
        
        if mask_file_name in self.mask_list:
            mask = np.load(self.mask_path + mask_file_name)
            mask = torch.tensor(mask, dtype=torch.long)
        else:
            mask = torch.zeros([self.tile_size,self.tile_size], dtype=torch.long)
        

        # Normalize image and ToTensor
        image = self.nrm(image)
        
        return image, mask

    def __len__(self):
        return len(self.image_list)

In [4]:
class TestDataset(Dataset):
    def __init__(self, image_path='test_image_tiles/', mask_path='test_mask_tiles/', tile_size=1536, mean=None, std=None):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.tile_size = tile_size
        self.image_list = os.listdir(self.image_path)
        self.mask_list = os.listdir(self.mask_path)
        self.nrm = transforms.Compose([transforms.ToTensor(), 
                                       transforms.Normalize(mean=mean, std=std)])

    def __getitem__(self, idx):
        img_file_name = self.image_list[idx]
        mask_file_name = img_file_name[:-5]+'.npy'
        image = Image.open(self.image_path + img_file_name).convert("RGB")
        
        if mask_file_name in self.mask_list:
            mask = np.load(self.mask_path + mask_file_name)
            mask = torch.tensor(mask, dtype=torch.long)
        else:
            mask = torch.zeros([self.tile_size,self.tile_size], dtype=torch.long)
        

        # Normalize image and ToTensor
        image = self.nrm(image)
        
        return image, mask

    def __len__(self):
        return len(self.image_list)

# 4. Network

In [5]:
n_classes = 2
network = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, num_classes=n_classes, aux_loss=True)

# 5. Training

## 5.1 Evaluation Metrics: Dice Coefficient

In [6]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = torch.sigmoid(inputs)
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()

        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [7]:
def eval(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    n_val = len(loader)  # the number of batch
    DL = DiceLoss()
    tot = 0

    #with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
    for idx, data in enumerate(loader):
        imgs, true_masks = data
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device)

        with torch.no_grad():
            masks_pred = net(imgs)['out']

        #if net.n_classes > 1:
        #    tot += F.cross_entropy(mask_pred, true_masks).item()
        #else:
        #pred = torch.sigmoid(mask_pred)
        #pred = (pred > 0.5).float()
        pred = masks_pred.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        tot += DL(pred, true_masks).item()
        #pbar.update()

    # net.train()
    return tot / n_val

## 5.2 Set Prameters

In [8]:
mean = [0.6359447051178325,0.4671320752664046,0.6818842088634317]
std = [0.1621760194274512,0.2289463254538449,0.1411499072882262]

dataset = HuBMAPDataset('train_image_tiles/', 'train_mask_tiles/', mean=mean, std=std)

batch_size = 4
n_workers = 7
n_channels = 3
n_train = len(dataset)

if (len(dataset) % batch_size == 0):
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=n_workers, pin_memory=True, drop_last=False)
else:
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=n_workers, pin_memory=True, drop_last=True)

epochs = 3

In [9]:
test_dataset = TestDataset('train_image_tiles/', 'train_mask_tiles/', mean=mean, std=std)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                         num_workers=n_workers, pin_memory=True, drop_last=False)

In [None]:
test_dataset = TestDataset('test_image_tiles/', 'test_mask_tiles/', mean=mean, std=std)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                         num_workers=n_workers, pin_memory=True, drop_last=False)

## 5.3 Configuration

In [10]:
class Head(nn.Module):
    def __init__(self, network):
        super().__init__()
        self.network = network
        self.backbone = self.network.backbone
        self.classifier = self.network.classifier
        self.aux_classifier = self.network.aux_classifier
        
    def forward(self, x):
        input_shape = x.shape[-2:]
        with torch.no_grad():
            x = self.backbone(x)
        
        output = self.classifier(x['out'])
        output = F.interpolate(output, input_shape, mode='bilinear', align_corners=False)
        output_aux = self.aux_classifier(x['aux'])
        output_aux = F.interpolate(output_aux, input_shape, mode='bilinear', align_corners=False)
        
        return {'out':output, 'aux':output_aux}

In [10]:
model_checkpoint = 'DeepLab2'
save_path = 'model/'
# os.makedirs(save_path)

n_classes = 2
# network0 = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False, num_classes=n_classes, aux_loss=True)
checkpoint = torch.load(save_path+model_checkpoint)
network.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# network = Head(network0)
# checkpoint = torch.load(save_path+'DeepLabHead0')
# network.load_state_dict(checkpoint['model_state_dict'])

network = network.to(device)
optimizer = madgrad.MADGRAD(network.parameters(), lr=1e-4, momentum=0.9)

# del checkpoint

criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()

# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

In [11]:
model_checkpoint = 'DeepLab'

## 5.4 Start Training

In [None]:
min_loss = 0.07
for epoch in range(epochs):
    network.train()

    epoch_loss = 0
    with notebook.tqdm(total=len(dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for idx, data in enumerate(train_loader):
            imgs, true_masks = data
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.squeeze(1).to(device=device)

            optimizer.zero_grad()
            
            # Casts operations to mixed precision
            with torch.cuda.amp.autocast():
                outputs = network(imgs)
                loss = criterion1(outputs['out'], true_masks) + criterion2(outputs['aux'], true_masks)
            epoch_loss += loss.item()
            pbar.set_postfix(**{'loss (batch)': loss.item()})
            
            # Scales the loss, and calls backward()
            # to create scaled gradients
            scaler.scale(loss).backward()

            # Unscales gradients and calls
            # or skips optimizer.step()
            scaler.step(optimizer)

            # Updates the scale for next iteration
            scaler.update()
            
            pbar.update(imgs.shape[0])
            if (idx % 1100 == 0) and (idx != 0):
                state = {
                    'model_state_dict': network.state_dict()
                }
                torch.save(state, save_path+model_checkpoint+str(epoch%10))
                test_score = eval(network, test_loader, device)
                network.train()
                print("Dice loss:", test_score)
                if test_score < min_loss:
                    min_loss = test_score
                    print("Min Dice loss:", min_loss)
                    torch.save(state, save_path+model_checkpoint+'_min_loss_'+str(min_loss))

    print('Epoch Cross Entropy Loss: ', epoch_loss)
    print('Saving at %s' % save_path)
    state = {
        'model_state_dict': network.state_dict()
    }
    torch.save(state, save_path+model_checkpoint+str(epoch%10))
    print('Save complete:'+str(epoch%10))

HBox(children=(HTML(value='Epoch 1/3'), FloatProgress(value=0.0, max=22119.0), HTML(value='')))

Dice loss: 0.0631866929158999
Min Dice loss: 0.0631866929158999
Dice loss: 0.06033728370776755
Min Dice loss: 0.06033728370776755
Dice loss: 0.0700602450811794
Dice loss: 0.06440035910964702
Dice loss: 0.06440930945335785

Epoch Cross Entropy Loss:  88.20144050481031
Saving at model/
Save complete:0


HBox(children=(HTML(value='Epoch 2/3'), FloatProgress(value=0.0, max=22119.0), HTML(value='')))

Dice loss: 0.07221358649303458
Dice loss: 0.05955204887886268
Min Dice loss: 0.05955204887886268
Dice loss: 0.0577204049667182
Min Dice loss: 0.0577204049667182
Dice loss: 0.07072981623555884


In [None]:
from jarviscloud import jarviscloud
jarviscloud.pause()