    Phase 1 directly trained w/ LovaszHingeLoss. 
    Model: dsDSEUNeXt, no dropout
    Fold: 0/4 (stratified)
    lr_scheduler: CosineAnnealing
    Optimizer: SGD
    initial lr: 0.1
    cycle: 0
    batch_size: 16 
    Epochs: 80

# TGS Salt Identification

In [None]:
import os
import sys
import time
from datetime import datetime
import numpy as np
import imageio
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import shutil
from collections import OrderedDict
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision import transforms
import torchvision.transforms.functional as TF

from sklearn.metrics import jaccard_similarity_score

print(torch.__version__) # Different PyTorch version for CPU-only/GPU
print(sys.version)

torch.set_printoptions(precision = 3)

## Data Visulization

In [None]:
!ls -la ../input

In [None]:
train_path = "../input/train/"

train_df = pd.read_csv('../input/train.csv')
depths_df = pd.read_csv('../input/depths.csv')

file_list = list(train_df['id'].values)

train_df.columns, depths_df.columns

In [None]:
def plot_imgs_masks(file_list, train_path, depths_df, size = 1):
    image_folder = os.path.join(train_path, "images")
    mask_folder = os.path.join(train_path, "masks")
    figs, axs = plt.subplots(2, size, figsize = (15, 8))
    indices = np.random.randint(0, len(file_list), size = size)
    
    for i, idx in enumerate(indices):
        filename = file_list[idx]
        image_path = os.path.join(image_folder, filename + ".png")        
        mask_path = os.path.join(mask_folder, filename + ".png")
        image = np.array(imageio.imread(image_path), dtype = np.uint8)
        mask = np.array(imageio.imread(mask_path), dtype = bool).astype(np.float32)[..., np.newaxis]
        depth = np.array(depths_df.loc[depths_df['id'] == filename]['z'], dtype = np.float32).reshape(1, 1, 1)
        axs[0][i].imshow(image)
        axs[0][i].set_title('depth: {}'.format(depth))
        axs[0][i].grid()
        axs[1][i].imshow(mask[:, :, 0])
        axs[1][i].grid()

In [None]:
plot_imgs_masks(file_list, train_path, depths_df, 5)

In [None]:
depths_df.hist()

## Define Pytorch Dataset

In [None]:
class TGSSaltDataset(data.Dataset):    
    def __init__(self, root_path, file_list, depths_df, train = True, augmentation = False):
        
        self.image_folder = os.path.join(root_path, "images")
        self.mask_folder = os.path.join(root_path, "masks")
        self.file_list = file_list
        self.depths_df = depths_df
        self.train = train
        self.depth_range = self.depths_df['z'].max() - self.depths_df['z'].min()
        self.depth_mean = self.depths_df['z'].mean()
        self.augmentation = augmentation
        
        print('dataset size:', self.__len__())
    
    def __len__(self):
        return len(self.file_list)
    
    def _transform(self, sample):

        eps = 1e-5
        
        sample['image'] = TF.to_pil_image(sample['image'])
        sample['image'] = TF.pad(sample['image'], padding = (13, 13, 14, 14), padding_mode = 'symmetric')
        
        if self.train:
            sample['mask'] = TF.to_pil_image(sample['mask'])
            sample['mask'] = TF.pad(sample['mask'], padding = (13, 13, 14, 14), padding_mode = 'symmetric')
            
            if self.augmentation:
                if random.random() > 0.5:
                    sample['image'] = TF.hflip(sample['image'])
                    sample['mask'] = TF.hflip(sample['mask'])

                if random.random() > 0.5:
                    choice = np.random.choice(3)
                    adjust = np.random.uniform(0.95, 1.05)
                    if choice == 0:
                        sample['image'] = TF.adjust_brightness(sample['image'], adjust)
                    elif choice == 1:
                        sample['image'] = TF.adjust_contrast(sample['image'], adjust)
                    else:
                        sample['image'] = TF.adjust_gamma(sample['image'], adjust)
        
            sample['image'] = TF.to_tensor(sample['image'])
            sample['mask'] = TF.to_tensor(sample['mask'])
            
            '''
                some images are all-zero. if not safe_division, then those empty images 
                would be damaged through normalization (due to zero std), hence prevent 
                network from improving.
            '''
            img_mean = sample['image'].mean()
            img_std = sample['image'].std() + eps
            sample['image'] = TF.normalize(sample['image'], [img_mean], [img_std])
            
            sample['empty_mask'] = torch.from_numpy(sample['empty_mask'])
            
        else:
            # Test Time Augmentation
            test_imgs = list()
            test_imgs.append(TF.to_tensor(sample['image']))
            if self.augmentation:
                test_imgs.append(TF.to_tensor(TF.hflip(sample['image'])))
                sample['depth'] = np.repeat(sample['depth'], 2, axis = 0)
                
            sample['image'] = torch.cat(test_imgs, dim = 0)

            img_mean = sample['image'].view(len(test_imgs), -1).mean(dim = -1)
            img_std = sample['image'].view(len(test_imgs), -1).std(dim = -1) + eps
            sample['image'] = TF.normalize(sample['image'], img_mean, img_std).unsqueeze(1)  
            sample['depth'] = np.expand_dims(sample['depth'], axis = 1)           
        
        sample['depth'] = torch.from_numpy((sample['depth'] - self.depth_mean) / (self.depth_range + eps))
        
        return sample
    
    def __getitem__(self, index):
        
        file_idx = self.file_list[index]
        
        image_path = os.path.join(self.image_folder, file_idx + ".png")
        image = np.array(imageio.imread(image_path), dtype = np.uint8)[:, :, 0][..., np.newaxis]
        depth = np.array(self.depths_df.loc[self.depths_df['id'] == file_idx]['z'], dtype = np.float32).reshape(1, 1, 1)
        sample = {'image': image, 'depth': depth}

        if self.train:
            mask_path = os.path.join(self.mask_folder, file_idx + ".png")
            mask = np.array(imageio.imread(mask_path), dtype = bool).astype(np.float32)[..., np.newaxis]
            sample['mask'] = mask
            sample['empty_mask'] = (np.sum(mask) != 0).astype(np.float32).reshape(1)    # 0 for empty, 1 for non-empty

        sample = self._transform(sample)
        
        return sample

In [None]:
train_dataset = TGSSaltDataset(train_path, file_list, depths_df, True, True)

In [None]:
for i in range(1):
    sample = train_dataset[i]
    print('data size: image:', sample['image'].shape, ' mask:', sample['mask'].shape)
    print('image pixel value range:', (torch.min(sample['image']), torch.max(sample['image'])))
    print('mask pixel value range:', (torch.min(sample['mask']), torch.max(sample['mask'])))
    print('depth:', sample['depth'])
    print('data type: image:', sample['image'].dtype, ' depth:', sample['depth'].dtype, ' mask:', sample['mask'].dtype)

