## Imports

In [1]:
import os
from datetime import datetime
IDENTIFIER   = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

#-----------------------------------

#numerical libs
import math
import numpy as np
import random
import PIL
import cv2
import matplotlib


# std libs
import collections
from collections import defaultdict
import copy
import numbers
import inspect
import shutil
from timeit import default_timer as timer
import itertools
from collections import OrderedDict
from multiprocessing import Pool
import multiprocessing as mp

#from pprintpp import pprint, pformat
import json
import zipfile
from shutil import copyfile

import csv
import pandas as pd
import pickle
import glob
import sys
from distutils.dir_util import copy_tree
import time

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


# constant #
PI  = np.pi
INF = np.inf
EPS = 1e-12

def seed_py(seed):
    random.seed(seed)
    np.random.seed(seed)
    return seed

## Model and Data Parameters

In [2]:
image_size = 512
mask_size = 64
mask_size1 = 16
data_dir = './'
is_mixed_precision = True  #True #False
output_dir = './'

## Defining Classes and Functions

In [3]:
#  https://www.kaggle.com/lextoumbourou/radampytorch#radam.py
#  https://forums.fast.ai/t/meet-ranger-radam-lookahead-optimizer/52886/21
#  https://github.com/nachiket273/lookahead_pytorch
#  https://github.com/mgrankin/over9000

import math
import torch
from torch.optim.optimizer import Optimizer, required
from collections import defaultdict


class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                else:
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])

                p.data.copy_(p_data_fp32)

        return loss


In [4]:
# torch libs
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parallel.data_parallel import data_parallel

from torch.nn.utils.rnn import *


def seed_torch(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

In [5]:
#from lib.include import *
import os
import pickle
import sys
import pandas as pd
import shutil

import builtins
import re

class Struct(object):
    def __init__(self, is_copy=False, **kwargs):
        self.add(is_copy, **kwargs)

    def add(self, is_copy=False, **kwargs):
        #self.__dict__.update(kwargs)

        if is_copy == False:
            for key, value in kwargs.items():
                setattr(self, key, value)
        else:
            for key, value in kwargs.items():
                try:
                    setattr(self, key, copy.deepcopy(value))
                    #setattr(self, key, value.copy())
                except Exception:
                    setattr(self, key, value)

    def drop(self,  missing=None, **kwargs):
        drop_value = []
        for key, value in kwargs.items():
            try:
                delattr(self, key)
                drop_value.append(value)
            except:
                drop_value.append(missing)
        return drop_value

    def __str__(self):
        text =''
        for k,v in self.__dict__.items():
            text += '\t%s : %s\n'%(k, str(v))
        return text



# log ------------------------------------
def remove_comments(lines, token='#'):
    """ Generator. Strips comments and whitespace from input lines.
    """

    l = []
    for line in lines:
        s = line.split(token, 1)[0].strip()
        if s != '':
            l.append(s)
    return l


def open(file, mode=None, encoding=None):
    if mode == None: mode = 'r'

    if '/' in file:
        if 'w' or 'a' in mode:
            dir = os.path.dirname(file)
            if not os.path.isdir(dir):  os.makedirs(dir)

    f = builtins.open(file, mode=mode, encoding=encoding)
    return f


def remove(file):
    if os.path.exists(file): os.remove(file)


def empty(dir):
    if os.path.isdir(dir):
        shutil.rmtree(dir, ignore_errors=True)
    else:
        os.makedirs(dir)


# http://stackoverflow.com/questions/34950201/pycharm-print-end-r-statement-not-working
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout  #stdout
        self.file = None

    def open(self, file, mode=None):
        if mode is None: mode ='w'
        self.file = open(file, mode)

    def write(self, message, is_terminal=1, is_file=1 ):
        if '\r' in message: is_file=0

        if is_terminal == 1:
            self.terminal.write(message)
            self.terminal.flush()
            #time.sleep(1)

        if is_file == 1:
            self.file.write(message)
            self.file.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass

# io ------------------------------------
def write_list_to_file(list_file, strings):
    with open(list_file, 'w') as f:
        for s in strings:
            f.write('%s\n'%str(s))
    pass


def read_list_from_file(list_file, comment='#'):
    with open(list_file) as f:
        lines  = f.readlines()
    strings=[]
    for line in lines:
        if comment is not None:
            s = line.split(comment, 1)[0].strip()
        else:
            s = line.strip()
        if s != '':
            strings.append(s)
    return strings



def read_pickle_from_file(pickle_file):
    with open(pickle_file,'rb') as f:
        x = pickle.load(f)
    return x

def write_pickle_to_file(pickle_file, x):
    with open(pickle_file, 'wb') as f:
        pickle.dump(x, f, pickle.HIGHEST_PROTOCOL)



# backup ------------------------------------

#https://stackoverflow.com/questions/1855095/how-to-create-a-zip-archive-of-a-directory
def backup_project_as_zip(project_dir, zip_file):
    assert(os.path.isdir(project_dir))
    assert(os.path.isdir(os.path.dirname(zip_file)))
    shutil.make_archive(zip_file.replace('.zip',''), 'zip', project_dir)
    pass


# etc ------------------------------------
def time_to_str(t, mode='min'):
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)

    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)

    else:
        raise NotImplementedError


