In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from fastai.vision import *
from fastai import *
import os
from collections import defaultdict
from fastai.vision.models.cadene_models import *

### Set up paths

In [2]:
train_pd = pd.read_csv('/root/.fastai/data/severstal/train.csv')

In [3]:
train_pd.head(5)

Unnamed: 0,ImageId_ClassId,EncodedPixels
0,0002cc93b.jpg_1,29102 12 29346 24 29602 24 29858 24 30114 24 3...
1,0002cc93b.jpg_2,
2,0002cc93b.jpg_3,
3,0002cc93b.jpg_4,
4,00031f466.jpg_1,


In [4]:
path = Path('/root/.fastai/data/severstal')

In [5]:
path.ls()

[PosixPath('/root/.fastai/data/severstal/train_images.zip'),
 PosixPath('/root/.fastai/data/severstal/sample_submission.csv'),
 PosixPath('/root/.fastai/data/severstal/test_images.zip'),
 PosixPath('/root/.fastai/data/severstal/train.csv'),
 PosixPath('/root/.fastai/data/severstal/train_images'),
 PosixPath('/root/.fastai/data/severstal/test_images')]

In [6]:
train_images = get_image_files(path/'train_images')
train_images[:3]

[PosixPath('/root/.fastai/data/severstal/train_images/5e581254c.jpg'),
 PosixPath('/root/.fastai/data/severstal/train_images/fd2f7b4f4.jpg'),
 PosixPath('/root/.fastai/data/severstal/train_images/82f4c0b69.jpg')]

### Check maximum size of images

In [7]:
def check_img_max_size(folder):
    max_height = 0
    max_width = 0
    for train_image in train_images:
        img = open_image(train_image)
        if max_height < img.shape[1]:
            max_height = img.shape[1]
        if max_width < img.shape[2]:
            max_width = img.shape[2]
    return max_height, max_width

def show_image(images, index):
    img_f = images[index]
    print(type(img_f))
    img = open_image(img_f)
    print(img)
    img.show(figsize=(5,5))

In [8]:
mask_path = Path('/kaggle/mask')
if not os.path.exists(mask_path):
    os.makedirs(str(mask_path))

In [9]:
def convert_encoded_to_array(encoded_pixels):
    pos_array = []
    len_array = []
    splits = encoded_pixels.split()
    pos_array = [int(n) - 1 for i, n in enumerate(splits) if i % 2 == 0]
    len_array = [int(n) for i, n in enumerate(splits) if i % 2 == 1]
    return pos_array, len_array
        
def convert_to_pair(pos_array, rows):
    return [(p % rows, p // rows) for p in pos_array]

def create_positions(single_pos, size):
    return [i for i in range(single_pos, single_pos + size)]

def create_positions_pairs(single_pos, size, row_size):
    return convert_to_pair(create_positions(single_pos, size), row_size)

def convert_to_mask(encoded_pixels, row_size, col_size, category):
    pos_array, len_array = convert_encoded_to_array(encoded_pixels)
    mask = np.zeros([row_size, col_size])
    for(p, l) in zip(pos_array, len_array):
        for row, col in create_positions_pairs(p, l, row_size):
            mask[row][col] = category
    return mask

def save_to_image(masked, image_name):
    im = PIL.Image.fromarray(masked)
    im = im.convert("L")
    image_name = re.sub(r'(.+)\.jpg', r'\1', image_name) + ".png"
    real_path = mask_path/image_name
    im.save(real_path)
    return real_path

def open_single_image(path):
    img = open_image(path)
    img.show(figsize=(20,20))
    
def get_y_fn(x):
    return mask_path/(x.stem + '.png')

def group_by(train_images, train_pd):
    tran_dict = {image.name:[] for image in train_images}
    pattern = re.compile('(.+)_(\d+)')
    for index, image_path in train_pd.iterrows():
        m = pattern.match(image_path['ImageId_ClassId'])
        file_name = m.group(1)
        category = m.group(2)
        tran_dict[file_name].append((int(category), image_path['EncodedPixels']))
    return tran_dict

def display_image_with_mask(img_name):
    full_image = path/'train_images'/img_name
    print(full_image)
    open_single_image(full_image)
    mask_image = get_y_fn(full_image)
    mask = open_mask(mask_image)
    print(full_image)
    mask.show(figsize=(20, 20), alpha=0.5)

### Prepare Transforms

In [10]:
def limited_dihedral_affine(k:partial(uniform_int,0,3)):
    "Randomly flip `x` image based on `k`."
    x = -1 if k&1 else 1
    y = -1 if k&2 else 1
    if k&4: return [[0, x, 0.],
                    [y, 0, 0],
                    [0, 0, 1.]]
    return [[x, 0, 0.],
            [0, y, 0],
            [0, 0, 1.]]

dihedral_affine = TfmAffine(limited_dihedral_affine)

def get_extra_transforms(max_rotate:float=3., max_zoom:float=1.1,
                   max_lighting:float=0.2, max_warp:float=0.2, p_affine:float=0.75,
                   p_lighting:float=0.75, xtra_tfms:Optional[Collection[Transform]]=None)->Collection[Transform]:
    "Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms."
    p_lightings = [p_lighting, p_lighting + 0.2, p_lighting + 0.4, p_lighting + 0.6, p_lighting + 0.7]
    max_lightings = [max_lighting, max_lighting + 0.2, max_lighting + 0.4, max_lighting + 0.6, max_lighting + 0.7]
    res = [rand_crop(), dihedral_affine(), 
           symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine),
           rotate(degrees=(-max_rotate,max_rotate), p=p_affine),
           rand_zoom(scale=(1., max_zoom), p=p_affine)]
    res.extend([brightness(change=(0.5*(1-mp[0]), 0.5*(1+mp[0])), p=mp[1]) for mp in zip(max_lightings, p_lightings)])
    res.extend([contrast(scale=(1-mp[0], 1/(1-mp[0])), p=mp[1]) for mp in zip(max_lightings, p_lightings)])
    #       train                   , valid
    return (res, [crop_pad()])

def get_simple_transforms(max_rotate:float=3., max_zoom:float=1.1,
                   max_lighting:float=0.2, max_warp:float=0.2, p_affine:float=0.75,
                   p_lighting:float=0.75, xtra_tfms:Optional[Collection[Transform]]=None)->Collection[Transform]:
    "Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms."
    res = [
#         rand_crop(),
        symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine),
        rotate(degrees=(-max_rotate,max_rotate), p=p_affine),
        rand_zoom(scale=(1., max_zoom), p=p_affine)
          ]
    #       train                   , valid
    return (res, [crop_pad()])

