In [None]:
# default_exp models.predict

# Predict

> Prediction and export outputs.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from steel_segmentation.core import *
from steel_segmentation.data import *
from steel_segmentation.dataloaders import *
from steel_segmentation.models.metrics import *
from steel_segmentation.models.model import *

import fastai
from fastai.vision.all import *
from fastai.metrics import *
from fastai.data.all import *

import cv2
import pathlib
import numpy as np
import pandas as pd

pred_path = path.parent / "predictions"
pred_path.mkdir(parents=True, exist_ok=True)

In [None]:
# hide
only_imgs = ["0a1cade03.jpg", "bca4ae758.jpg", "988cf521f.jpg", "b6a257b28.jpg",
             "b2ad335bf.jpg", "72aaba8ad.jpg", "f383950e8.jpg"]
train = train[train["ImageId"].isin(only_imgs)].copy()
train_all = train_all[train_all["ImageId"].isin(only_imgs)].copy()
train_multi = train_multi[train_multi["ImageId"].isin(only_imgs)].copy()

In this notebook the *functions* provided can be used to make inference with the models trained and saved in `models_dir`.

The first section is related to the fast.ai API while the second is for a general Pytorch approach.

In [None]:
print_competition_data(models_dir)

../models/ResNet18-Unet-kaggle.pth
../models/.ipynb_checkpoints
../models/ResNet18-Classifier-kaggle.pth
../models/ResNet34-Unet-256-stage5.pth
../models/ResNet34-Unet-256-stage3.pth
../models/ResNet34-Unet-128-stage2.5.pth


## Fast.ai prediction

In [None]:
# export
class Predict:
    
    def __init__(self, 
                 source, 
                 learner,
                 source_path:pathlib.Path=None):
        self.source = source 
        self.learner = learner
        self.source_path = source_path if source_path else train_path
        
        self.single_prediction = False
        
        if isinstance(self.source, pathlib.Path):
            self.img_paths = self.get_path_source_list()
        elif isinstance(self.source, pd.DataFrame):
            self.img_paths = self.get_df_source_list()
        elif isinstance(self.source, str):
            self.single_prediction = True
            self.img_paths = L(self.source_path / self.source)

        self.elems = len(self.img_paths)
    
    def get_df_source_list(self):
        """Load `source` if it's a DataFrame instance."""
        tfm = ColReader("ImageId", pref=self.source_path)
        return L([tfm(o) for o in self.source.itertuples()])
        
    def get_path_source_list(self):
        """Load `source` if it's a pathlib.Path instance."""
        return get_image_files(self.source)
    
    def predict(self, selected_imgs):
        """Get the predictions on the `selected_imgs`."""
        if self.single_prediction:
            pred_full_dec, pred_dec, out = self.learner.predict(selected_imgs)
            return pred_full_dec
            
        test_dl = self.learner.dls.test_dl(test_items=selected_imgs)
        pred_probs,_,_ = self.learner.get_preds(dl=test_dl, with_decoded=True)
        
        return pred_probs
    
    def post_process(self, probability):
        """
        Post processing of each predicted mask, components with lesser number of pixels
        than `min_size` are ignored.
        """
        mask = cv2.threshold(probability, self.threshold, 1, cv2.THRESH_BINARY)[1]
        num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
        predictions = np.zeros((256, 1600), np.float32)
        num = 0
        for c in range(1, num_component):
            p = (component == c)
            if p.sum() > self.min_size:
                predictions[p] = 1
                num += 1
        return predictions, num
    
    def get_RLEs(self, img_names, pred_probs):
        """For each image in the batch to predict, 
        returns a list of tuples with `(img_id, rle)`."""
        predictions = []
        
        for num_pred, t_pred in enumerate(pred_probs): # img in bs
            np_pred = t_pred.numpy()

            # iterate through class_id without class_id 0
            for class_id, prob in enumerate(np_pred[1:]): 
                pred, num = self.post_process(prob)
                rle = mask2rle(pred)
                name = img_names[num_pred] + f"_{class_id+1}"
                predictions.append([name, rle])
        
        return pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
    
    def save_df(self, df, file_name):
        """Save the final DataFrame into the `pred_path` folder."""
        df.to_csv(pred_path/file_name, index=False)
    
    def get_predictions(self, size_fold:int, threshold:float, min_size:int):
        self.size_fold = min([self.elems, size_fold])
        self.threshold = threshold
        self.min_size = min_size
        self.folds = self.elems // self.size_fold
        if (self.elems % self.size_fold) != 0:
            self.folds += 1
        
        df_preds = []
        
        for fold in range(self.folds):
            start, end = fold*self.size_fold, (fold+1)*self.size_fold
            print(f"From {start} to {end} of {self.elems}")
            
            selected_imgs = self.img_paths[start:end]
            
            img_names = selected_imgs.map(Path).map(lambda x: x.name)
            pred_probs = self.predict(selected_imgs)            
            
            tmp_df = self.get_RLEs(img_names, pred_probs)
            df_preds.append(tmp_df)

            torch.cuda.empty_cache()
        
        if len(df_preds) == 1:
            return df_preds[0]
        
        df = pd.concat(df_preds, axis=0, ignore_index=True)
        return df.fillna("", inplace=True)