## Split Dataset & Define Pytorch DataLoader

In [None]:
from sklearn.model_selection import StratifiedKFold

def get_dataloader(dataset, args, stratification_metrics = None):
    
    def random_split_indices(indices, split_ratio):
        indices_perm = np.random.RandomState(seed = 42).permutation(indices)
        train_size = int(len(indices_perm) * (1 - split_ratio))
        valid_size = data_size - train_size
        train_indices = indices_perm[: train_size]
        valid_indices = indices_perm[-valid_size:]
        return train_indices, valid_indices
    
    def k_fold_indices(indices, stratification_metrics, k):
        skf = StratifiedKFold(n_splits = k, shuffle = True, random_state = 42)
        for train_indices, valid_indices in skf.split(indices, stratification_metrics):
#             print(train_indices.shape)
            if len(train_indices) % args['batch_size'] == 1:
                train_indices = np.pad(train_indices, (0, 1), 'edge')
            if len(valid_indices) % args['batch_size'] == 1:
                valid_indices = np.pad(valid_indices, (0, 1), 'edge')
            yield train_indices, valid_indices
    
    if args['use_cuda']:
        num_workers = 8
        pin_memory = True
    else:
        num_workers = 1
        pin_memory = False
        
    data_size = len(dataset)
    
    split_ratio = 0.2
    
    if args['k_fold'] is not None:
        for indices in k_fold_indices(np.arange(data_size), stratification_metrics, args['k_fold']):
            train_indices, valid_indices = indices
            train_sampler = data.sampler.SubsetRandomSampler(train_indices)
            valid_sampler = data.sampler.SubsetRandomSampler(valid_indices)
            train_loader = data.DataLoader(dataset, batch_size = args['batch_size'], sampler = train_sampler, 
                                   pin_memory = args['use_cuda'], num_workers = num_workers)
            valid_loader = data.DataLoader(dataset, batch_size = args['batch_size'], sampler = valid_sampler, 
                                   pin_memory = args['use_cuda'], num_workers = num_workers)
            yield train_loader, valid_loader
    else:
        train_indices, valid_indices = random_split_indices(np.arange(data_size), split_ratio)
        train_sampler = data.sampler.SubsetRandomSampler(train_indices)
        valid_sampler = data.sampler.SubsetRandomSampler(valid_indices)
        train_loader = data.DataLoader(dataset, batch_size = args['batch_size'], sampler = train_sampler, 
                               pin_memory = args['use_cuda'], num_workers = num_workers)
        valid_loader = data.DataLoader(dataset, batch_size = args['batch_size'], sampler = valid_sampler, 
                               pin_memory = args['use_cuda'], num_workers = num_workers)
        yield train_loader, valid_loader


## Compute Coverage Class by mask & Create Stratified K-fold CV indices

In [None]:
from tqdm import tqdm

def cov_to_class(val):
    for i in range(0, 11):
        if val * 10 <= i:
            return i

coverage = list()
for file_idx in tqdm(train_df['id']):
    mask_path = os.path.join(train_dataset.mask_folder, file_idx + ".png")
    mask = np.array(imageio.imread(mask_path), dtype = bool).astype(np.float32)
    coverage.append(np.mean(mask))
train_df['coverage_class'] = list(map(cov_to_class, coverage))

In [None]:
train_df.head()

## Define Training Process

In [None]:
def training(model, criterions, optimizer, train_dataloaders, args, lr_scheduler = None, fold = None, 
             start_cycle = 0, model_dir = './model', verbose = True):
    '''
        params:
            model: the model to be trained
            criterion: loss function
            optimizer: gradient descent optimizer
            dataloaders: the generator of k-fold training set dataloaders (see function 
                         get_dataloaders)
            args: arguments of training process
            lr_scheduler: if provided, changes learning rate according to given rules
            checkpoint: a string would be appended to the end of filename if provided
            fold: if not None, only train the model on the specified fold split
            model_dir: the directory path for saving/loading model
    '''
    if args['k_fold'] is None:
        fold = None
        
    cycle = start_cycle    
        
    init_filename = model.__class__.__name__ + '_init'
    
    save_checkpoint({'model': model.state_dict(),
                     'optimizer': optimizer.state_dict()
                    }, False, model_dir, init_filename, verbose)
    
    print ('start training...')

    if args['k_fold'] is not None and fold is None:
        best_valid_metrics = {
            'loss': [10e6] * args['k_fold'],
            'accu': [0.0] * args['k_fold'],
            'IoU': [0.0] * args['k_fold'],
            'AvgPrec': [0.0] * args['k_fold']
        }
    else:
        best_valid_metrics = {
            'loss': [10e6],
            'accu': [0.0],
            'IoU': [0.0],
            'AvgPrec': [0.0]
        }

    for fold_idx, dataloader in enumerate(train_dataloaders):

        if fold is not None:
            if fold_idx != fold:
                continue

        train_loader, valid_loader = dataloader

        init_ckpt = load_checkpoint(False, model_dir, init_filename, verbose)
        model.load_state_dict(init_ckpt['model'])
        optimizer.load_state_dict(init_ckpt['optimizer'])

        # for ReduceLROnPlateau only
        if lr_scheduler is not None:
            if lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                lr_scheduler._reset()

        metrics = {
            'train_losses': [],
            'train_accus': [],
            'valid_losses': [],
            'valid_accus': [],
            'valid_ious': [],
            'valid_avg_iou_precs': []
        }

        for epoch in np.arange(args['epochs']):

            if lr_scheduler is not None:
                if lr_scheduler.__class__.__name__ != 'ReduceLROnPlateau':
                    if lr_scheduler.__class__.__name__ == 'CosineAnnealingLR':
                        cur_cycle = start_cycle + epoch // lr_scheduler.T_max
                        if cycle != cur_cycle:
                            cycle = cur_cycle
                            best_valid_metrics = {
                                'loss': [10e6],
                                'accu': [0.0],
                                'IoU': [0.0],
                                'AvgPrec': [0.0]
                            }
                        if epoch % lr_scheduler.T_max == 0 and verbose:
                            print('cycle {}:'.format(cycle))
                        lr_scheduler.step(epoch % lr_scheduler.T_max)
                    else:
                        lr_scheduler.step(epoch)

            t0 = time.time()

            model.train()

            train_loss = 0.0
            train_accu = 0.0

            for i, batch_data in enumerate(train_loader, 0):
                samples = batch_data  
                imgs = samples['image']
                masks = samples['mask']
                deps = samples['depth'] 
                labels = samples['empty_mask']
                if args['use_cuda']:
                    imgs, deps, masks, labels = imgs.cuda(), deps.cuda(), masks.cuda(), labels.cuda()