### Prepare data bunch

In [11]:
codes = array(['0', '1', '2', '3', '4'])

In [12]:
train_images = (path/'train_images').ls()
src_size = np.array(open_image(str(train_images[0])).shape[1:])
valid_pct = 0.10

In [13]:
def create_data_bunch(bs, size, start_pos, end_pos):
    src = (SegmentationItemList.from_folder(path/'train_images')
       .split_by_rand_pct(valid_pct=valid_pct)
       .label_from_func(get_y_fn, classes=codes))
    test_files = [f.name for f in get_image_files(path/'test_images')]
    test_files = test_files[start_pos:end_pos]
    print('len(test_files)', len(test_files))
    data = (src.transform(get_transforms(max_rotate=25), size=size, tfm_y=True)
        .add_test(ImageList.from_df(path=path/'test_images', df=pd.DataFrame(test_files)), 
                  tfms=None, tfm_y=False)
        .databunch(bs=bs)
        .normalize(imagenet_stats))
    return src, data

### Create learner and training
Starting with low resolution training

##### Some metrics functions

In [14]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['0']

def acc_camvid(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    argmax = (input.argmax(dim=1))
    comparison = argmax[mask]==target[mask]
    return torch.tensor(0.) if comparison.numel() == 0 else comparison.float().mean()

def acc_camvid_with_zero_check(input, target):
    target = target.squeeze(1)
    argmax = (input.argmax(dim=1))
    batch_size = input.shape[0]
    total = torch.empty([batch_size])
    for b in range(batch_size):
        if(torch.sum(argmax[b]).item() == 0.0 and torch.sum(target[b]).item() == 0.0):
            total[b] = 1
        else:
            mask = target[b] != void_code
            comparison = argmax[b][mask]==target[b][mask]
            total[b] = torch.tensor(0.) if comparison.numel() == 0 else comparison.float().mean()
    return total.mean()


def calc_dice_coefficients(argmax, target, cats):
    def calc_dice_coefficient(seg, gt, cat: int):
        mask_seg = seg == cat
        mask_gt = gt == cat
        sum_seg = torch.sum(mask_seg.float())
        sum_gt = torch.sum(mask_gt.float())
        if sum_seg + sum_gt == 0:
            return torch.tensor(1.0)
        return (torch.sum((seg[gt == cat] / cat).float()) * 2.0) / (sum_seg + sum_gt)

    total_avg = torch.empty([len(cats)])
    for i, c in enumerate(cats):
        total_avg[i] = calc_dice_coefficient(argmax, target, c)
    return total_avg.mean()


def dice_coefficient(input, target):
    target = target.squeeze(1)
    argmax = (input.argmax(dim=1))
    batch_size = input.shape[0]
    cats = [1, 2, 3, 4]
    total = torch.empty([batch_size])
    for b in range(batch_size):
        total[b] = calc_dice_coefficients(argmax[b], target[b], cats)
    return total.mean()

def calc_dice_coefficients_2(argmax, target, cats):
    def calc_dice_coefficient(seg, gt, cat: int):
        mask_seg = seg == cat
        mask_gt = gt == cat
        sum_seg = torch.sum(mask_seg.float())
        sum_gt = torch.sum(mask_gt.float())
        return (torch.sum((seg[gt == cat] / cat).float())), (sum_seg + sum_gt)

    total_avg = torch.empty([len(cats), 2])
    for i, c in enumerate(cats):
        total_avg[i][0], total_avg[i][1] = calc_dice_coefficient(argmax, target, c)
    total_sum = total_avg.sum(axis=0)
    if (total_sum[1] == 0.0):
        return torch.tensor(1.0)
    return total_sum[0] * 2.0 / total_sum[1]


def dice_coefficient_2(input, target):
    target = target.squeeze(1)
    argmax = (input.argmax(dim=1))
    batch_size = input.shape[0]
    cats = [1, 2, 3, 4]
    total = torch.empty([batch_size])
    for b in range(batch_size):
        total[b] = calc_dice_coefficients_2(argmax[b], target[b], cats)
    return total.mean()


def accuracy_simple(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()


def dice_coeff(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

### Customized loss function

In [15]:
class CombinedDiceLoss(nn.Module):
    def __init__(self, zero_cat_factor=0.1):
        super().__init__()
        self.zero_cat_factor = zero_cat_factor

    def forward(self, input, target):
        return self.dice_loss(target, input, self.zero_cat_factor)

    def dice_loss(self, target, output, eps=1e-7, zero_cat_factor=0.1):
        '''
        Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
        Assumes the `channels_last` format.

        # Arguments
            target: b x 1 x X x Y( x Z...) ground truth
            output: b x c x X x Y( x Z...) Network output, must sum to 1 over c channel (such as after softmax)
            epsilon: Used for numerical stability to avoid divide by zero errors

        # References
            V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
            https://arxiv.org/abs/1606.04797
            More details on Dice loss formulation
            https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)

            Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
        '''

        # skip the batch and class axis for calculating Dice score
        num_classes = output.shape[1]
        y_true = F.one_hot(target.long().squeeze(), num_classes)
        y_pred = F.softmax(output, dim=1).permute(0, 2, 3, 1)
        y_true = y_true.type(y_pred.type())
        y_true = y_true.permute(0, 3, 1, 2)
        y_true[:,0,:] *= zero_cat_factor # Factor used to take power away from the zeroth category
        y_true = y_true.permute(0, 2, 3, 1)
        axes = tuple(range(1, len(y_pred.shape)-1))
        numerator = 2. * torch.sum(y_pred * y_true, axes)
        denominator = torch.sum(y_pred ** 2 + y_true ** 2, axes)
        # When intersection and cardinality are all zero you have 100% score and not 0% score
        # For this we use the eps parameter
        loss_array = ((numerator + eps) / (denominator + eps))
        loss_array = (loss_array).mean(dim=0)
        return ((1 - torch.mean(loss_array)) + F.cross_entropy(output, target.squeeze())) / 2.

    def __del__(self): pass

#### Prediction Functions

In [16]:
def predict(img_path):
    pred_class, pred_idx, outputs = inference_learn.predict(open_image(str(img_path)))
    return pred_class, pred_idx, outputs

def encode_classes(pred_class_data):
    pixels = np.concatenate([[0], torch.transpose(pred_class_data.squeeze(), 0, 1).flatten(), [0]])
    classes_dict = {1: [], 2: [], 3: [], 4: []}
    count = 0
    previous = pixels[0]
    for i, val in enumerate(pixels):
        if val != previous:
            if previous in classes_dict:
                classes_dict[previous].append((i - count, count))
            count = 0
        previous = val
        count += 1
    return classes_dict


def convert_classes_to_text(classes_dict, clazz):
    return ' '.join([f'{v[0]} {v[1]}' for v in classes_dict[clazz]])

### Create and Load Model

In [17]:
import time

def create_tta_predictions(start_pos, end_pos, model_to_load='export-4-best', append_to_file=False, submission_file='submission.csv'):
    print(f'TTA prediction from {start_pos} to {end_pos}.')
    bs = 4
    src, data = create_data_bunch(bs, src_size, start_pos, end_pos)
    metrics=accuracy_simple, acc_camvid_with_zero_check, dice_coefficient, dice_coefficient_2
    learn = unet_learner(data, models.resnet34, metrics=metrics, wd=1e-2, bottle=True)
    learn.loss_func = CombinedDiceLoss(zero_cat_factor=0.5)
    learn.model_dir = Path('/kaggle/model')
    learn.load(model_to_load);
    # ys = final_preds
    ys, y = learn.TTA(scale=1.1, ds_type=DatasetType.Test)
    # get the actual predictions
    pred_class = torch.argmax(ys, dim=1)
    test_images = data.test_dl.dataset.items
    assert pred_class.shape[0] == test_images.shape[0], f'{pred_class.shape[0]} != {test_images.shape[0]}'
    test_images = [Path(f) for f in test_images]
    
    start_time = time.time()

    defect_classes = [1, 2, 3, 4]
    append_flag = 'a' if append_to_file else 'w'
    with open(submission_file, append_flag) as submission_file:
        if append_to_file:
            submission_file.write('ImageId_ClassId,EncodedPixels\n')
        for i, test_image in enumerate(test_images):
            encoded_all = encode_classes(pred_class[i])
            for defect_class in defect_classes:
                submission_file.write(f'{test_image.name}_{defect_class},{convert_classes_to_text(encoded_all, defect_class)}\n')
            if i % 5 == 0:
                print(f'Processed {i} images\r', end='')

    print(f"--- {time.time() - start_time} seconds ---")

In [None]:
!rm submission.csv
test_files = get_image_files(path/'test_images')
batch_size = len(test_files) // 2
create_tta_predictions(0, batch_size)
create_tta_predictions(batch_size, len(test_files), append_to_file=True)

### Loop through the test images and create submission csv

In [20]:
!ls -latr 'submission.csv'

-rw-r--r-- 1 root root 114314 Oct 18 16:13 submission.csv


In [21]:
!head -n200 'submission.csv'

1804f41eb.jpg_1,
1804f41eb.jpg_2,
1804f41eb.jpg_3,
1804f41eb.jpg_4,
c90f155dd.jpg_1,
c90f155dd.jpg_2,
c90f155dd.jpg_3,
c90f155dd.jpg_4,
e0b422958.jpg_1,
e0b422958.jpg_2,
e0b422958.jpg_3,
e0b422958.jpg_4,
a631d53aa.jpg_1,
a631d53aa.jpg_2,
a631d53aa.jpg_3,
a631d53aa.jpg_4,
d01da361f.jpg_1,
d01da361f.jpg_2,
d01da361f.jpg_3,
d01da361f.jpg_4,
86fe3cf8c.jpg_1,
86fe3cf8c.jpg_2,
86fe3cf8c.jpg_3,
86fe3cf8c.jpg_4,
54eb4b690.jpg_1,
54eb4b690.jpg_2,
54eb4b690.jpg_3,
54eb4b690.jpg_4,
2efa6b22f.jpg_1,
2efa6b22f.jpg_2,
2efa6b22f.jpg_3,
2efa6b22f.jpg_4,
d6128fbfc.jpg_1,
d6128fbfc.jpg_2,
d6128fbfc.jpg_3,
d6128fbfc.jpg_4,
f625f93a1.jpg_1,
f625f93a1.jpg_2,
f625f93a1.jpg_3,
f625f93a1.jpg_4,
499a9893b.jpg_1,
499a9893b.jpg_2,
499a9893b.jpg_3,
499a9893b.jpg_4,
4c5671c92.jpg_1,
4c5671c92.jpg_2,
4c5671c92.jpg_3,212130 1 212133 2 212141 7 212319 1 212321 2 212354 1 212362 1 212366 1 212382 1 212385 2 212389 3 212395 10 212566 1 212569 2 212573 10 212585 1 212609 2 2

### Alternative prediction methods

In [None]:
preds,y = learn.get_preds(ds_type=DatasetType.Test, with_loss=False)

In [None]:
preds.shape

In [None]:
pred_class_data = preds.argmax(dim=1)

In [None]:
len((path/'test_images').ls())

In [None]:
data.test_ds.x

#### Checking encoding methods

In [None]:
encoded_all = encode_classes(pred_class.data)
print(convert_classes_to_text(encoded_all, 3))

In [None]:
image_name = train_images[16]
print(get_y_fn(image_name))
img = open_mask(get_y_fn(image_name))
img_data = img.data
print(convert_classes_to_text(encode_classes(img_data), 3))
img_data.shape