First, we need to load a `segmentation_learner` with the right `parameters`.

In [None]:
arch = resnet34
bs = 4 
dls = get_segmentation_dls_from_df(train_df=train_multi, bs=bs, size=(256, 1600))
segmentation_learner = unet_learner(dls=dls, arch=arch, metrics=seg_metrics, pretrained=True)
segmentation_learner.model_dir = models_dir

In [None]:
# missing
segmentation_learner = segmentation_learner.load("ResNet34-Unet-256-stage5")

Next we need a `source` as a list of images to infer. The source can be a `folder_path` or a `df_col` to read. 

In [None]:
df_tmp = train.iloc[0].to_frame().T

# test arguments
size_fold = 100
threshold = 0.5
min_size = 3000

In [None]:
pred = Predict(df_tmp, segmentation_learner)

In [None]:
show_doc(Predict.get_predictions)

<h4 id="Predict.get_predictions" class="doc_header"><code>Predict.get_predictions</code><a href="__main__.py#L80" class="source_link" style="float:right">[source]</a></h4>

> <code>Predict.get_predictions</code>(**`size_fold`**:`int`, **`threshold`**:`float`, **`min_size`**:`int`)



In [None]:
df_pred = pred.get_predictions(size_fold, threshold, min_size)

From 0 to 1 of 1


  return np.nanmean(binary_dice_scores)


In [None]:
df_pred

Unnamed: 0,ImageId_ClassId,EncodedPixels
0,0002cc93b.jpg_1,77411 11 77665 17 77918 20 78172 24 78427 26 78681 31 78937 30 79194 32 79229 1 79231 7 79239 1 79241 1 79271 1 79273 1 79275 1 79277 1 79279 1 79281 1 79283 1 79449 33 79485 1 79487 1 79489 9 79499 1 79519 1 79521 1 79523 1 79525 1 79527 1 79529 13 79543 1 79705 53 79759 1 79771 1 79773 31 79962 31 79996 1 79999 14 80015 1 80019 1 80021 1 80023 1 80025 35 80218 31 80255 3 80259 11 80271 48 80474 31 80513 13 80527 1 80529 1 80531 1 80533 43 80730 32 80767 67 80987 30 81024 1 81027 2 81030 3 81034 1 81036 1 81040 1 81043 46 81243 30 81279 15 81295 2 81299 47 81499 29 81538 1 81540 1 81542 1...
1,0002cc93b.jpg_2,
2,0002cc93b.jpg_3,
3,0002cc93b.jpg_4,


In [None]:
# with 10 elements
df_tmp = segmentation_learner.dls.valid.items.iloc[:5]
df_tmp.shape

(5, 2)

In [None]:
pred = Predict(df_tmp, segmentation_learner)

In [None]:
df_pred = pred.get_predictions(size_fold, threshold, min_size)

From 0 to 5 of 5


  return np.nanmean(binary_dice_scores)


In [None]:
print(df_pred.shape)
df_pred.head()

(20, 2)


Unnamed: 0,ImageId_ClassId,EncodedPixels
0,b5352d213.jpg_1,
1,b5352d213.jpg_2,
2,b5352d213.jpg_3,219065 4 219317 11 219572 13 219828 13 220083 14 220338 17 220594 17 220849 18 221104 20 221360 20 221616 21 221871 22 222126 23 222382 24 222638 24 222893 25 223149 25 223405 26 223661 26 223916 27 224172 27 224427 28 224683 28 224938 29 225194 29 225450 29 225706 29 225961 30 226217 30 226473 30 226729 30 226985 29 227240 30 227496 30 227752 30 228007 31 228263 31 228519 30 228774 31 229030 31 229286 31 229542 30 229798 30 230054 30 230310 29 230566 29 230821 30 231077 29 231333 29 231589 28 231845 28 232101 28 232357 28 232613 27 232870 26 233126 25 233382 25 233638 25 233894 25 234149 ...
3,b5352d213.jpg_4,
4,ecb50399d.jpg_1,


## Pytorch prediction

In [None]:
pass

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()

Converted 01_core.ipynb.
Converted 02_data.ipynb.
Converted 03_dataloaders.ipynb.
Converted 04_model.metrics.ipynb.
Converted 05_models.unet.ipynb.
Converted 06_models.model.ipynb.
Converted 07_model.predict.ipynb.
Converted index.ipynb.
