In [None]:
#| include: false
#all_slow

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/marcomatteo/steel_segmentation/blob/master/dev_nbs/21_ensemble_unet_fpn_resnet34.ipynb)

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from steel_segmentation.all import *
from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

import warnings
import random
import os
import cv2
import pandas as pd
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

seed_everything()

In [None]:
torch.device("cuda:0").type

'cuda'

In [None]:
torch.cuda.is_available()

True

In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

Sun Mar 14 17:25:16 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.36.06    Driver Version: 450.36.06    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro P5000        On   | 00000000:00:05.0 Off |                  Off |
| 32%   48C    P0    46W / 180W |      4MiB / 16278MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
print_competition_data(models_dir)

../models/fastai-UNET-ResNet34-256-stage5.pth
../models/fastai-UNET-XResNeXt34-128x800.pth
../models/fastai-UNET-ResNet34-smp-pytorch_dls-stage1.pth
../models/fastai-UNET-XResNeXt34-128x800-finetuning.pth
../models/.ipynb_checkpoints
../models/kaggle-UNET-ResNet34.pth
../models/kaggle-FPN-ResNet34.pth


In [None]:
device = torch.device("cuda")
# device = torch.device("cpu")

In [None]:
# testset = get_test_dls(batch_size=2)
testset = get_test_dls(root=train_path, df=train_multi, batch_size=4)

In [None]:
name, x = next(iter(testset))
x.shape

torch.Size([4, 3, 256, 1600])

## UNET model

In [None]:
dls = get_segmentation_dls(4, (256, 1600), with_btfms=False)

In [None]:
unet_trainer = unet_learner(dls=dls, arch=resnet34, metrics=seg_metrics, pretrained=True)
unet_trainer.model_dir = models_dir
unet_trainer = unet_trainer.load("fastai-UNET-ResNet34-256-stage5")

In [None]:
unet_model = unet_trainer.model
unet_model.to(device)
unet_model = unet_model.eval()

In [None]:
unet_preds = torch.sigmoid(unet_model(x.to(device)))
unet_preds.shape

In [None]:
unet_preds[:, 1:].shape

## FPN model

In [None]:
fpn_model = smp.FPN("resnet34", encoder_weights='imagenet', classes=4, activation=None)

In [None]:
loaded_dict = torch.load(models_dir/"kaggle-FPN-ResNet34.pth")
fpn_model.load_state_dict(loaded_dict["state_dict"], strict=True)

<All keys matched successfully>

In [None]:
fpn_model.to(device)
fpn_model = fpn_model.eval()

In [None]:
fpn_preds = torch.sigmoid(fpn_model(x.to(device)))
fpn_preds.shape

## Ensemble

As an ensemble we build a specific `nn.Module` class to get the predictions.

In [None]:
class Ensemble(nn.Module):
    
    def __init__(self, models):  
        super(Ensemble, self).__init__()
        self.models = models
    
    def forward(self, x):
        preds = [model(x.clone()) for model in self.models]
        probs = map(f.sigmoid, preds)
        return torch.cat(list(probs), axis=1).mean(axis=1)

In [None]:
models = [fpn_model, unet_model] # not working because unet is 5 classes and fpn 4 classes
Ensemble()

## Inference

In [None]:
def post_process(probability, threshold, min_size):
    """
    Post processing of each predicted mask, components with lesser number of pixels
    than `min_size` are ignored
    """
    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = np.zeros((256, 1600), np.float32)
    num = 0
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            predictions[p] = 1
            num += 1
    return predictions, num

In [None]:
best_threshold = 0.5
min_size = 3000

In [None]:
min_sizes = [3000, 3000, 3000, 3000]

In [None]:
# start prediction on validation set
predictions = []
for i, batch in enumerate(tqdm(testset)):
    fnames, images = batch
    images = images.to(device)
    
    # FPN
    fpn_preds = torch.sigmoid(fpn_model(images))
    fpn_preds = fpn_preds.detach().cpu().numpy()
    
    # UNET
    unet_preds = torch.sigmoid(unet_model(images))
    unet_preds = unet_preds[:, 1:].detach().cpu().numpy()
    
    batch_preds = (fpn_preds + unet_preds) / 2
    for fname, preds in zip(fnames, batch_preds):
        for cls, pred in enumerate(preds):
            min_size = min_sizes[cls]
            pred, num = post_process(pred, best_threshold, min_size)
            rle = mask2rle(pred)
            name = fname + f"_{cls+1}"
            predictions.append([name, rle])

# save predictions to submission.csv
df = pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
df.to_csv(sub_path/"ensemble_validation.csv", index=False)

100%|██████████| 3142/3142 [19:31<00:00,  2.68it/s]


In [None]:
# start prediction on test set
predictions = []
for i, batch in enumerate(tqdm(testset)):
    fnames, images = batch
    images = images.to(device)
    
    # FPN
    fpn_preds = torch.sigmoid(fpn_model(images))
    fpn_preds = fpn_preds.detach().cpu().numpy()
    
    # UNET
    unet_preds = torch.sigmoid(unet_model(images))
    unet_preds = unet_preds[:, 1:].detach().cpu().numpy()
    
    batch_preds = (fpn_preds + unet_preds) / 2
    for fname, preds in zip(fnames, batch_preds):
        for cls, pred in enumerate(preds):
            min_size = min_sizes[cls]
            pred, num = post_process(pred, best_threshold, min_size)
            rle = mask2rle(pred)
            name = fname + f"_{cls+1}"
            predictions.append([name, rle])

# save predictions to submission.csv
df = pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
df.to_csv(sub_path/"ensemble_submission.csv", index=False)

100%|██████████| 2753/2753 [08:45<00:00,  5.24it/s]


In [None]:
df.head()

Unnamed: 0,ImageId_ClassId,EncodedPixels
0,0002cc93b.jpg_1,76644 13 76895 22 77149 26 77403 30 77657 36 77694 1 77696 1 77912 43 78166 49 78216 1 78218 1 78220 1 78422 55 78677 72 78750 25 78776 1 78933 104 79188 107 79445 106 79700 109 79956 110 80212 111 80469 110 80725 112 80982 111 81238 111 81494 111 81750 113 82007 112 82263 50 82316 1 82318 57 82520 37 82580 51 82777 34 82837 52 83033 32 83095 50 83290 31 83352 51 83546 31 83609 50 83802 31 83867 54 83922 1 83938 1 83940 1 83942 1 83944 1 84059 28 84125 58 84184 1 84190 11 84315 28 84383 82 84466 1 84468 1 84470 3 84572 25 84654 82 84829 24 84914 79 85085 23 85177 72 85342 20 85435 70 85599...
1,0002cc93b.jpg_2,
2,0002cc93b.jpg_3,
3,0002cc93b.jpg_4,
4,00031f466.jpg_1,


## Refrences

