In [None]:
!ls ../input/effb5-mishfpn

In [None]:
!ls ../usr/lib

In [None]:
import numpy as np
import pandas as pd
import os
import sys

import random 
from timeit import default_timer as timer
import cv2
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, Sampler
import torch.utils.data as data
import torchvision.models as models
import torch.nn as nn
from torch.nn import functional as F
import torch

from mishefficientnet_hengs import EfficientNetB5
from mishefficientnet_hengs import CONVERSION
from misheffnet_b5utility  import *
# from heng_s_utility_functions import *

PI = np.pi
IMAGE_RGB_MEAN = [0.485, 0.456, 0.406]
IMAGE_RGB_STD  = [0.229, 0.224, 0.225]
DEFECT_COLOR = [(0,0,0),(0,0,255),(0,255,0),(255,0,0),(0,255,255)]
SEED = 69

In [None]:
def seed_everything(seed=SEED):
   random.seed(seed)
   os.environ['PYTHONHASHSEED'] = str(seed)
   np.random.seed(seed)
   torch.manual_seed(seed)
   torch.cuda.manual_seed(seed)
   torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
SPLIT_DIR = '../input/hengs-split'
DATA_DIR = '../input/severstal-mine'

In [None]:
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout  #stdout
        self.file = None

    def open(self, file, mode=None):
        if mode is None: mode ='w'
        self.file = open(file, mode)

    def write(self, message, is_terminal=1, is_file=1 ):
        if '\r' in message: is_file=0

        if is_terminal == 1:
            self.terminal.write(message)
            self.terminal.flush()
            #time.sleep(1)

        if is_file == 1:
            self.file.write(message)
            self.file.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass


In [None]:
class SteelDataset(Dataset):
    def __init__(self, split, csv, mode, augment=None):

        self.split   = split
        self.csv     = csv
        self.mode    = mode
        self.augment = augment

        self.uid = list(np.concatenate([np.load(SPLIT_DIR + '/%s'%f , allow_pickle=True) for f in split]))
        df = pd.concat([pd.read_csv(DATA_DIR + '/%s'%f).fillna('') for f in csv])

        df['Class'] = df['ImageId_ClassId'].str[-1].astype(np.int32)
        df['Label'] = (df['EncodedPixels']!='').astype(np.int32)
        df = df_loc_by_list(df, 'ImageId_ClassId', [ u.split('/')[-1] + '_%d'%c  for u in self.uid for c in [1,2,3,4] ])
        self.df = df
        self.num_image = len(df)//4


    def __str__(self):
        num1 = (self.df['Class']==1).sum()
        num2 = (self.df['Class']==2).sum()
        num3 = (self.df['Class']==3).sum()
        num4 = (self.df['Class']==4).sum()
        pos1 = ((self.df['Class']==1) & (self.df['Label']==1)).sum()
        pos2 = ((self.df['Class']==2) & (self.df['Label']==1)).sum()
        pos3 = ((self.df['Class']==3) & (self.df['Label']==1)).sum()
        pos4 = ((self.df['Class']==4) & (self.df['Label']==1)).sum()
        neg1 = num1-pos1
        neg2 = num2-pos2
        neg3 = num3-pos3
        neg4 = num4-pos4

        length = len(self)
        num = len(self)
        pos = (self.df['Label']==1).sum()
        neg = num-pos

        #---

        string  = ''
        string += '\tmode    = %s\n'%self.mode
        string += '\tsplit   = %s\n'%self.split
        string += '\tcsv     = %s\n'%str(self.csv)
        string += '\tnum_image = %8d\n'%self.num_image
        string += '\tlen       = %8d\n'%len(self)
        if self.mode == 'train':
            string += '\t\tpos1, neg1 = %5d  %0.3f,  %5d  %0.3f\n'%(pos1,pos1/num,neg1,neg1/num)
            string += '\t\tpos2, neg2 = %5d  %0.3f,  %5d  %0.3f\n'%(pos2,pos2/num,neg2,neg2/num)
            string += '\t\tpos3, neg3 = %5d  %0.3f,  %5d  %0.3f\n'%(pos3,pos3/num,neg3,neg3/num)
            string += '\t\tpos4, neg4 = %5d  %0.3f,  %5d  %0.3f\n'%(pos4,pos4/num,neg4,neg4/num)
        return string


    def __len__(self):
        return len(self.uid)


    def __getitem__(self, index):
        # print(index)
        folder, image_id = self.uid[index].split('/')

        rle = [
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_1','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_2','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_3','EncodedPixels'].values[0],
            self.df.loc[self.df['ImageId_ClassId']==image_id + '_4','EncodedPixels'].values[0],
        ]
        image = cv2.imread(DATA_DIR + '/%s/%s'%(folder,image_id), cv2.IMREAD_COLOR)
        label = [ 0 if r=='' else 1 for r in rle]
        mask  = np.array([run_length_decode(r, height=256, width=1600, fill_value=c) for c,r in zip([1,2,3,4],rle)])
        mask  = mask.max(0, keepdims=0)

        infor = Struct(
            index    = index,
            folder   = folder,
            image_id = image_id,
        )

        if self.augment is None:
            return image, label, mask, infor
        else:
            return self.augment(image, label, mask, infor)