#                 has_nan([imgs, deps, masks], 'training input for model: epoch: {}'.format(epoch + 1))
                optimizer.zero_grad()
                segs, binary = model((imgs, deps))
#                 has_nan(segs, 'computing training output: epoch: {}'.format(epoch + 1))
                ne_indices = (labels.squeeze(1) == 1)
                ne_preds, ne_masks = segs[ne_indices], masks[ne_indices]
                loss = criterions['all_segs'](segs, masks)
                loss += 0.1 * criterions['binary'](binary.reshape(-1, 1), labels)
                if len(ne_preds) > 0:
                    loss += 0.1 * criterions['non_empty'](ne_preds, ne_masks)
#                 has_nan(loss, 'computing training loss: epoch: {}'.format(epoch + 1))
                train_loss += loss.detach()
                if criterions['all_segs'].__class__.__name__ == 'LovaszHingeLoss':
                    train_accu += torch.mean((torch.eq(torch.ge(segs.detach(), 0).float(), masks)).float())
                else:
                    train_accu += torch.mean((torch.eq(segs.detach().sigmoid().round(), masks)).float())

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0)
                optimizer.step()

            metrics['train_losses'].append(train_loss / (i + 1))
            metrics['train_accus'].append(train_accu / (i + 1))

            if verbose:
                print('epoch [{0} / {1}]:'.format(epoch + 1, args['epochs']))
                print('train:  \nLoss: {:.4f}  Accu: {:.4f}'.format(train_loss / (i + 1), train_accu / (i + 1)))

            model.eval()

            valid_loss = 0.0
            valid_accu = 0.0
            valid_iou = 0.0
            valid_avg_iou_prec = 0.0

            with torch.no_grad():
                for i, batch_data in enumerate(valid_loader, 0):
                    samples = batch_data
                    imgs = samples['image']
                    masks = samples['mask']
                    deps = samples['depth'] 
                    labels = samples['empty_mask']
                    if args['use_cuda']:
                        imgs, deps, masks, labels = imgs.cuda(), deps.cuda(), masks.cuda(), labels.cuda()
#                     has_nan([imgs, deps, masks], 'validation input for model: epoch: {}'.format(epoch + 1))
                    segs, binary = model((imgs, deps))
#                     has_nan(segs, 'computing validation output: epoch: {}'.format(epoch + 1))
                    ne_indices = (labels.squeeze(1) == 1)
                    ne_preds, ne_masks = segs[ne_indices], masks[ne_indices]
                    loss = criterions['all_segs'](segs, masks)
                    loss += 0.1 * criterions['binary'](binary.reshape(-1, 1), labels)
                    if len(ne_preds) > 0:
                        loss += 0.1 * criterions['non_empty'](ne_preds, ne_masks)
#                     has_nan(loss, 'computing validation loss: epoch: {}'.format(epoch + 1))
                    valid_loss += loss.detach()
                    if criterions['all_segs'].__class__.__name__ == 'LovaszHingeLoss':
                        valid_accu += torch.mean((torch.eq(torch.ge(segs.detach(), 0).float(), masks)).float())
                    else:
                        valid_accu += torch.mean((torch.eq(segs.detach().sigmoid().round(), masks)).float())
                    iou = avg_iou_precision(segs.detach(), masks.detach(), criterions['all_segs'].__class__.__name__)
                    valid_iou += iou[0]
                    valid_avg_iou_prec += iou[1]

            metrics['valid_losses'].append(valid_loss / (i + 1))
            metrics['valid_accus'].append(valid_accu / (i + 1))
            metrics['valid_ious'].append(valid_iou / (i + 1))
            metrics['valid_avg_iou_precs'].append(valid_avg_iou_prec / (i + 1))

            if verbose:
                print('valid:  \nLoss: {:.4f}  Accu: {:.4f}  \nIoU: {}  \nAvgPrecision: {}'.\
                      format(valid_loss / (i + 1), valid_accu / (i + 1), valid_iou / (i + 1), 
                             valid_avg_iou_prec / (i + 1)))
                print('epoch time: {}s'.format(int(time.time() - t0)))

            '''
                 the 'epoch' here is only for verbose print, hence we use 'epoch + 1' 
                 to indicate at the end of which epoch it reduces lr. In StepLR, we 
                 use epoch instead.
            '''
            if lr_scheduler is not None:
                if lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    lr_scheduler.step(valid_iou / (i + 1), epoch = (epoch + 1))

            if fold is not None:
                idx = 0
            else:
                idx = fold_idx
            if best_valid_metrics['IoU'][idx] < metrics['valid_ious'][-1]:
                is_best = True
                best_valid_metrics['loss'][idx] = metrics['valid_losses'][-1]
                best_valid_metrics['accu'][idx] = metrics['valid_accus'][-1]
                best_valid_metrics['IoU'][idx] = metrics['valid_ious'][-1]
                best_valid_metrics['AvgPrec'][idx] = metrics['valid_avg_iou_precs'][-1]

            if is_best or (epoch % 10 == 9) or (epoch == args['epochs'] - 1):
                filename = model.__class__.__name__ + '_fold{}'.format(fold_idx) + '_cycle{}'.format(cycle)
                save_checkpoint({'epoch': epoch + 1,
                                 'model': model.state_dict(),
                                 'metrics': metrics,
                                 'best_valid_metrics': best_valid_metrics,
                                 'optimizer': optimizer.state_dict()
                                }, is_best, model_dir, filename, verbose)
                is_best = False

    summarize_CV(best_valid_metrics)

    print('training finished.')
    
    del_checkpoint(False, model_dir, init_filename, verbose)
    
    return metrics
    
    

# --------------------------------------- helper function ---------------------------------------


def has_nan(t, id_str = 'test'):
    '''
        params:
            t: torch.tensor or sequence of torch.tensor
            id_str: identification string to print
    '''
    def _check(t, id_str):
        if torch.sum(torch.isnan(t)):
            print('NaNs tensor \n{} \nfound at {}'.format(t, id_str))
            sys.exit()
            
    if torch.is_tensor(t):
        _check(t, id_str)
    else:
        for i, item in enumerate(t):
            _check(item, id_str + ' [{}/{}]'.format(i, len(t) - 1))
        

def summarize_CV(CV_dict):
    metrics = list(CV_dict.keys())
    print('Cross Validation on {}-fold: '.format(len(CV_dict[metrics[0]])))
    for metric in metrics:
        print('best {}:  mean: {:.4f} std: {:.4f}'.format(metric, np.mean(CV_dict[metric]), 
                                                          np.std(CV_dict[metric])))
    
    