def np_float32_to_uint8(x, scale=255):
    return (x*scale).astype(np.uint8)

def np_uint8_to_float32(x, scale=255):
    return (x/scale).astype(np.float32)


def int_tuple(x):
    return tuple( [int(round(xx)) for xx in x] )




def df_loc_by_list(df, key, values):
    df = df.loc[df[key].isin(values)]
    df = df.assign(sort = pd.Categorical(df[key], categories=values, ordered=True))
    df = df.sort_values('sort')
    #df = df.reset_index()
    df = df.drop('sort', axis=1)
    return  df

In [8]:
from sklearn.metrics import roc_auc_score, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import average_precision_score


def np_loss_binary_cross_entropy(probability, truth):
    batch_size = len(probability)
    probability = probability.reshape(-1)
    truth = truth.reshape(-1)

    log_p_pos = -np.log(np.clip(probability,1e-5,1))
    log_p_neg = -np.log(np.clip(1-probability,1e-5,1))

    loss = log_p_pos[truth==1].sum() + log_p_neg[truth==0].sum()
    loss = loss/len(truth)
    return loss

def np_metric_map_curve_by_class(truth,probability):
    num_sample, num_label = probability.shape
    score = []
    for i in range(num_label):
        s = average_precision_score(truth[:,i], probability[:,i])
        score.append(s)
    score = np.array(score)
    return score

def np_metric_roc_auc(probability, truth):
    truth = truth.reshape(-1)
    probability = probability.reshape(-1)
    score = roc_auc_score(truth, probability)
    return score

In [6]:
#---------------------------------------------------------------------------------
COMMON_STRING = ''
if 1:
    seed = int(time.time())
    seed_py(seed)
    seed_torch(seed)

    torch.backends.cudnn.benchmark     = False  ##uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. -
    torch.backends.cudnn.enabled       = True
    torch.backends.cudnn.deterministic = True

    COMMON_STRING += '\tpytorch\n'
    COMMON_STRING += '\t\tseed = %d\n'%seed
    COMMON_STRING += '\t\ttorch.__version__              = %s\n'%torch.__version__
    COMMON_STRING += '\t\ttorch.version.cuda             = %s\n'%torch.version.cuda
    COMMON_STRING += '\t\ttorch.backends.cudnn.version() = %s\n'%torch.backends.cudnn.version()
    try:
        COMMON_STRING += '\t\tos[\'CUDA_VISIBLE_DEVICES\']     = %s\n'%os.environ['CUDA_VISIBLE_DEVICES']
        NUM_CUDA_DEVICES = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    except Exception:
        COMMON_STRING += '\t\tos[\'CUDA_VISIBLE_DEVICES\']     = None\n'
        NUM_CUDA_DEVICES = 1

    COMMON_STRING += '\t\ttorch.cuda.device_count()      = %d\n'%torch.cuda.device_count()
    COMMON_STRING += '\t\ttorch.cuda.get_device_properties() = %s\n' % str(torch.cuda.get_device_properties(0))[21:]