In [None]:
class FiveBalanceClassSampler(Sampler):
    def __init__(self, dataset):
        self.dataset = dataset

        label = (self.dataset.df['Label'].values)
        label = label.reshape(-1,4)
        label = np.hstack([label.sum(1,keepdims=True)==0,label]).T

        self.neg_index  = np.where(label[0])[0]
        self.pos1_index = np.where(label[1])[0]
        self.pos2_index = np.where(label[2])[0]
        self.pos3_index = np.where(label[3])[0]
        self.pos4_index = np.where(label[4])[0]

        #5x
        self.num_image = len(self.dataset.df)//4
        self.length = self.num_image*5


    def __iter__(self):
        # neg = self.neg_index.copy()
        # random.shuffle(neg)

        neg  = np.random.choice(self.neg_index,  self.num_image, replace=True)
        pos1 = np.random.choice(self.pos1_index, self.num_image, replace=True)
        pos2 = np.random.choice(self.pos2_index, self.num_image, replace=True)
        pos3 = np.random.choice(self.pos3_index, self.num_image, replace=True)
        pos4 = np.random.choice(self.pos4_index, self.num_image, replace=True)

        l = np.stack([neg,pos1,pos2,pos3,pos4]).T
        l = l.reshape(-1)
        return iter(l)

    def __len__(self):
        return self.length

In [None]:
# Class which is used by the infor object in __get_item__
class Struct(object):
    def __init__(self, is_copy=False, **kwargs):
        self.add(is_copy, **kwargs)

    def add(self, is_copy=False, **kwargs):
        #self.__dict__.update(kwargs)

        if is_copy == False:
            for key, value in kwargs.items():
                setattr(self, key, value)
        else:
            for key, value in kwargs.items():
                try:
                    setattr(self, key, copy.deepcopy(value))
                    #setattr(self, key, value.copy())
                except Exception:
                    setattr(self, key, value)

    def __str__(self):
        text =''
        for k,v in self.__dict__.items():
            text += '\t%s : %s\n'%(k, str(v))
        return text
    
# Creating masks
def run_length_decode(rle, height=256, width=1600, fill_value=1):
    mask = np.zeros((height,width), np.float32)
    if rle != '':
        mask=mask.reshape(-1)
        r = [int(r) for r in rle.split(' ')]
        r = np.array(r).reshape(-1, 2)
        for start,length in r:
            start = start-1  #???? 0 or 1 index ???
            mask[start:(start + length)] = fill_value
        mask=mask.reshape(width, height).T
    return mask

In [None]:
def null_collate0(batch):
    batch_size = len(batch)

    input = []
    truth_label = []
    truth_mask  = []
    infor = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth_label.append(batch[b][1])
        truth_mask.append(batch[b][2])
        infor.append(batch[b][3])

    input = np.stack(input).astype(np.float32)/255
    input = input.transpose(0,3,1,2)
    truth_label = np.stack(truth_label)
    truth_mask  = np.stack(truth_mask)

    input = torch.from_numpy(input).float()
    truth_label = torch.from_numpy(truth_label).float()
    truth_mask = torch.from_numpy(truth_mask).long().unsqueeze(1)

    return input, truth_label, truth_mask, infor

def null_collate(batch):
    input, truth_label, truth_mask, infor = null_collate0(batch)
    with torch.no_grad():
        arange = torch.FloatTensor([1,2,3,4]).to(truth_mask.device).view(1,4,1,1).long()
        m = truth_mask.repeat(1,4,1,1)
        m = (m==arange).float()
        truth_attention = F.avg_pool2d(m,kernel_size=(32,32),stride=(32,32))
        truth_attention = (truth_attention > 0/(32*32)).float()

        #relabel for augmentation cropping, etc
        truth_label = m.sum(dim=[2,3])
        truth_label = (truth_label > 1).float()

    return input, truth_label, truth_mask, truth_attention, infor

