In [1]:
import sys
import os

# Get necessary Imports
import pdb
import os
import cv2
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from albumentations import (Normalize, Compose)
from albumentations.pytorch import ToTensor
import torch.utils.data as data
import torchvision.models as models
import torch.nn as nn
from torch.nn import functional as F

from mishefficientnet_hengs	import *
from misheffnet_b5utility import *

In [2]:
SPLIT_DIR = '../input/hengs-split'
DATA_DIR = '../input/severstal-steel-defect-detection'
CHECKPOINT_FILE = '../input/more-train-mish-inference/00095000_model.pth'
SUBMISSION_CSV_FILE = '../working/submission.csv'

In [3]:
def summarise_submission_csv(df):


    text = ''
    df['Class'] = df['ImageId_ClassId'].str[-1].astype(np.int32)
    df['Label'] = (df['EncodedPixels']!='').astype(np.int32)
    num_image = len(df)//4
    num = len(df)

    pos = (df['Label']==1).sum()
    neg = num-pos


    pos1 = ((df['Class']==1) & (df['Label']==1)).sum()
    pos2 = ((df['Class']==2) & (df['Label']==1)).sum()
    pos3 = ((df['Class']==3) & (df['Label']==1)).sum()
    pos4 = ((df['Class']==4) & (df['Label']==1)).sum()

    neg1 = num_image-pos1
    neg2 = num_image-pos2
    neg3 = num_image-pos3
    neg4 = num_image-pos4


    text += 'compare with LB probing ... \n'
    text += '\t\tnum_image = %5d(1801) \n'%num_image
    text += '\t\tnum  = %5d(7204) \n'%num
    text += '\n'

    text += '\t\tpos1 = %5d( 128)  %0.3f\n'%(pos1,pos1/128)
    text += '\t\tpos2 = %5d(  43)  %0.3f\n'%(pos2,pos2/43)
    text += '\t\tpos3 = %5d( 741)  %0.3f\n'%(pos3,pos3/741)
    text += '\t\tpos4 = %5d( 120)  %0.3f\n'%(pos4,pos4/120)
    text += '\n'

    text += '\t\tneg1 = %5d(1673)  %0.3f  %3d\n'%(neg1,neg1/1673, neg1-1673)
    text += '\t\tneg2 = %5d(1758)  %0.3f  %3d\n'%(neg2,neg2/1758, neg2-1758)
    text += '\t\tneg3 = %5d(1060)  %0.3f  %3d\n'%(neg3,neg3/1060, neg3-1060)
    text += '\t\tneg4 = %5d(1681)  %0.3f  %3d\n'%(neg4,neg4/1681, neg4-1681)
    text += '--------------------------------------------------\n'
    text += '\t\tneg  = %5d(6172)  %0.3f  %3d \n'%(neg,neg/6172, neg-6172)
    text += '\n'

    if 1:
        #compare with reference
        pass

    return text