def avg_iou_precision(preds, targets, loss = None):
    smooth = 10e-8
    thresholds_iou = np.linspace(0.50, 0.95, 10)
    if loss == 'LovaszHingeLoss':
        threshold_pred = 0.0
    else:
        threshold_pred = 0.5
    batch_size = preds.shape[0]
    targets = targets.byte()
    
    avg_precision = torch.zeros((batch_size, 1), device = preds.device)
    preds_t = torch.ge(preds, threshold_pred)
    intersection = (preds_t & targets).float().sum(dim = (2, 3))
    union = (preds_t | targets).float().sum(dim = (2, 3))
    iou = (intersection + smooth) / (union + smooth)
    for t_i in thresholds_iou:
        avg_precision += torch.ge(iou, t_i).float()
    avg_precision = avg_precision / len(thresholds_iou)
    return iou.mean(), avg_precision.mean()

    
def save_checkpoint(state: dict, is_best: bool, dir_path = './model', filename = 'ckpt', verbose = True):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    file_path = os.path.join(dir_path, filename + '.pth.tar')
    torch.save(state, file_path)
    if is_best:
        best_file_path = os.path.join(dir_path, 'best_' + filename + '.pth.tar')
        shutil.copyfile(file_path, best_file_path)
        if verbose:
            print('model saved at {}'.format(best_file_path))
    else:
        if verbose:
            print('model saved at {}'.format(file_path))
        

def load_checkpoint(is_best: bool, dir_path = './model', filename = 'ckpt', verbose = True):
    if is_best:
        file_path = os.path.join(dir_path, 'best_' + filename + '.pth.tar')
    else:
        file_path = os.path.join(dir_path, filename + '.pth.tar')
    if os.path.isfile(file_path):
        checkpoint = torch.load(file_path)
        mtime = ''
        if verbose:
            mtime = datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d %H:%M:%S")
            print ("checkpoint loaded '{}' {}".format(file_path, mtime))
        return checkpoint
    else:
        print('checkpoint {} not found.'.format(file_path))
        
def del_checkpoint(is_best: bool, dir_path = './model', filename = 'ckpt', verbose = True):
    if is_best:
        file_path = os.path.join(dir_path, 'best_' + filename + '.pth.tar')
    else:
        file_path = os.path.join(dir_path, filename + '.pth.tar')
    if os.path.isfile(file_path):
        mtime = ''
        if verbose:
            mtime = datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%Y-%m-%d %H:%M:%S")
            print ("checkpoint deleted '{}' {}".format(file_path, mtime))   
        os.remove(file_path)
    else:
        print('checkpoint {} not found.'.format(file_path))

## Define Building Blocks

In [None]:
'''
    vanilla conv2d with activation and batchnorm
'''
def _conv2d(in_channel, out_channel, kernel_size = 3, stride = 1, padding = 0, dilation = 1, 
            groups = 1, bn_acti = True):        
    block = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, dilation, groups)
    )
    
    if bn_acti:
        block.add_module('batchnorm', nn.BatchNorm2d(out_channel))
        block.add_module('activation', nn.ELU())
    return block


'''
    transposed conv2d with activation and batchnorm
'''
def _tconv2d(in_channel, out_channel, kernel_size = 3, stride = 1, padding = 0, out_padding = 0, 
             dilation = 1, groups = 1, bn_acti = True):
    block = nn.Sequential(
        nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, out_padding, 
                           groups, dilation = dilation)
    )
    
    if bn_acti:
        block.add_module('batchnorm', nn.BatchNorm2d(out_channel))
        block.add_module('activation', nn.ELU())
    return block   