In [None]:
def train_augment1(image, label, mask, infor):
    u=np.random.choice(3)
    if   u==0:
        pass
    elif u==1:
        image, mask = do_random_crop_rescale(image,mask,1600-(256-180),180)
    elif u==2:
        image, mask = do_random_crop_rotate_rescale(image,mask,1600-(256-200),200)

    #---------
    image, mask = do_random_crop(image, mask, 400,256)

    if np.random.rand()>0.25:
         image, mask = do_random_cutout(image, mask)


    #---------
    if np.random.rand()>0.5:
        image, mask = do_flip_lr(image, mask)
    if np.random.rand()>0.5:
        image, mask = do_flip_ud(image, mask)

    #---------
    if np.random.rand()>0.5:
        image = do_random_log_contast(image, gain=[0.50, 1.75])

    #---------
    u=np.random.choice(2)
    if   u==0:
        pass
    if   u==1:
        image = do_random_noise(image, noise=8)
#     if   u==2:
#         image = do_random_salt_pepper_noise(image, noise =0.0001)
    # if   u==3:
    #     image = do_random_salt_pepper_line(image)

    return image, label, mask, infor

# Learning Rate Schedule
class NullScheduler():
    def __init__(self, lr=0.01 ):
        super(NullScheduler, self).__init__()
        self.lr    = lr
        self.cycle = 0

    def __call__(self, time):
        return self.lr

    def __str__(self):
        string = 'NullScheduler\n' \
                + 'lr=%0.5f '%(self.lr)
        return string

In [None]:
BatchNorm2d = nn.BatchNorm2d

PRETRAIN_FILE = '../input/efficientnet-pytorch-b0-b7/efficientnet-b5-b6417697.pth'
def load_pretrain(net, skip=[], pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=True):

    #raise NotImplementedError
    print('\tload pretrain_file: %s'%pretrain_file)

    #pretrain_state_dict = torch.load(pretrain_file)
    pretrain_state_dict = torch.load(pretrain_file, map_location=lambda storage, loc: storage)
    state_dict = net.state_dict()

    i = 0
    conversion = np.array(CONVERSION).reshape(-1,4)
    for key,_,pretrain_key,_ in conversion:
        if any(s in key for s in
            ['.num_batches_tracked',]+skip):
            continue

        #print('\t\t',key)
        if is_print:
            print('\t\t','%-48s  %-24s  <---  %-32s  %-24s'%(
                key, str(state_dict[key].shape),
                pretrain_key, str(pretrain_state_dict[pretrain_key].shape),
            ))
        i = i+1

        state_dict[key] = pretrain_state_dict[pretrain_key]

    net.load_state_dict(state_dict)
    print('')
    print('len(pretrain_state_dict.keys()) = %d'%len(pretrain_state_dict.keys()))
    print('len(state_dict.keys())          = %d'%len(state_dict.keys()))
    print('loaded    = %d'%i)
    print('')