COMMON_STRING += '\n'

if __name__ == '__main__':
    print (COMMON_STRING)


	pytorch
		seed = 1629378109
		torch.__version__              = 1.8.1+cu102
		torch.version.cuda             = 10.2
		torch.backends.cudnn.version() = 7605
		os['CUDA_VISIBLE_DEVICES']     = None
		torch.cuda.device_count()      = 1
		torch.cuda.get_device_properties() = (name='Tesla V100-SXM2-32GB', major=7, minor=0, total_memory=32510MB, multi_processor_count=80)




## Defining Augmentations

In [7]:
# #--- flip ---
def do_random_hflip(image, mask):
    if np.random.rand()>0.5:
        image = cv2.flip(image,1)
        mask = cv2.flip(mask,1)
    return image, mask


# #--- geometric ---
def do_random_rotate(image, mask, mag=15 ):
    angle = np.random.uniform(-1, 1)*mag

    height, width = image.shape[:2]
    cx, cy = width // 2, height // 2

    transform = cv2.getRotationMatrix2D((cx, cy), -angle, 1.0)
    image = cv2.warpAffine(image, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    mask = cv2.warpAffine(mask, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)

    return image, mask


def do_random_scale( image, mask, mag=0.1 ):
    s = 1 + np.random.uniform(-1, 1)*mag
    height, width = image.shape[:2]
    w,h = int(s*width), int(s*height)
    if (h,w)==image.shape[:2]:
        return image, mask

    dst = np.array([
        [0,0],[width,height], [width,0], #[0,height],
    ]).astype(np.float32)

    if s>1:
        dx = np.random.choice(w-width)
        dy = np.random.choice(h-height)
        src = np.array([
            [-dx,-dy],[-dx+w,-dy+h], [-dx+w,-dy],#[-dx,-dy+h],#
        ]).astype(np.float32)
    if s<1:
        dx = np.random.choice(width-w)
        dy = np.random.choice(height-h)
        src = np.array([
            [dx,dy], [dx+w,dy+h], [dx+w,dy],#
        ]).astype(np.float32)

    transform = cv2.getAffineTransform(src, dst)
    image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    mask = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask


def do_random_stretch_y( image, mask, mag=0.25 ):
    s = 1 + np.random.uniform(-1, 1)*mag
    height, width = image.shape[:2]
    h = int(s*height)
    w = width
    if h==height:
        return image, mask

    dst = np.array([
        [0,0],[width,height], [width,0], #[0,height],
    ]).astype(np.float32)


    if s>1:
        dx = 0#np.random.choice(w-width)
        dy = np.random.choice(h-height)
        src = np.array([
            [-dx,-dy],[-dx+w,-dy+h], [-dx+w,-dy],#[-dx,-dy+h],#
        ]).astype(np.float32)
    if s<1:
        dx = 0#np.random.choice(width-w)
        dy = np.random.choice(height-h)
        src = np.array([
            [dx,dy], [dx+w,dy+h], [dx+w,dy],#
        ]).astype(np.float32)

    transform = cv2.getAffineTransform(src, dst)
    image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    mask = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask



def do_random_stretch_x( image, mask, mag=0.25 ):
    s = 1 + np.random.uniform(-1, 1)*mag
    height, width = image.shape[:2]
    h = height
    w = int(s*width)
    if w==width:
        return image, mask

    dst = np.array([
        [0,0],[width,height], [width,0], #[0,height],
    ]).astype(np.float32)

    if s>1:
        dx = np.random.choice(w-width)
        dy = 0#np.random.choice(h-height)
        src = np.array([
            [-dx,-dy],[-dx+w,-dy+h], [-dx+w,-dy],#[-dx,-dy+h],#
        ]).astype(np.float32)
    if s<1:
        dx = np.random.choice(width-w)
        dy = 0#np.random.choice(height-h)
        src = np.array([
            [dx,dy], [dx+w,dy+h], [dx+w,dy],#
        ]).astype(np.float32)

    transform = cv2.getAffineTransform(src, dst)
    image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    mask = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask


def do_random_shift( image, mask, mag=32 ):
    b = mag
    height, width = image.shape[:2]

    image = cv2.copyMakeBorder(image, b,b,b,b, borderType=cv2.BORDER_CONSTANT, value=0)
    mask  = cv2.copyMakeBorder(mask, b,b,b,b, borderType=cv2.BORDER_CONSTANT, value=0)
    x = np.random.randint(0,2*b)
    y = np.random.randint(0,2*b)
    image = image[y:y+height,x:x+width]
    mask = mask[y:y+height,x:x+width]

    return image, mask

###########################################################################################3



# #--- noise ---
def do_random_blurout(image, size=0.20, num_cut=3):
    height, width = image.shape[:2]
    size = int(size*(height+width)/2)
    for t in range(num_cut):
        x = np.random.randint(0,width- size)
        y = np.random.randint(0,height-size)
        x0 = x
        x1 = x+size
        y0 = y
        y1 = y+size
        image[y0:y1,x0:x1]=image[y0:y1,x0:x1].mean()

    return image

def do_random_guassian_blur(image, mag=[0.1, 2.0]):
    sigma = np.random.uniform(mag[0],mag[1])
    image = cv2.GaussianBlur(image, (23, 23), sigma)
    return image

def do_random_noise(image, mag=0.08):
    height, width = image.shape[:2]

    image = image.astype(np.float32)/255
    noise = np.random.uniform(-1,1,size=(height,width))*mag
    image = image+noise

    image = np.clip(image,0,1)
    image = (image*255).astype(np.uint8)
    return image



# # --- intensity ---
def do_random_intensity_shift_contast(image, mag=[0.3,0.2]):
    image = (image).astype(np.float32)/255
    alpha0 = 1 + random.uniform(-1,1)*mag[0]
    alpha1 = random.uniform(-1,1)*mag[1]
    image = (image+alpha1)
    image = np.clip(image,0,1)
    image = image**alpha0
    image = np.clip(image,0,1)
    image = (image*255).astype(np.uint8)
    return image

#https://answers.opencv.org/question/12024/use-of-clahe/)
def do_random_clahe(image, mag=[[2,4],[6,12]]):
    l = np.random.uniform(*mag[0])
    g = np.random.randint(*mag[1])
    clahe = cv2.createCLAHE(clipLimit=l, tileGridSize=(g, g))

    image = clahe.apply(image)
    return image

# https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py
def do_histogram_norm(image, mag=[[2,4],[6,12]]):
    num_bin = 255

    histogram, bin = np.histogram( image.flatten(), num_bin, density=True)
    cdf = histogram.cumsum()  # cumulative distribution function
    cdf = 255 * cdf / cdf[-1]  # normalize

    # use linear interpolation of cdf to find new pixel values
    equalized = np.interp(image.flatten(), bin[:-1], cdf)
    image = equalized.reshape(image.shape)
    return image


## Dataloader Class and Functions

In [9]:

def make_fold(mode='train-1'):
    if 'train' in mode:
        df_study = pd.read_csv(data_dir+'/df_study_split_binary_negative_eb5ns_eb6eb6_ns_4024.csv')
        df_study['set'] = "train"
        df = df_study.copy()

        fold = int(mode[-1])
        df_train = df[df.fold != fold].reset_index(drop=True)
        df_valid = df[df.fold == fold].reset_index(drop=True)
        return df_train, df_valid

    if 'test' in mode:
        print("Please use Inference Pipeline")

def null_augment(r):
    image = r['image']
    return r


class SiimDataset(Dataset):
    def __init__(self, df, augment=null_augment):
        super().__init__()
        self.df = df
        self.augment = augment
        self.length = len(df)

    def __str__(self):
        string  = ''
        string += '\tlen = %d\n'%len(self)
        string += '\tdf  = %s\n'%str(self.df.shape)

        string += '\tlabel distribution\n'
        for i in range(2):
            if i == 0 :
                n = self.df['none'].sum()
                n = len(self.df) - n
            if i == 1:
                n = self.df['none'].sum()
            string += '\t\t %d %26s: %5d (%0.4f)\n'%(i, 'none', n, n/len(self.df) )
        return string


    def __len__(self):
        return self.length

    def __getitem__(self, index):
        d = self.df.iloc[index]
        image_file = data_dir+'/%s/%s.png' % (d.set, d.image)
        image = cv2.resize(cv2.imread(image_file,cv2.IMREAD_GRAYSCALE), (image_size, image_size))
        onehot = d[['none']].values
        onehotns = d[['negative']].values

        if d.set == 'train':
            mask_file = data_dir+'/%s_mask/%s.png' % (d.set, d.image)
            mask = cv2.imread(mask_file,cv2.IMREAD_GRAYSCALE)
            try:
                mask = cv2.resize(mask, (image_size,image_size))
            except:
                mask = np.zeros_like(image)
        else:
            mask = np.zeros_like(image)

        r = {
            'index' : index,
            'd' : d,
            'image' : image,
            'mask' : mask,
            'onehot' : onehot,
            'onehotns' : onehotns,
        }
        if self.augment is not None: r = self.augment(r)
        return r


def null_collate(batch):
    collate = defaultdict(list)

    for r in batch:
        for k, v in r.items():
            collate[k].append(v)

    # ---
    batch_size = len(batch)
    onehot = np.ascontiguousarray(np.stack(collate['onehot'])).astype(np.float32)
    collate['onehot'] = torch.from_numpy(onehot)
    onehotns = np.ascontiguousarray(np.stack(collate['onehotns'])).astype(np.float32)
    collate['onehotns'] = torch.from_numpy(onehotns)

    image = np.stack(collate['image'])
    image = image.reshape(batch_size, 1, image_size,image_size).repeat(3,1)
    image = np.ascontiguousarray(image)
    image = image.astype(np.float32) / 255
    collate['image'] = torch.from_numpy(image)


    mask = np.stack(collate['mask'])
    mask = mask.reshape(batch_size, 1, image_size,image_size)
    mask = np.ascontiguousarray(mask)
    mask = mask.astype(np.float32) / 255
    collate['mask'] = torch.from_numpy(mask)

    return collate


## Defining the Model

In [10]:
from timm.models.efficientnet import *

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        e = tf_efficientnet_b6_ns(pretrained=True, drop_rate=0.5, drop_path_rate=0.3)
        self.b0 = nn.Sequential(
            e.conv_stem,
            e.bn1,
            e.act1,
        )
        self.b1 = e.blocks[0]
        self.b2 = e.blocks[1]
        self.b3 = e.blocks[2]
        self.b4 = e.blocks[3]
        self.b5 = e.blocks[4]
        self.b6 = e.blocks[5]
        self.b7 = e.blocks[6]
        self.b8 = nn.Sequential(
            e.conv_head, #for eb3 384, 1536 #for eb7 640,2560
            e.bn2,
            e.act2,
        )

        self.logit = nn.Linear(2304,1)
        self.mask = nn.Sequential(
            nn.Conv2d(72, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=1, padding=0),
        )
        
        self.mask1 = nn.Sequential(
            nn.Conv2d(344, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=1, padding=0),
        )


    # @torch.cuda.amp.autocast()
    def forward(self, image):
        batch_size = len(image)
        x = 2*image-1     # ; print('input ',   x.shape)

        x = self.b0(x) #; print (x.shape)  # torch.Size([2, 40, 256, 256])
        x = self.b1(x) #; print (x.shape)  # torch.Size([2, 24, 256, 256])
        x = self.b2(x) #; print (x.shape)  # torch.Size([2, 32, 128, 128])
        x = self.b3(x) #; print (x.shape)  # torch.Size([2, 48, 64, 64])
        mask = self.mask(x)
        x = self.b4(x) #; print (x.shape)  # torch.Size([2, 96, 32, 32])
        x = self.b5(x) #; print (x.shape)  # torch.Size([2, 136, 32, 32])
        #------------
        #-------------
        x = self.b6(x) #; print (x.shape)  # torch.Size([2, 232, 16, 16])
        mask1 = self.mask1(x)
        x = self.b7(x) #; print (x.shape)  # torch.Size([2, 384, 16, 16])
        x = self.b8(x) #; print (x.shape)  # torch.Size([2, 1536, 16, 16])
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
        logit = self.logit(x)
        return logit, mask, mask1

# check #################################################################

def run_check_net():
    batch_size = 2
    C, H, W = 3, 512, 512
    image = torch.randn(batch_size, C, H, W).cuda()
    mask  = torch.randn(batch_size, 1, H, W).cuda()

    net = Net().cuda()
    logit, mask, mask1 = net(image)

    print(image.shape)
    print(logit.shape)
    print(mask.shape)
    print(mask1.shape)

In [11]:
# -----------------------------------
# Loss function Libarary
from pytorch_toolbelt.losses import (
    BinaryLovaszLoss,
    DiceLoss,
    BinaryFocalLoss,
)
from pytorch_toolbelt.losses import BinaryFocalLoss, JointLoss
#----------------
import os
import torch.cuda.amp as amp
from madgrad import MADGRAD
import torch.optim as optim

class AmpNet(Net):
    @torch.cuda.amp.autocast()
    def forward(self,*args):
        return super(AmpNet, self).forward(*args)


In [12]:
def train_augment(r):
    image = r['image']
    mask = r['mask']

    if 1:
        for fn in np.random.choice([
            lambda image, mask : do_random_scale(image, mask, mag=0.30),
            lambda image, mask : do_random_stretch_y(image, mask, mag=0.30),
            lambda image, mask : do_random_stretch_x(image, mask, mag=0.30),
            lambda image, mask : do_random_shift(image, mask, mag=int(0.30*image_size)),
            lambda image, mask : (image, mask)
        ],4):
            image, mask = fn(image, mask)

        for fn in np.random.choice([
            lambda image, mask : do_random_rotate(image, mask, mag=20),
            lambda image, mask : do_random_hflip(image, mask),
            lambda image, mask : (image, mask)
        ],3):
            image, mask = fn(image, mask)

        # ------------------------
        for fn in np.random.choice([
            lambda image : do_random_intensity_shift_contast(image, mag=[0.5,0.5]),
            lambda image : do_random_noise(image, mag=0.30),
            lambda image : do_random_guassian_blur(image),
            lambda image : do_random_blurout(image, size=0.30, num_cut=2),
            lambda image : do_histogram_norm(image),
            lambda image : image,
        ],4):
            image = fn(image)

    r['image'] = image
    r['mask'] = mask
    return r



def do_valid(net, valid_loader, criterion):

    valid_probability = []
    valid_truth = []
    valid_num = 0

    net.eval()
    start_timer = timer()
    for t, batch in enumerate(valid_loader):
        batch_size = len(batch['index'])
        image = batch['image'].cuda()
        label = batch['onehot']

        with torch.no_grad():
                logit, mask, mask1 = net(image)
                probability = logit.sigmoid()

        valid_num += batch_size
        valid_probability.append(probability.data.cpu().numpy())
        valid_truth.append(label.data.cpu().numpy())
        print('\r %8d / %d  %s'%(valid_num, len(valid_loader.dataset),time_to_str(timer() - start_timer,'sec')),end='',flush=True)

    assert(valid_num == len(valid_loader.dataset))

    truth = np.concatenate(valid_truth)
    probability = np.concatenate(valid_probability)
    predict = probability.argsort(-1)[::-1]

    loss = np_loss_binary_cross_entropy(probability,truth)

    mapp  = np_metric_map_curve_by_class(truth, probability)*(1/6)
    auc = np_metric_roc_auc(probability, truth)

    return [loss, auc, mapp]



# start here ! ###################################################################################


def run_train():

    for fold in [0,1,2,3,4]:
        out_dir = output_dir+'result/eb6ns_512_binary_noisy/fold%d'%fold
        #initial_checkpoint = output_dir+'result/eb6ns_512_binary_noisy/fold%d/checkpoint/best_model_map.pth'%fold
        initial_checkpoint = None
        
        start_iteration = 0
        
        if initial_checkpoint is not None:
            f = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
            start_iteration = f['iteration']

        start_lr   = 0.0001
        batch_size = 16

        num_iteration = start_iteration + 50
        iter_log    = 50
        iter_valid  = 50
        iter_save   = list(range(0, num_iteration+1, 50))

        ## setup  ----------------------------------------
        for f in ['checkpoint', 'train', 'valid', 'backup']: os.makedirs(out_dir + '/' + f, exist_ok=True)

        log = Logger()
        log.open(out_dir + '/log.train.txt', mode='a')
        log.write('\n--- [START %s] %s\n\n' % (IDENTIFIER, '-' * 64))
        log.write('\t%s\n' % COMMON_STRING)
        log.write('\tout_dir  = %s\n' % out_dir)
        log.write('\n')

        ## dataset ------------------------------------
        df_train, df_valid = make_fold('train-%d'%fold)
        train_dataset = SiimDataset(df_train, train_augment)
        valid_dataset = SiimDataset(df_valid, )

        train_loader = DataLoader(
            train_dataset,
            sampler = RandomSampler(train_dataset),
            batch_size = batch_size,
            drop_last   = True,
            num_workers = 4,
            pin_memory  = True,
            worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2 ** 32 + id),
            collate_fn  = null_collate,
        )
        valid_loader  = DataLoader(
            valid_dataset,
            sampler = SequentialSampler(valid_dataset),
            batch_size  = batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate,
        )

        log.write('train_dataset : \n%s\n'%(train_dataset))
        log.write('valid_dataset : \n%s\n'%(valid_dataset))
        log.write('\n')


        ## net ----------------------------------------
        log.write('** net setting **\n')
        if is_mixed_precision:
            scaler = amp.GradScaler()
            net = AmpNet().cuda()
        else:
            net = Net().cuda()


        if initial_checkpoint is not None:
            f = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
            start_iteration = f['iteration']
            start_epoch = f['epoch']
            state_dict  = f['state_dict']
            net.load_state_dict(state_dict,strict=True)  #True
        else:
            start_iteration = 0
            start_epoch = 0
        
        

        log.write('net=%s\n'%(type(net)))
        log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint)
        log.write('\n')
        
        criterion = nn.BCEWithLogitsLoss()
        lovasz_loss=BinaryLovaszLoss()

        if torch.cuda.device_count() > 1:
            log.write("Let's use %d GPUs! \n" % (torch.cuda.device_count()))
            net = nn.DataParallel(net)


        optimizer = RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=start_lr)
        #optimizer = MADGRAD( filter(lambda p: p.requires_grad, net.parameters()), lr=start_lr, momentum= 0.9, weight_decay= 1e-06, eps= 1e-06)
        
        log.write('optimizer\n  %s\n'%(optimizer))
        log.write('\n')


        ## start training here! ##############################################
        log.write('** start training here! **\n')
        log.write('   fold = %d\n'%(fold))
        log.write('   is_mixed_precision = %s \n'%str(is_mixed_precision))
        log.write('   batch_size = %d\n'%(batch_size))
        log.write('                             |----- VALID --------|---- TRAIN/BATCH --------------\n')
        log.write('rate        iter    epoch    | loss    AUC    MAP | loss0  loss1  loss2 | time          \n')
        log.write('----------------------------------------------------------------------\n')
                  #0.00000   0.00* 0.00  | 0.000  0.000  | 0.000  0.000  |  0 hr 00 min

        def message(mode='print'):
            if mode==('print'):
                asterisk = ' '
                loss = batch_loss
            if mode==('log'):
                asterisk = '*' if iteration in iter_save else ' '
                loss = train_loss

            text = \
                '%0.7f  %5.4f%s %4.2f  | '%(rate, iteration/10000, asterisk, epoch,) +\
                '%4.4f  %4.5f  %4.5f  | '%(*valid_loss,) +\
                '%4.3f  %4.3f  %4.3f  | '%(*loss,) +\
                '%s' % (time_to_str(timer() - start_timer,'min'))

            return text

        #----
        valid_loss = np.zeros(3,np.float32)
        train_loss = np.zeros(3,np.float32)
        batch_loss = np.zeros_like(train_loss)
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = 0
        loss0 = torch.FloatTensor([0]).cuda().sum()
        loss1 = torch.FloatTensor([0]).cuda().sum()
        loss2 = torch.FloatTensor([0]).cuda().sum()


        start_timer = timer()
        iteration = start_iteration
        epoch = start_epoch
        rate = 0
        auc_metric = 0
        map_metric = 0
        while  iteration < num_iteration:

            for t, batch in enumerate(train_loader):

                if (iteration % iter_valid == 0):
                        valid_loss = do_valid(net, valid_loader, criterion)  #
                        auc_val = valid_loss[1]
                        map_val = valid_loss[2]
                        if auc_val >  auc_metric:
                            auc_metric = auc_val
                            if iteration in iter_save:
                                if iteration != start_iteration:
                                    torch.save({
                                        'state_dict': net.state_dict(),
                                        'iteration': iteration,
                                        'epoch': epoch,
                                    }, out_dir + '/checkpoint/best_model_auc.pth')
                        if map_val >  map_metric:
                            map_metric = map_val
                            if iteration in iter_save:
                                if iteration != start_iteration:
                                    torch.save({
                                        'state_dict': net.state_dict(),
                                        'iteration': iteration,
                                        'epoch': epoch,
                                    }, out_dir + '/checkpoint/best_model_map.pth')
                                    pass
                        pass

                if (iteration % iter_log == 0):
                    print('\r', end='', flush=True)
                    log.write(message(mode='log') + '\n')

                rate = optimizer.param_groups[0]["lr"]

                # one iteration update  -------------
                batch_size = len(batch['index'])
                image = batch['image'].cuda()
                truthmask = batch['mask'].cuda()
                truth_mask = F.interpolate(truthmask, size=(mask_size,mask_size), mode='bilinear', align_corners=False)
                truth_mask1 = F.interpolate(truthmask, size=(mask_size1,mask_size1), mode='bilinear', align_corners=False)
                label = batch['onehot'].cuda()
                labelns = batch['onehotns'].cuda()

                #----
                net.train()
                optimizer.zero_grad()

                if is_mixed_precision:
                    with amp.autocast():
                        logit, mask, mask1 = net(image)
                        loss0 = criterion(logit, label)
                        loss1 = lovasz_loss(mask, truth_mask)
                        loss2 = lovasz_loss(mask1, truth_mask1)
                        distill_loss = F.binary_cross_entropy_with_logits(logit, labelns)
                        cls_loss = loss0 * (1 - 0.5) + distill_loss * 0.5

                    scaler.scale(cls_loss+loss1+loss2).backward()
                    scaler.unscale_(optimizer)
                    scaler.step(optimizer)
                    scaler.update()

                # print statistics  --------
                epoch += 1 / len(train_loader)
                iteration += 1

                batch_loss = np.array([loss0.item(), loss1.item(), loss2.item()])
                sum_train_loss += batch_loss
                sum_train += 1
                if iteration % 100 == 0:
                    train_loss = sum_train_loss / (sum_train + 1e-12)
                    sum_train_loss[...] = 0
                    sum_train = 0

                print('\r', end='', flush=True)
                print(message(mode='print'), end='', flush=True)
        log.write('\n')


In [None]:
if __name__ == '__main__':
    run_train()