In [4]:
class ConvGnUp2d(nn.Module):
    def __init__(self, in_channel, out_channel, num_group=32, kernel_size=3, padding=1, stride=1):
        super(ConvGnUp2d, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        self.gn   = nn.GroupNorm(num_group,out_channel)

    def forward(self,x):
        x = self.conv(x)
        x = self.gn(x)
        x = F.relu(x, inplace=True)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        return x

def upsize_add(x, lateral):
    return F.interpolate(x, size=lateral.shape[2:], mode='nearest') + lateral

def upsize(x, scale_factor=2):
    x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
    return x

class Net(nn.Module):
    def load_pretrain(self, skip=['logit.'], is_print=True):
        load_pretrain(self, skip, pretrain_file=PRETRAIN_FILE, conversion=CONVERSION, is_print=is_print)

    def __init__(self, num_class=4, drop_connect_rate=0.2):
        super(Net, self).__init__()

        e = EfficientNetB5(drop_connect_rate)
        self.stem   = e.stem
        self.block1 = e.block1
        self.block2 = e.block2
        self.block3 = e.block3
        self.block4 = e.block4
        self.block5 = e.block5
        self.block6 = e.block6
        self.block7 = e.block7
        self.last   = e.last
        e = None  #dropped

        #---
        self.lateral0 = nn.Conv2d(2048, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral1 = nn.Conv2d( 176, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral2 = nn.Conv2d(  64, 64,  kernel_size=1, padding=0, stride=1)
        self.lateral3 = nn.Conv2d(  40, 64,  kernel_size=1, padding=0, stride=1)

        self.top1 = nn.Sequential(
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
        )
        self.top2 = nn.Sequential(
            ConvGnUp2d( 64, 64),
            ConvGnUp2d( 64, 64),
        )
        self.top3 = nn.Sequential(
            ConvGnUp2d( 64, 64),
        )
        self.top4 = nn.Sequential(
            nn.Conv2d(64*3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.logit_mask = nn.Conv2d(64,num_class+1,kernel_size=1)

    def forward(self, x):
        batch_size,C,H,W = x.shape
        # x = x.clone()
        # x = x-torch.FloatTensor(IMAGE_RGB_MEAN).to(x.device).view(1,-1,1,1)
        # x = x/torch.FloatTensor(IMAGE_RGB_STD).to(x.device).view(1,-1,1,1)

        x = self.stem(x)            #; print('stem  ',x.shape)
        x = self.block1(x)    ;x0=x #; print('block1',x.shape)
        x = self.block2(x)    ;x1=x #; print('block2',x.shape)
        x = self.block3(x)    ;x2=x #; print('block3',x.shape)
        x = self.block4(x)          #; print('block4',x.shape)
        x = self.block5(x)    ;x3=x #; print('block5',x.shape)
        x = self.block6(x)          #; print('block6',x.shape)
        x = self.block7(x)          #; print('block7',x.shape)
        x = self.last(x)      ;x4=x #; print('last  ',x.shape)

        # segment
        t0 = self.lateral0(x4)
        t1 = upsize_add(t0, self.lateral1(x3)) #16x16
        t2 = upsize_add(t1, self.lateral2(x2)) #32x32
        t3 = upsize_add(t2, self.lateral3(x1)) #64x64

        t1 = self.top1(t1) #128x128
        t2 = self.top2(t2) #128x128
        t3 = self.top3(t3) #128x128

        t = torch.cat([t1,t2,t3],1)
        t = self.top4(t)
        logit_mask = self.logit_mask(t)
        logit_mask = F.interpolate(logit_mask, scale_factor=2.0, mode='bilinear', align_corners=False)

        return logit_mask

In [5]:
def image_to_input(image):
    input = image.astype(np.float32)
    input = input[...,::-1]/255
    input = input.transpose(0,3,1,2)
    # input[:,0] = (input[:,0]-IMAGE_RGB_MEAN[0])/IMAGE_RGB_STD[0]
    # input[:,1] = (input[:,1]-IMAGE_RGB_MEAN[1])/IMAGE_RGB_STD[1]
    # input[:,2] = (input[:,2]-IMAGE_RGB_MEAN[2])/IMAGE_RGB_STD[2]
    return input


class KaggleTestDataset(Dataset):
    def __init__(self):

        df =  pd.read_csv(DATA_DIR + '/sample_submission.csv')
        df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
        self.uid = df['ImageId'].unique().tolist()

    def __str__(self):
        string  = ''
        string += '\tlen = %d\n'%len(self)
        return string

    def __len__(self):
        return len(self.uid)

    def __getitem__(self, index):
        # print(index)
        image_id = self.uid[index]
        image = cv2.imread(DATA_DIR + '/test_images/%s'%(image_id), cv2.IMREAD_COLOR)
        return image, image_id


def null_collate(batch):
    batch_size = len(batch)
    input = []
    image_id = []
    for b in range(batch_size):
        input.append(batch[b][0])
        image_id.append(batch[b][1])
    input = np.stack(input)
    input = torch.from_numpy(image_to_input(input))
    return input, image_id

def run_length_encode(mask):
    #possible bug for here
    m = mask.T.flatten()
    if m.sum()==0:
        rle=''
    else:
        m   = np.concatenate([[0], m, [0]])
        run = np.where(m[1:] != m[:-1])[0] + 1
        run[1::2] -= run[::2]
        rle = ' '.join(str(r) for r in run)
    return rle

def probability_mask_to_probability_label(probability):
    batch_size,num_class,H,W = probability.shape
    probability = probability.permute(0, 2, 3, 1).contiguous().view(batch_size,-1, 5)
    value, index = probability.max(1)
    probability = value[:,1:]
    return probability

In [6]:
# threshold_label      = [ 0.50, 0.50, 0.50, 0.50,]
# threshold_mask_pixel = [ 0.50, 0.50, 0.50, 0.50,]
# threshold_mask_size  = [ 1,  1,  1,  1,]

threshold_label      = [ 0.80, 0.90, 0.75, 0.60,]
threshold_mask_pixel = [ 0.40, 0.40, 0.40, 0.40,]
threshold_mask_size  = [   40,   40,   40,   40,]

## load net
print('load net ...')
net = Net().cuda()
net.load_state_dict(torch.load(CHECKPOINT_FILE, map_location=lambda storage, loc: storage))
print('')


## load data
print('load data ...')
dataset = KaggleTestDataset()
print(dataset)
#exit(0)

loader  = DataLoader(
    dataset,
    sampler     = SequentialSampler(dataset),
    batch_size  = 4,
    drop_last   = False,
    num_workers = 0,
    pin_memory  = True,
    collate_fn  = null_collate
)


## start here ----------------------------------
image_id_class_id = []
encoded_pixel     = []


net.eval()

import tqdm
start_timer = timer()
for t,(input, image_id) in tqdm.tqdm_notebook(enumerate(loader)):
    if t%200==0:
        print('\r loader: t = %4d / %4d  %s  %s : %s'%(
              t, len(loader)-1, str(input.shape), image_id[0], time_to_str((timer() - start_timer),'sec'),
        ),end='', flush=True)

    with torch.no_grad():
        input = input.cuda()

        logit = net(input) #data_parallel(net,input)  #net(input)
        probability = torch.softmax(logit,1)
        
        probability_mask  = probability[:,1:]#just drop background
        probability_label = probability_mask_to_probability_label(probability[:,1:])

    probability_mask  = probability_mask.data.cpu().numpy()
    probability_label = probability_label.data.cpu().numpy()

    batch_size = len(image_id)
    for b in range(batch_size):
        for c in range(4):
            rle=''
            predict_label = probability_label[b,c]>threshold_label[c]
            if predict_label:
                try:
                    predict_mask = probability_mask[b,c] > threshold_mask_pixel[c]
                    #predict_mask = post_process(predict_mask, threshold_mask_size[c])
                    rle = run_length_encode(predict_mask)

                except:
                    print('An exception occurred : %s'%(image_id[b]+'_%d'%(c+1)))


            image_id_class_id.append(image_id[b]+'_%d'%(c+1))
            encoded_pixel.append(rle)


df = pd.DataFrame(zip(image_id_class_id, encoded_pixel), columns=['ImageId_ClassId', 'EncodedPixels'])
df.to_csv(SUBMISSION_CSV_FILE, index=False)
print('')

## print statistics ----
if 1:
    text = summarise_submission_csv(df)
    print(text)


load net ...

load data ...
	len = 1801



HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

 loader: t =  400 /  450  torch.Size([4, 3, 256, 1600])  e4f7dd7de.jpg :  2 min 03 sec

compare with LB probing ... 
		num_image =  1801(1801) 
		num  =  7204(7204) 

		pos1 =   131( 128)  1.023
		pos2 =    58(  43)  1.349
		pos3 =   151( 741)  0.204
		pos4 =   111( 120)  0.925

		neg1 =  1670(1673)  0.998   -3
		neg2 =  1743(1758)  0.991  -15
		neg3 =  1650(1060)  1.557  590
		neg4 =  1690(1681)  1.005    9
--------------------------------------------------
		neg  =  6753(6172)  1.094  581 