class ConvGnUp2d(nn.Module):
    def __init__(self, in_channel, out_channel, num_group=32, kernel_size=3, padding=1, stride=1):
        super(ConvGnUp2d, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        self.gn   = nn.GroupNorm(num_group,out_channel)

    def forward(self,x):
        x = self.conv(x)
        x = self.gn(x)
        x = F.relu(x, inplace=True)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        return x


def upsize_add(x, lateral):
    return F.interpolate(x, size=lateral.shape[2:], mode='nearest') + lateral

def upsize(x, scale_factor=2):
    x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
    return x

'''
model.py: calling main function ... 
 

stem   torch.Size([10, 48, 128, 128])
block1 torch.Size([10, 24, 128, 128])

block2 torch.Size([10, 40, 64, 64])

block3 torch.Size([10, 64, 32, 32])

block4 torch.Size([10, 128, 16, 16])
block5 torch.Size([10, 176, 16, 16])

block6 torch.Size([10, 304, 8, 8])
block7 torch.Size([10, 512, 8, 8])
last   torch.Size([10, 2048, 8, 8])

sucess!
'''


class Net(nn.Module):
    def load_pretrain(self, skip=['logit.'], is_print=True):
        load_pretrain(self, skip, pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=is_print)



    def __init__(self, num_class=4, drop_connect_rate=0.2):
        super(Net, self).__init__()

        e = EfficientNetB5(drop_connect_rate)
        self.stem   = e.stem
        self.block1 = e.block1
        self.block2 = e.block2
        self.block3 = e.block3
        self.block4 = e.block4
        self.block5 = e.block5
        self.block6 = e.block6
        self.block7 = e.block7
        self.last   = e.last
        e = None  #dropped

        #---
        self.lateral0 = nn.Conv2d(2048, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral1 = nn.Conv2d( 176, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral2 = nn.Conv2d(  64, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral3 = nn.Conv2d(  40, 64,  kernel_size=1, padding=0, stride=1)

        self.top1 = nn.Sequential(
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
        )
        self.top2 = nn.Sequential(
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
        )
        self.top3 = nn.Sequential(
            ConvGnUp2d( 64, 64),
        )
        self.top4 = nn.Sequential(
            nn.Conv2d(64*3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.logit_mask = nn.Conv2d(64,num_class+1,kernel_size=1)






    def forward(self, x):
        batch_size,C,H,W = x.shape

        x = self.stem(x)            #; print('stem  ',x.shape)
        x = self.block1(x)    ;x0=x #; print('block1',x.shape)
        x = self.block2(x)    ;x1=x #; print('block2',x.shape)
        x = self.block3(x)    ;x2=x #; print('block3',x.shape)
        x = self.block4(x)          #; print('block4',x.shape)
        x = self.block5(x)    ;x3=x #; print('block5',x.shape)
        x = self.block6(x)          #; print('block6',x.shape)
        x = self.block7(x)          #; print('block7',x.shape)
        x = self.last(x)      ;x4=x #; print('last  ',x.shape)

        # segment
        t0 = self.lateral0(x4)
        t1 = upsize_add(t0, self.lateral1(x3)) #16x16
        t2 = upsize_add(t1, self.lateral2(x2)) #32x32
        t3 = upsize_add(t2, self.lateral3(x1)) #64x64

        t1 = self.top1(t1) #128x128
        t2 = self.top2(t2) #128x128
        t3 = self.top3(t3) #128x128

        t = torch.cat([t1,t2,t3],1)
        t = self.top4(t)
        logit_mask = self.logit_mask(t)
        logit_mask = F.interpolate(logit_mask, scale_factor=2.0, mode='bilinear', align_corners=False)

        return logit_mask

In [None]:
# METRICS
# use topk
# def criterion_label(logit, truth, weight=None):
#     batch_size,num_class,H,W = logit.shape
#     K=5
#
#     logit = logit.view(batch_size,num_class,-1)
#     value, index = logit.topk(K)
#
#     logit_k = torch.gather(logit,dim=2,index=index)
#     truth_k = truth.view(batch_size,num_class,1).repeat(1,1,5)
#
#
#     if weight is None: weight=[1,1,1,1]
#     weight = torch.FloatTensor(weight).to(truth.device).view(1,-1,1)
#
#
#     loss = F.binary_cross_entropy_with_logits(logit_k, truth_k, reduction='none')
#     #https://arxiv.org/pdf/1909.07829.pdf
#     if 1:
#         gamma=2.0
#         p = torch.sigmoid(logit_k)
#         focal = (truth_k*(1-p) + (1-truth_k)*(p))**gamma
#         weight = weight*focal /focal.sum().item()
#
#     loss = loss*weight
#     loss = loss.mean()
#     return loss


#use top only
# def criterion_label(logit, truth, weight=None):
#     batch_size,num_class,H,W = logit.shape
#     logit = F.adaptive_max_pool2d(logit,1).view(-1,4)
#     truth = truth.view(-1,4)
#
#     if weight is None: weight=[1,1,1,1]
#     weight = torch.FloatTensor(weight).to(truth.device).view(1,-1)
#
#     loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')
#     loss = loss*weight
#     loss = loss.mean()
#     return loss




#https://discuss.pytorch.org/t/numerical-stability-of-bcewithlogitsloss/8246
def criterion_attention(logit, truth, weight=None):
    batch_size,num_class, H,W = logit.shape

    if weight is None: weight=[1,1,1,1]
    weight = torch.FloatTensor(weight).to(truth.device).view(1,-1,1,1)

    loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')

    #---
    #https://arxiv.org/pdf/1909.07829.pdf
    if 0:
        gamma=2.0
        p = torch.sigmoid(logit)
        focal = (truth*(1-p) + (1-truth)*(p))**gamma
        weight = weight*focal /focal.sum().item()*H*W
    #---
    loss = loss*weight
    loss = loss.mean()
    return loss

#
# def criterion_mask(logit, truth, weight=None):
#     if weight is not None: weight = torch.FloatTensor([1]+weight).cuda()
#     batch_size,num_class,H,W = logit.shape
#
#     logit = logit.permute(0, 2, 3, 1).contiguous().view(batch_size,-1, 5)
#     log_probability = -F.log_softmax(logit,-1)
#
#
#     truth = truth.permute(0, 2, 3, 1).contiguous().view(-1,1)
#     onehot = torch.FloatTensor(batch_size*H*W, 5).to(truth.device)
#     onehot.zero_()
#     onehot.scatter_(1, truth, 1)
#     onehot = onehot.view(batch_size,-1, 5)
#
#     #loss = F.cross_entropy(logit, truth, weight=weight, reduction='none')
#     loss = log_probability*onehot
#
#     loss = loss*weight
#     loss = loss.mean()
#     return loss

#focal loss
def criterion_mask(logit, truth, weight=None):
    if weight is None: weight=[1,1,1,1]
    weight = torch.FloatTensor([1]+weight).to(truth.device).view(1,-1 )

    batch_size,num_class,H,W = logit.shape

    logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, 5)
    truth = truth.permute(0, 2, 3, 1).contiguous().view(-1)
    # return F.cross_entropy(logit, truth, reduction='mean')

    log_probability = -F.log_softmax(logit,-1)
    probability = F.softmax(logit,-1)

    onehot = torch.zeros(batch_size*H*W,num_class).to(truth.device)
    onehot.scatter_(dim=1, index=truth.view(-1,1),value=1) #F.one_hot(truth,5).float()

    loss = log_probability*onehot

    #---
    if 1:#image based focusing
        probability = probability.view(batch_size,H*W,5)
        truth  = truth.view(batch_size,H*W,1)
        weight = weight.view(1,1,5)

        alpha  = 2
        focal  = torch.gather(probability, dim=-1, index=truth.view(batch_size,H*W,1))
        focal  = (1-focal)**alpha
        focal_sum = focal.sum(dim=[1,2],keepdim=True)
        #focal_sum = focal.sum().view(1,1,1)
        weight = weight*focal/focal_sum.detach() *H*W
        weight = weight.view(-1,5)

    #---
    if 0:#add topk max pool loss
        #https://discuss.pytorch.org/t/resolved-how-to-implement-k-max-pooling-for-cnn-text-classification/931
        #probability = probability.view(batch_size,H*W,5)
        #weight = weight.view(1,1,5)
        with torch.no_grad():
            index = probability.topk(k=5, dim = 1)[1].sort(dim = 1)[0]
            topk = torch.ones(batch_size*H*W,num_class).to(truth.device)
            topk[index] = 2.0  #increase weighing
            topk = topk.view(-1,5)
        weight = weight*topk

        zz=0


    loss = loss*weight
    loss = loss.mean()
    return loss

#----
def logit_mask_to_probability_label(logit):
    batch_size,num_class,H,W = logit.shape
    probability = F.softmax(logit,1)
    #probability = F.avg_pool2d(probability, kernel_size=16,stride=16)

    probability = probability.permute(0, 2, 3, 1).contiguous().view(batch_size,-1, 5)
    value,index = probability.max(1)

    probability = value[:,1:]
    return probability

def metric_label(probability, truth, threshold=0.5):
    batch_size=len(truth)

    with torch.no_grad():
        probability = probability.view(batch_size,4)
        truth = truth.view(batch_size,4)

        #----
        neg_index = (truth==0).float()
        pos_index = 1-neg_index
        num_neg = neg_index.sum(0)
        num_pos = pos_index.sum(0)

        #----
        p = (probability>threshold).float()
        t = (truth>0.5).float()

        tp = ((p + t) == 2).float()  # True positives
        tn = ((p + t) == 0).float()  # True negatives
        tn = tn.sum(0)
        tp = tp.sum(0)

        #----
        tn = tn.data.cpu().numpy()
        tp = tp.data.cpu().numpy()
        num_neg = num_neg.data.cpu().numpy().astype(np.int32)
        num_pos = num_pos.data.cpu().numpy().astype(np.int32)

    return tn,tp, num_neg,num_pos

def truth_to_onehot(truth, num_class=4):
    onehot = truth.repeat(1,num_class,1,1)
    arange = torch.arange(1,num_class+1).view(1,num_class,1,1).to(truth.device)
    onehot = (onehot == arange).float()
    return onehot

def predict_to_onehot(predict, num_class=4):
    value, index = torch.max(predict, 1, keepdim=True)
    value  = value.repeat(1,num_class,1,1)
    index  = index.repeat(1,num_class,1,1)
    arange = torch.arange(1,num_class+1).view(1,num_class,1,1).to(predict.device)
    onehot = (index == arange).float()
    value  = value*onehot
    return value

def metric_mask(logit, truth, threshold=0.5, sum_threshold=100):
    with torch.no_grad():
        probability = torch.softmax(logit,1)
        truth = truth_to_onehot(truth)
        probability = predict_to_onehot(probability)

        batch_size,num_class,H,W = truth.shape
        probability = probability.view(batch_size,num_class,-1)
        truth = truth.view(batch_size,num_class,-1)
        p = (probability>threshold).float()
        t = (truth>0.5).float()

        t_sum = t.sum(-1)
        p_sum = p.sum(-1)
        d_neg = (p_sum < sum_threshold).float()
        d_pos = 2*(p*t).sum(-1)/((p+t).sum(-1)+1e-12)

        neg_index = (t_sum==0).float()
        pos_index = 1-neg_index

        num_neg = neg_index.sum(0)
        num_pos = pos_index.sum(0)
        dn = (neg_index*d_neg).sum(0)
        dp = (pos_index*d_pos).sum(0)

        #----
        dn = dn.data.cpu().numpy()
        dp = dp.data.cpu().numpy()
        num_neg = num_neg.data.cpu().numpy().astype(np.int32)
        num_pos = num_pos.data.cpu().numpy().astype(np.int32)

    return dn,dp, num_neg,num_pos

def probability_mask_to_probability_label(probability):
    batch_size,num_class,H,W = probability.shape
    probability = probability.permute(0, 2, 3, 1).contiguous().view(batch_size,-1, 5)
    value, index = probability.max(1)
    probability = value[:,1:]
    return probability

In [None]:
def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
       lr +=[ param_group['lr'] ]

    assert(len(lr)==1) #we support only one param_group
    lr = lr[0]

    return lr

In [None]:
def do_valid(net, valid_loader, out_dir=None):

    valid_loss = np.zeros(17, np.float32)
    valid_num  = np.zeros_like(valid_loss)

    for t, (input, truth_label, truth_mask, truth_attention, infor) in enumerate(valid_loader):

        #if b==5: break
        batch_size = len(infor)

        net.eval()
        input = input.cuda()
        truth_label = truth_label.cuda()
        truth_mask  = truth_mask.cuda()
        truth_attention = truth_attention.cuda()

        with torch.no_grad():
            logit_mask = net(input)
            loss = criterion_mask(logit_mask, truth_mask)

            probability_mask  = F.softmax(logit_mask,1)
            probability_label = probability_mask_to_probability_label(probability_mask)
            tn,tp, num_neg,num_pos = metric_label(probability_label, truth_label)
            dn,dp, num_neg,num_pos = metric_mask(logit_mask, truth_mask)

        #---
        l = np.array([ loss.item()*batch_size, *tn, *tp, *dn, *dp])
        n = np.array([ batch_size, *num_neg, *num_pos, *num_neg, *num_pos])
        valid_loss += l
        valid_num  += n

        #==========
        #dum results for debug
#         if 0:

#             probability_mask  = F.softmax(logit_mask,1)

#             probability_label = probability_label.data.cpu().numpy()
#             probability_mask = probability_mask.data.cpu().numpy()
#             truth_label = truth_label.data.cpu().numpy()
#             truth_mask  = truth_mask.data.cpu().numpy()


#             image = input_to_image(input)
#             for b in range(batch_size):
#                 image_id = infor[b].image_id
#                 result = draw_predict_result(
#                     image[b], truth_label[b], truth_mask[b], probability_label[b], probability_mask[b])

#                 image_show('result',result,resize=0.5)
#                 cv2.imwrite(out_dir +'/valid/%s.png'%image_id[:-4], result)
#                 cv2.waitKey(1)
#                 pass
#         #==========

        #print(valid_loss)
        print('\r %4d/%4d'%(valid_num[0], len(valid_loader.dataset)),end='',flush=True)

        pass  #-- end of one data loader --
    assert(valid_num[0] == len(valid_loader.dataset))
    valid_loss = valid_loss/valid_num

    return valid_loss

In [None]:
def run_train():
    out_dir = \
        ''
    initial_checkpoint = \
        '../input/effb5-mishfpn/00045000_model.pth'
    
    sampler     = FiveBalanceClassSampler #RandomSampler #FiveBalanceClassSampler
    loss_weight = None #[5,10,2,5]

    schduler = NullScheduler(lr=0.001)
    iter_accum = 1
    batch_size =4 #8
    
    log = Logger()
    log.open('../working/log.train.txt',mode='a')
    log.write('\n')
    log.write('\tSEED         = %u\n' % SEED)
    
    log.write('** dataset setting **\n')
    train_dataset = SteelDataset(
        mode    = 'train',
        csv     = ['train.csv',],
        split   = ['train_b1_11568.npy',],
        augment = train_augment1,
    )
    train_loader  = DataLoader(
        train_dataset,
        sampler     = sampler(train_dataset),
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 4,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    valid_dataset = SteelDataset(
        mode    = 'train',
        csv     = ['train.csv',],
        split   = ['valid_b1_1000.npy',],
        augment = None,
    )
    valid_loader = DataLoader(
        valid_dataset,
        sampler     = SequentialSampler(valid_dataset),
        batch_size  = 4,
        drop_last   = False,
        num_workers = 4,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    assert(len(train_dataset)>=batch_size)
    log.write('batch_size = %d\n'%(batch_size))
    log.write('train_dataset : \n%s\n'%(train_dataset))
    log.write('valid_dataset : \n%s\n'%(valid_dataset))
    log.write('\n')
    
    log.write('** net setting **\n')
    net = Net().cuda()
#     log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint)
    if initial_checkpoint is not None:
        state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
        # for k in list(state_dict.keys()):
        #     if any(s in k for s in ['g_block1',]): state_dict.pop(k, None)
        # net.load_state_dict(state_dict,strict=False)
        net.load_state_dict(state_dict,strict=False)  #True
    else:
        net.load_pretrain(is_print=False)
    
    log.write('%s\n'%(type(net)))
    log.write('loss_weight=%s\n'%(str(loss_weight)))
    log.write('sampler=%s\n'%(str(sampler)))
    log.write('\n')
    
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=schduler(0), momentum=0.9, weight_decay=0.0001)

    num_iters   = 3000*1000
    iter_smooth = 50
    iter_log    = 200
    iter_valid  = 200
    iter_save   = [0, num_iters-1]\
                   + list(range(0, num_iters, 2500))#1*1000

    start_iter = 0
    start_epoch= 0
    rate       = 0
    if initial_checkpoint is not None:
        initial_optimizer = initial_checkpoint.replace('_model.pth','_optimizer.pth')
        if os.path.exists(initial_optimizer):
            checkpoint  = torch.load(initial_optimizer)
            start_iter  = checkpoint['iter' ]
            start_epoch = checkpoint['epoch']
            #optimizer.load_state_dict(checkpoint['optimizer'])
        pass

    log.write('optimizer\n  %s\n'%(optimizer))
    log.write('schduler\n  %s\n'%(schduler))
    log.write('\n')
    
    log.write('** start training here! **\n')
    log.write('   batch_size=%d,  iter_accum=%d\n'%(batch_size,iter_accum))
    log.write('                     |------------------------------------------- VALID------------------------------------------------|---------------------- TRAIN/BATCH ---------------------\n')
    log.write('rate     iter  epoch |  loss           [tn1,2,3,4  :  tp1,2,3,4]                    [dn1,2,3,4  :  dp1,2,3,4]          |  loss    [tn :  tp1,2,3,4]          | time             \n')
    log.write('--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n')
              #0.00000 135.0*  65.1 |   nan  [1.00 1.00 1.00 1.00 : 0.00 0.00 0.00 0.00]  [1.00 1.00 1.00 1.00 : 0.00 0.00 0.00 0.00] | 0.000  [0.00 : 0.00 0.00 0.00 0.00] |  0 hr 00 min

    valid_loss = np.zeros(17,np.float32)
    train_loss = np.zeros( 6,np.float32)
    batch_loss = np.zeros_like(valid_loss)
    iter = 0
    i    = 0
    
    start = timer()
    while  iter<num_iters:
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = np.zeros_like(train_loss)

        optimizer.zero_grad()
        for t, (input, truth_label, truth_mask, truth_attention, infor) in enumerate(train_loader):

            batch_size = len(infor)
            iter  = i + start_iter
            epoch = (iter-start_iter)*batch_size/len(train_dataset) + start_epoch


            #if 0:
            if (iter % iter_valid==0):
                valid_loss = do_valid(net, valid_loader, out_dir) #
                pass

            if (iter % iter_log==0):
                print('\r',end='',flush=True)
                asterisk = '*' if iter in iter_save else ' '
                log.write('%0.5f %5.1f%s %5.1f | %5.3f  [%0.2f %0.2f %0.2f %0.2f : %0.2f %0.2f %0.2f %0.2f]  [%0.2f %0.2f %0.2f %0.2f : %0.2f %0.2f %0.2f %0.2f] | %5.3f  [%0.2f : %0.2f %0.2f %0.2f %0.2f] | %s' % (\
                         rate, iter/1000, asterisk, epoch,
                         *valid_loss,
                         *train_loss,
                         time_to_str((timer() - start),'min'))
                )
                log.write('\n')
                
            if iter in iter_save:
                torch.save({
                    #'optimizer': optimizer.state_dict(),
                    'iter'     : iter,
                    'epoch'    : epoch,
                }, '../working/%08d_optimizer.pth'%(iter))
                if iter!=start_iter:
                    torch.save(net.state_dict(),'../working/%08d_model.pth'%(iter))
                    pass


            # learning rate schduler -------------
            lr = schduler(iter)
            if lr<0 : break
            adjust_learning_rate(optimizer, lr)
            rate = get_learning_rate(optimizer)

            # one iteration update  -------------
            #net.set_mode('train',is_freeze_bn=True)

            net.train()
            input = input.cuda()
            truth_label = truth_label.cuda()
            truth_mask  = truth_mask.cuda()
#             truth_attention  = truth_attention.cuda()


            logit_mask = net(input)
            loss = criterion_mask(logit_mask, truth_mask)
            probability_mask  = F.softmax(logit_mask,1)
            probability_label = probability_mask_to_probability_label(probability_mask)
            tn,tp, num_neg,num_pos = metric_label(probability_label, truth_label)
            
            ((loss)/iter_accum).backward()
            if (iter % iter_accum)==0:
                optimizer.step()
                optimizer.zero_grad()

            # print statistics  --------
            l = np.array([ loss.item()*batch_size,tn.sum(),*tp ])
            n = np.array([ batch_size, num_neg.sum(),*num_pos ])
            batch_loss      = l/(n+1e-8)
            sum_train_loss += l
            sum_train      += n
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/(sum_train+1e-12)
                sum_train_loss[...] = 0
                sum_train[...]      = 0


            print('\r',end='',flush=True)
            asterisk = ' '
            print('%0.5f %5.1f%s %5.1f | %5.3f  [%0.2f %0.2f %0.2f %0.2f : %0.2f %0.2f %0.2f %0.2f]  [%0.2f %0.2f %0.2f %0.2f : %0.2f %0.2f %0.2f %0.2f] | %5.3f  [%0.2f : %0.2f %0.2f %0.2f %0.2f] | %s' % (\
                         rate, iter/1000, asterisk, epoch,
                         *valid_loss,
                         *batch_loss,
                         time_to_str((timer() - start),'min'))
            , end='',flush=True)
            i=i+1
            
            if 0:
                for di in range(3):
                    if (iter+di)%1000==0:

                        probability_attention = torch.sigmoid(logit_attention)
                        probability_attention = probability_attention.data.cpu().numpy().reshape(-1,4,5)
                        truth_label = truth_label.data.cpu().numpy()
                        truth_mask  = truth_mask.data.cpu().numpy()
                        truth_attention = truth_attention.data.cpu().numpy().reshape(-1,4,5)

                        image = input_to_image(input)
                        for b in range(batch_size):
                            image_id = infor[b].image_id
                            result = draw_predict_result_8cls(image[b], truth_label[b], truth_mask[b], truth_attention[b], probability_attention[b])

                            image_show('result',result,resize=0.5)
                            cv2.imwrite(out_dir +'/train/%05d.png'%(di*100+b), result)
                            cv2.waitKey(1)
                            pass




        pass  #-- end of one data loader --
    pass #-- end of all iterations --

    log.write('\n')

In [None]:
print('                     |------------------------------------------- VALID------------------------------------------------|---------------------- TRAIN/BATCH ---------------------\n')
print('rate     iter  epoch |  loss           [tn1,2,3,4  :  tp1,2,3,4]                    [dn1,2,3,4  :  dp1,2,3,4]          |  loss    [tn :  tp1,2,3,4]          | time             \n')

In [None]:
run_train()