'''
    vanilla contracting block in UNet
'''
class _contract(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size = 3, stride = 1, padding = 0, dropout = None):
        super(_contract, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        conv = list()
        if dropout is not None:
            conv = [nn.Dropout2d(p = dropout)]
        conv += [_conv2d(in_channel, out_channel, kernel_size, stride, padding),
                 _conv2d(out_channel, out_channel, kernel_size, stride, padding)]
        self.conv = nn.Sequential(*conv)
    
    def forward(self, x):
        output = self.conv(self.maxpool(x))
        
        return output
    

'''
    contracting block with: Maxpool -> Inception module as feature map size reduction
                            conv2d -> ResBlock as feature map increasing
'''
class _contract_v2(nn.Module):
    def __init__(self, in_channel, incept_channels, dim_reduces, inter_channel, out_channel, kernel_size = 3, 
                 stride = 1, padding = 0, dilation = 1, groups = 1, num_res = 2, dropout = None, scSE = False):
        super(_contract_v2, self).__init__()
        '''
            This module has two output heads. One is placed right after the Inception module 
            and the other is placed at the end.
            
            params:
                in_channel: input channel for Inception module
                out_channels: a dict of out_channel for Inception module
                dim_reduces: a dict of intermidiate feature map dimension reduction with conv1 
                             before conv3 and conv5
                kernel_size: kernel_size for last 3 ResBlocks (kernel_size for Inception module 
                             is fixed)
                stride: stride for last 3 ResBlocks (stride for Inception module is fixed)
                padding: padding for last two ResBlocks (padding for Inception module is fixed)
                num_res: number of resblocks after Inception module
                dropout: a intermidiate 2d dropout layer placed before ResBlocks
                scSE: whether to add scSE blocks at the end
                
            We intend to maintain in_channel = sum(out_channels.values()) to have the Inception 
            module act like a Maxpooling layer, however, it seems a little bit hard to do this.
            Hence, the sum(out_channels.values()) can be slightly larger than in_channel.
            The actual output channel for the last ResBlock = 2 * sum(out_channels.values())
        '''
        if not isinstance(padding, list):
            padding = [padding] * num_res
        if not isinstance(dilation, list):
            dilation = [dilation] * num_res
        if not isinstance(groups, list):
            groups = [groups] * num_res
        if not isinstance(inter_channel, list):
            inter_channel = [inter_channel] * num_res
        channels = [(sum(incept_channels.values()), inter_channel[0], out_channel)] + \
                   [(out_channel, inter_channel[i], out_channel) for i in range(num_res - 1)]
        
        self.Inception = nn.Sequential(
            _InceptionModule(in_channel, incept_channels, dim_reduces, stride = 2)
        )
        ResBlocks = list()
        if dropout is not None:
            ResBlocks = [nn.Dropout2d(dropout)]
        ResBlocks += [_ResBlock(*channels[i], kernel_size, stride, padding[i], dilation[i], groups[i]) 
                      for i in range(num_res)]
        self.ResBlocks = nn.Sequential(*ResBlocks)
        
        if scSE:
            self.Inception.add_module('scSE', _scSE(sum(incept_channels.values())))
            self.ResBlocks.add_module('scSE', _scSE(out_channel))
        
    def forward(self, x):
        inception = self.Inception(x)  
        output = self.ResBlocks(inception)
        
        return inception, output

    
'''
    vanilla expanding block in UNet
'''
class _expand(nn.Module):
    def __init__(self, in_channel, cat_channel, out_channel, kernel_size = 3, stride = 1, padding = 0, dropout = None):
        super(_expand, self).__init__()
        self.tconv = _tconv2d(in_channel, int(in_channel / 2), kernel_size = 2, stride = 2, bn_acti = False)
        conv = list()
        if dropout is not None:
            conv = [nn.Dropout2d(p = dropout)]
        conv += [_conv2d(int(in_channel / 2) + cat_channel, out_channel, kernel_size, stride, padding),
                 _conv2d(out_channel, out_channel, kernel_size, stride, padding)]
        self.conv = nn.Sequential(*conv)
        
    def forward(self, x, skip):
        output = self.conv(torch.cat([self.tconv(x), skip], dim = 1))
        
        return output
    
    
'''
    expanding block with: tconv2d -> tconv2d as feature map upsampling
                          conv2d -> ResBlock as feature map increasing
'''    
class _expand_v2(nn.Module):
    def __init__(self, in_channel, cat_channel, inter_channel, out_channel, kernel_size = 3, stride = 1, 
                 padding = 0, dilation = 1, groups = 1, num_res = 2, dropout = None, scSE = False):
        super(_expand_v2, self).__init__()
        '''
            params:
                in_channel: input channel for tconv2d layer
                cat_channel: channel for feature maps to be concatenated with the input
                out_channel: out_channel for last ResBlock
                kernel_size: kernel_size for last 3 ResBlocks (kernel_size for tconv2d layer 
                             is fixed)
                stride: stride for last 3 ResBlocks (stride for tconv2d layer is fixed)
                padding: padding for last 3 ResBlocks (padding for tconv2d layer is fixed)
                num_res: number of resblocks after tconv2d layer
                dropout: a intermidiate 2d dropout layer placed before ResBlocks
                scSE: whether to add scSE blocks at the end
        '''
        if not isinstance(padding, list):
            padding = [padding] * num_res
        if not isinstance(dilation, list):
            dilation = [dilation] * num_res
        if not isinstance(groups, list):
            groups = [groups] * num_res
        if not isinstance(inter_channel, list):
            inter_channel = [inter_channel] * num_res
        channels = [(int(in_channel / 2) + cat_channel, inter_channel[0], out_channel)] + \
                   [(out_channel, inter_channel[i], out_channel) for i in range(num_res - 1)]
        
        
        self.tconv = nn.Sequential(
            _tconv2d(in_channel, int(in_channel / 2), kernel_size = 2, stride = 2, bn_acti = False)
        )
        ResBlocks = list()
        if dropout is not None:
            ResBlocks = [nn.Dropout2d(dropout)]
        ResBlocks += [_ResBlock(*channels[i], kernel_size, stride, padding[i], dilation[i], groups[i]) 
                      for i in range(num_res)]
        self.ResBlocks = nn.Sequential(*ResBlocks)
        
        if scSE:
            self.ResBlocks.add_module('scSE', _scSE(out_channel))
        
    def forward(self, x, skip):
        tconv = self.tconv(x)
        output = self.ResBlocks(torch.cat([tconv, skip], dim = 1))
        
        return output
    

'''
    SE blocks
'''
class _cSE(nn.Module):
    def __init__(self, in_channel):
        super(_cSE, self).__init__()
        self.in_channel = in_channel
        self.project = nn.Sequential(
            nn.Linear(in_channel, int(in_channel / 2)),
            nn.ReLU(),
            nn.Linear(int(in_channel / 2), in_channel),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x_identity = x
        x = x.reshape(*(x.shape[: 2]), -1).mean(dim = -1)
        output = x_identity * self.project(x).reshape(-1, self.in_channel, 1, 1)
        
        return output
    
class _sSE(nn.Module):
    def __init__(self, in_channel):
        super(_sSE, self).__init__()
        self.in_channel = in_channel
        self.project = nn.Sequential(
            _conv2d(in_channel, 1, kernel_size = 1, bn_acti = False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x_identity = x
        output = x_identity * self.project(x)
        
        return output
    
class _scSE(nn.Module):
    def __init__(self, in_channel):
        super(_scSE, self).__init__()
        self.in_channel = in_channel
        self.cSE = _cSE(self.in_channel)
        self.sSE = _sSE(self.in_channel)
    
    def forward(self, x):
        output = self.cSE(x) + self.sSE(x)
            
        return output
    
    
'''
    zero padding input for skip connection in residual block
'''
def _zero_padding(x, target_dim):
    padding_dim = target_dim - x.shape[1]
    if padding_dim > 0:
        padding_shape = (x.shape[0], padding_dim, x.shape[2], x.shape[3])
        zero_padding = torch.zeros(padding_shape, dtype = x.dtype, device = x.device)
        x_padded = torch.cat([x, zero_padding], dim = 1)
        return x_padded
    return x


'''
    parallel conv2d with different kernel size concatenated together
'''
class _MultiConv2d(nn.Module):
    def __init__(self, in_channel, out_channels, kernel_sizes, paddings, strides = None, bn_acti = True):
        super(_MultiConv2d, self).__init__()
        assert len(kernel_sizes) == len(paddings), 'inconsistent number of args specified'
        
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes
        self.paddings = paddings
        self.strides = strides
        self.bn_acti = bn_acti
        
        if strides is None:
            self.strides = [1] * len(kernel_sizes)
        
        if isinstance(out_channels, int):
            self.out_channels = [out_channels] * len(kernel_sizes)
        
        self.conv2d_layers = nn.ModuleList([_conv2d(in_channel, out_channels[i], kernel_sizes[i], strides[i], 
                                                    paddings[i], bn_acti = False) for i in range(len(kernel_sizes))])
        
        self.bn_acti_layers = nn.Sequential()
        if self.bn_acti:
            self.bn_acti_layers.add_module('batchnorm', nn.BatchNorm2d(sum(self.out_channels)))
            self.bn_acti_layers.add_module('activation', nn.ELU())
    
    def forward(self, x):
        output = list()
        for layer in self.conv2d_layers:
            output.append(layer(x))
        
        output = self.bn_acti_layers(torch.cat(output, dim = 1))
        
        return output

    
'''
    residual (ResNet/ResNeXt) block in ResNeXt
'''
class _ResBlock(nn.Module):
    def __init__(self, in_channel, inter_channel, out_channel, kernel_size, stride = 1, padding = 0, 
                 dilation = 1, groups = 1, dropout = None, bn_acti = True):
        super(_ResBlock, self).__init__()
        self.in_channel = in_channel
        self.inter_channel = inter_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.dropout_rate = dropout
        self.bn_acti = bn_acti
        
        self.paths = nn.Sequential(
            _conv2d(in_channel, inter_channel, kernel_size = 1),
            _conv2d(inter_channel, inter_channel, kernel_size = kernel_size, stride = stride, 
                    padding = padding, dilation = dilation, groups = groups),
            _conv2d(inter_channel, out_channel, kernel_size = 1, bn_acti = False)
        )
        
        self.identity = nn.Sequential()
        if in_channel != out_channel:
            self.identity = _conv2d(in_channel, out_channel, kernel_size = 1, stride = stride, bn_acti = False)
        
        self.bn_acti_layers = nn.Sequential()
        if self.bn_acti:
            self.bn_acti_layers.add_module('batchnorm', nn.BatchNorm2d(out_channel))
            self.bn_acti_layers.add_module('activation', nn.ELU())
            
        self.dropout = nn.Sequential()
        if dropout is not None:
            self.dropout.add_module('dropout', nn.Dropout2d(self.dropout_rate))
        
    def forward(self, x):
        x_identity = self.identity(x)
        x_padded = _zero_padding(x_identity, self.out_channel)
        output = self.paths(x)
        output = output + x_padded

        output = self.dropout(self.bn_acti_layers(output))
            
        return output

'''
    Inception block in Inception V3
'''
class _InceptionModule(nn.Module):
    def __init__(self, in_channel, out_channels, dim_reduces, stride = 1, bn_acti = True):
        super(_InceptionModule, self).__init__()
        # out_channels = dict(conv1, conv3, conv5, maxpooling)
        # dim_reduces = dict(conv3, conv5)
        
        assert sum(out_channels.values()) >= in_channel, 'out_channels must be greater than or equal to in_channel'
        
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.dim_reduces = dim_reduces
        self.stride = stride
        self.bn_acti = bn_acti
        
        self.conv1 = _conv2d(in_channel, out_channels['conv1'], kernel_size = 1, stride = stride, bn_acti = False)
        self.conv3 = nn.Sequential(
            _conv2d(in_channel, dim_reduces['conv3'], kernel_size = 1, bn_acti = True),
            _conv2d(dim_reduces['conv3'], out_channels['conv3'], kernel_size = 3, stride = stride, padding = 1, bn_acti = False)
        )
        self.conv5 = nn.Sequential(
            _conv2d(in_channel, dim_reduces['conv5'], kernel_size = 1, bn_acti = True),
            _conv2d(dim_reduces['conv5'], out_channels['conv5'], kernel_size = (1, 5), stride = (1, stride), padding = (0, 2), bn_acti = False),
            _conv2d(out_channels['conv5'], out_channels['conv5'], kernel_size = (5, 1), stride = (stride, 1), padding = (2, 0), bn_acti = False)
        )
        
        if stride == 1:
            self.maxpool = nn.Sequential(
                nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
                _conv2d(in_channel, out_channels['maxpool'], kernel_size = 1, bn_acti = False)
            )
        elif stride == 2:
            self.maxpool = nn.Sequential(
                nn.MaxPool2d(kernel_size = 2, stride = 2),
                _conv2d(in_channel, out_channels['maxpool'], kernel_size = 1, bn_acti = False)
            )
        
        self.bn_acti_layers = nn.Sequential()
        if self.bn_acti:
            self.bn_acti_layers.add_module('batchnorm', nn.BatchNorm2d(sum(out_channels.values())))
            self.bn_acti_layers.add_module('activation', nn.ELU())

    def forward(self, x):
        
        output = [self.conv1(x), self.conv3(x), self.conv5(x), self.maxpool(x)]
        output = torch.cat(output, dim = 1)
        
        output = self.bn_acti_layers(output)
        
        return output
    
    
'''
    residual inception block in Inception V4
'''
class _ResInceptionModule(nn.Module):
    def __init__(self, in_channel, out_channels, dim_reduces, bn_acti = True):
        super(_ResInceptionModule, self).__init__()
        # out_channels = dict(conv1, conv3, conv5, maxpooling)
        # dim_reduces = dict(conv3, conv5)
        
        assert sum(out_channels.values()) >= in_channel, 'out_channels must be greater than or equal to in_channel'
        
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.dim_reduces = dim_reduces
        self.bn_acti = bn_acti
        
        self.conv1 = _conv2d(in_channel, out_channels['conv1'], kernel_size = 1, bn_acti = False)
        self.conv3 = nn.Sequential(
            _conv2d(in_channel, dim_reduces['conv3'], kernel_size = 1, bn_acti = True),
            _conv2d(dim_reduces['conv3'], out_channels['conv3'], kernel_size = 3, padding = 1, bn_acti = False)
        )
        self.conv5 = nn.Sequential(
            _conv2d(in_channel, dim_reduces['conv5'], kernel_size = 1, bn_acti = True),
            _conv2d(dim_reduces['conv5'], out_channels['conv5'], kernel_size = (1, 5), padding = (0, 2), bn_acti = False),
            _conv2d(out_channels['conv5'], out_channels['conv5'], kernel_size = (5, 1), padding = (2, 0), bn_acti = False)
        )
        self.maxpool = nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
            _conv2d(in_channel, out_channels['maxpool'], kernel_size = 1, bn_acti = False)
        )
        
        self.remap = _conv2d(sum(out_channels.values()), sum(out_channels.values()), kernel_size = 1, bn_acti = False)
        
        self.bn_acti_layers = nn.Sequential()
        if self.bn_acti:
            self.bn_acti_layers.add_module('batchnorm', nn.BatchNorm2d(sum(out_channels.values())))
            self.bn_acti_layers.add_module('activation', nn.ELU())

    def forward(self, x):
        output = [self.conv1(x), self.conv3(x), self.conv5(x), self.maxpool(x)]
        output = torch.cat(output, dim = 1)
        
        x_padded = _zero_padding(x, sum(self.out_channels.values()))
        output = x_padded + self.remap(output)
        
        output = self.bn_acti_layers(output)
        
        return output
        

## Define Models

In [None]:
class DenseUNet(nn.Module):
    def __init__(self):
        super(DenseUNet, self).__init__()
        self.img_head = _conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1)    # 128
        self.Encoder = nn.ModuleDict({
            'c1': _contract_v2(in_channel = 17,
                               out_channels = {'conv1': 4, 'conv3': 4, 'conv5': 4, 'maxpool': 8},
                               dim_reduces = {'conv3': 4, 'conv5': 4},
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               num_res = 2,
                               dropout = 0.25
                              ),    # 128 -> 64, 20, 40
            'c2': _contract_v2(in_channel = 60,    # 20 + 40
                               out_channels = {'conv1': 12, 'conv3': 12, 'conv5': 12, 'maxpool': 24},
                               dim_reduces = {'conv3': 12, 'conv5': 12},
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               num_res = 2,
                               dropout = 0.5
                              ),    # 64 -> 32, 60, 120
            'c3': _contract_v2(in_channel = 180,    # 60 + 120
                               out_channels = {'conv1': 36, 'conv3': 36, 'conv5': 36, 'maxpool': 72},
                               dim_reduces = {'conv3': 36, 'conv5': 36},
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               num_res = 2,
                               dropout = 0.5
                              ),    # 32 -> 16, 180, 360
            'c4': _contract_v2(in_channel = 540,    # 180 + 360
                               out_channels = {'conv1': 108, 'conv3': 108, 'conv5': 108, 'maxpool': 216},
                               dim_reduces = {'conv3': 108, 'conv5': 108},
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               num_res = 3,
                               dropout = 0.5
                              ),    # 16 -> 8, 540, 1080
        })
        self.Decoder = nn.ModuleDict({
            'e1': _expand_v2(1080, 360, 360, kernel_size = 3, stride = 1, padding = 1, num_res = 3, dropout = 0.5),    # 8 -> 16, 360
            'e2': _expand_v2(360, 120, 120, kernel_size = 3, stride = 1, padding = 1, num_res = 2, dropout = 0.5),    # 16 -> 32, 120
            'e3': _expand_v2(120, 40, 40, kernel_size = 3, stride = 1, padding = 1, num_res = 2, dropout = 0.5),    # 32 -> 64, 40
            'e4': _expand_v2(40, 17, 17, kernel_size = 3, stride = 1, padding = 1, num_res = 2, dropout = 0.5),    # 64 -> 128, 17
        })
        self.tail = nn.Sequential(
            _conv2d(17, 16, kernel_size = 3, stride = 1, padding = 1),
            _conv2d(16, 1, kernel_size = 1, bn_acti = False)
        )
            
    def forward(self, x):
        dep = x[1].expand_as(x[0])
        head = torch.cat([dep, self.img_head(x[0])], dim = 1)
        
        c1 = self.Encoder['c1'](head)
        c2 = self.Encoder['c2'](torch.cat(c1, dim = 1))
        c3 = self.Encoder['c3'](torch.cat(c2, dim = 1))
        c4 = self.Encoder['c4'](torch.cat(c3, dim = 1))
        
        e1 = self.Decoder['e1'](c4[1], c3[1])
        e2 = self.Decoder['e2'](e1, c2[1])
        e3 = self.Decoder['e3'](e2, c1[1])
        e4 = self.Decoder['e4'](e3, head)
        
        tail = self.tail(e4)
        
        return tail

In [None]:
class dsDSEUNeXt(nn.Module):
    def __init__(self):
        super(dsDSEUNeXt, self).__init__()
        self.img_head = _conv2d(1, 16, kernel_size = 3, stride = 1, padding = 1)    # 128
        self.Encoder = nn.ModuleDict({
            'c1': _contract_v2(in_channel = 17,
                               incept_channels = {'conv1': 4, 'conv3': 4, 'conv5': 4, 'maxpool': 8},
                               dim_reduces = {'conv3': 4, 'conv5': 4},
                               inter_channel = 20, 
                               out_channel = 40,
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               groups = 20,
                               num_res = 2,
                               dropout = None,
                               scSE = True),    # 128 -> 64, 20, 40
            'c2': _contract_v2(in_channel = 60,    # 20 + 40
                               incept_channels = {'conv1': 12, 'conv3': 12, 'conv5': 12, 'maxpool': 24},
                               dim_reduces = {'conv3': 12, 'conv5': 12},
                               inter_channel = 60,
                               out_channel = 120,
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               groups = 60,
                               num_res = 3,
                               dropout = None,
                               scSE = True),    # 64 -> 32, 60, 120
            'c3': _contract_v2(in_channel = 180,    # 60 + 120
                               incept_channels = {'conv1': 36, 'conv3': 36, 'conv5': 36, 'maxpool': 72},
                               dim_reduces = {'conv3': 36, 'conv5': 36},
                               inter_channel = 180,
                               out_channel = 360,
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               groups = 180,
                               num_res = 3,
                               dropout = None,
                               scSE = True),    # 32 -> 16, 180, 360
            'c4': _contract_v2(in_channel = 540,    # 180 + 360
                               incept_channels = {'conv1': 108, 'conv3': 108, 'conv5': 108, 'maxpool': 216},
                               dim_reduces = {'conv3': 108, 'conv5': 108},
                               inter_channel = 540,
                               out_channel = 1080,
                               kernel_size = 3, 
                               stride = 1, 
                               padding = 1,
                               groups = 540,
                               num_res = 4,
                               dropout = None,
                               scSE = True)    # 16 -> 8, 540, 1080
        })
        self.Decoder = nn.ModuleDict({
            'e1': _expand_v2(1080, 360, 180, 360, kernel_size = 3, stride = 1, padding = 1, groups = 180, 
                             num_res = 4, dropout = None, scSE = True),    # 8 -> 16, 360
            'e2': _expand_v2(360, 120, 60, 120, kernel_size = 3, stride = 1, padding = 1, groups = 60, 
                             num_res = 3, dropout = None, scSE = True),    # 16 -> 32, 120
            'e3': _expand_v2(120, 40, 20, 40, kernel_size = 3, stride = 1, padding = 1, groups = 20, 
                             num_res = 3, dropout = None, scSE = True),    # 32 -> 64, 40
            'e4': _expand_v2(40, 17, 8, 17, kernel_size = 3, stride = 1, padding = 1, groups = 8, 
                             num_res = 2, dropout = None, scSE = True)   # 64 -> 128, 17
        })
        self.tail = nn.Sequential(
            _conv2d(17, 16, kernel_size = 3, stride = 1, padding = 1),
            _scSE(16),
            _conv2d(16, 1, kernel_size = 1, bn_acti = False)
        )
        self.binary_classifier = nn.Sequential(
            _ResBlock(1080, 270, 540, kernel_size = 3, stride = 2, padding = 1),
            _ResBlock(540, 135, 270, kernel_size = 3, stride = 2, padding = 1),
            _scSE(270),
            nn.AvgPool2d(kernel_size = 2),
            _conv2d(270, 1, kernel_size = 1)
        )
            
    def forward(self, x):
        dep = x[1].expand_as(x[0])
        head = torch.cat([dep, self.img_head(x[0])], dim = 1)
        
        c1 = self.Encoder['c1'](head)
        c2 = self.Encoder['c2'](torch.cat(c1, dim = 1))
        c3 = self.Encoder['c3'](torch.cat(c2, dim = 1))
        c4 = self.Encoder['c4'](torch.cat(c3, dim = 1))
        
        e1 = self.Decoder['e1'](c4[1], c3[1])
        e2 = self.Decoder['e2'](e1, c2[1])
        e3 = self.Decoder['e3'](e2, c1[1])
        e4 = self.Decoder['e4'](e3, head)
        
        tail = self.tail(e4)
        
        binary_pred = self.binary_classifier(c4[1])
        
        return tail, binary_pred

## Customize Loss Functions

In [None]:
# https://discuss.pytorch.org/t/solved-class-weight-for-bceloss/3114
def weighted_BCELoss(output, target, weights=None):    
    if weights is not None:
        assert len(weights) == 2
        
        loss = weights[1] * (target * torch.log(output)) + \
               weights[0] * ((1 - target) * torch.log(1 - output))
    else:
        loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)

    return torch.neg(torch.mean(loss))

In [None]:
# https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
# a safe Binary Focal Loss function
class BFWithLogitsLoss(nn.Module):
    def __init__(self, gamma = 2, alpha = None, size_average = True):
        super(BFWithLogitsLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        # alpha[i] is the weight of class i
        if alpha is not None:
            self.alpha = torch.tensor(alpha)
        else:
            self.alpha = torch.ones(2)
        self.size_average = size_average

    def forward(self, logits, target):
        if not (target.shape == logits.shape):
            raise ValueError("Target size ({}) must be the same as input size ({})".\
                             format(target.size(), logits.size()))
        
        max_val = (-logits).clamp(min = 0)
        BCE = logits - logits * target + max_val + ((-max_val).exp() + (-max_val - logits).exp()).log()

        if logits.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        
        alpha_indexed = torch.index_select(self.alpha, dim = 0, index = target.long().view(-1)).reshape_as(target)
        pt = torch.sigmoid(-logits * (target * 2 - 1))
        
        loss = pt ** self.gamma * alpha_indexed * BCE
        
#         has_nan(loss, 'BFWithLogitsLoss function')
        
        if self.size_average: 
            return loss.mean()
        return loss.sum()


In [None]:
# https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py
def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels

def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.cumsum(0)
    union = gts + (1 - gt_sorted).cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

def lovasz_hinge_flat(logits, labels):
    """
        Binary Lovasz hinge loss
          logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
          labels: [P] Tensor, binary ground truth labels (0 or 1)
          ignore: label to ignore
          
        replaced F.relu to F.elu, in hoping for pushing the pred to a large margin even 
        when the original loss arrives at 0.
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels - 1.
    errors = (1. - logits * signs)
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted) + 1, grad)
    return loss

def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

class LovaszHingeLoss(nn.Module):
    def __init__(self, per_image = True, ignore = None):
        super(LovaszHingeLoss, self).__init__()
        self.per_image = per_image
        self.ignore = ignore

    def forward(self, logits, targets):
        """
        Binary Lovasz hinge loss
          logits: [B, 1, H, W] Variable, logits at each pixel (between -\infty and +\infty)
          labels: [B, 1, H, W] Tensor, binary ground truth masks (0 or 1)
          per_image: compute the loss per image instead of per batch
          ignore: void class id
        """
        logits = logits.squeeze(1)
        targets = targets.squeeze(1)
        if self.per_image:
            loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), self.ignore))
                              for log, lab in zip(logits, targets))
        else:
            loss = lovasz_hinge_flat(*flatten_binary_scores(logits, targets, self.ignore))
        return loss



## Training Model

In [None]:
!ls -la ../input

In [None]:
args = {
    'epochs': 80,
    'learning_rate': 0.1,
    'use_cuda': torch.cuda.is_available(),
    'batch_size': 16,
    'k_fold': 5
}

model = dsDSEUNeXt()
optimizer = torch.optim.SGD(model.parameters(), lr = args['learning_rate'], momentum = 0.9, weight_decay = 1e-5)

fold = 0
last_cycle = -1
# model_dir = '../input/tgs-salt-p1-final-fold{}-cycle{}/model'.format(fold, last_cycle)
# filename = model.__class__.__name__ + '_fold{}'.format(fold) + '_cycle{}'.format(last_cycle)
# ckpt = load_checkpoint(is_best = False, dir_path = model_dir, filename = filename)
# print('epoch: ', ckpt['epoch'])
# print('best validation metrics: \n', ckpt['best_valid_metrics'])

# model.load_state_dict(ckpt['model'])
# optimizer.load_state_dict(ckpt['optimizer'])

In [None]:
print('model summary:')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('total params:', total_params)
print(model)

# loss = nn.BCEWithLogitsLoss()
# loss = BFWithLogitsLoss(gamma = 0.8, alpha = (0.4, 0.6))
losses = {'non_empty': LovaszHingeLoss(), 
          'binary': nn.BCEWithLogitsLoss(), 
          'all_segs': LovaszHingeLoss()}

if args['use_cuda']:
    model = model.cuda()

# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
#                                                           mode = 'max', 
#                                                           factor = 0.5,
#                                                           patience = 8, 
#                                                           verbose = True,
#                                                           threshold = 10e-4,
#                                                           cooldown = 2,
#                                                           min_lr = 1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                          T_max = 80,
                                                          eta_min = 1e-5)

train_dataloaders = get_dataloader(train_dataset, args, train_df.coverage_class.values)

In [None]:
%%time
metrics = training(model, losses, optimizer, train_dataloaders, args, lr_scheduler, 
                   fold = fold, start_cycle = last_cycle + 1, model_dir = './model')

In [None]:
def plot_metrics(metrics: dict):
    colors = ['r', 'b', 'g', 'c', 'm', 'k', 'aquamarine', 'orange']
       
    plt.figure(figsize=(18,10))
    max_value = 0.0
    for i, item in enumerate(metrics.items()):
        color = colors[i]
        key, values = item
        x = np.arange(len(values))
        plt.plot(x, values, color, label = key)
        max_value = max(max_value, max(values))
    plt.grid()
    plt.legend()
    plt.yticks(np.arange(0.0, max_value, 0.1))
    plt.show()

In [None]:
plot_metrics(metrics)