In [None]:
# default_exp learner

In [None]:
#export 
from fastai2.vision.all import *
from fastcore.foundation import patch
from deepflash2.data import TileDataset
from scipy.stats import entropy

#export 
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from fastai2.learner import Learner
from fastprogress.fastprogress import progress_bar

# Patches for the `fastai` Learner

> Imlements functions necessary to build `Learner` suitable for bioimgage segmentation

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#export 
@patch
def predict_from_tiles(self:Learner, dl=None, tile_ds:TileDataset=None):
        "Predict and reconstruct images from tile dataset."
        
        if dl is None:
            dl = self.dls.valid
        if tile_ds is None:
            tile_ds = self.dls.valid_ds

        softmax_score, _ = self.get_preds(dl=dl)
        softmax_score = softmax_score.cpu().numpy()
        softmax_score = np.moveaxis(softmax_score, 1,-1)

        smxcores = tile_ds.reconstruct_from_tiles(softmax_score)
        segmentations = [np.argmax(x, axis=-1) for x in smxcores]

        return smxcores, segmentations

In [None]:
#export 
@patch
def apply_dropout(self:Learner):
    "If a module contains 'dropout', it will be switched to .train() mode."
    for m in self.model.modules():
        if isinstance(m, nn.Dropout):  m.train()

In [None]:
#export 
@patch
def predict_tiles_with_mc_dropout(self:Learner, dl=None, tile_ds=None, n_times=20):
    "Make predictions with dropout applied."

    if dl is None:
        dl = self.dls.valid
    if tile_ds is None:
        tile_ds = self.dls.valid_ds

    self.model.eval()
    self.apply_dropout()

    mean_list = []
    std_list = []
    for data in progress_bar(dl):
        images, _ = data
        out_list = []
        for t in range(n_times):
            with torch.no_grad():
                out = self.model(images)
            out = F.softmax(out, dim=1)
            out_list.append(out)
        out_stack = torch.stack(out_list)

        out_means = torch.mean(out_stack, dim=0)
        mean_list.append(out_means)

        out_sdts = torch.std(out_stack, dim=0)
        std_list.append(out_sdts)

    softmax_score = torch.cat(mean_list).cpu().numpy()
    softmax_score = np.moveaxis(softmax_score, 1,-1)

    std_scores = torch.cat(std_list).cpu().numpy()
    std_scores = np.moveaxis(std_scores, 1,-1)
    smxcores = tile_ds.reconstruct_from_tiles(softmax_score)
    segmentations = [np.argmax(x, axis=-1) for x in smxcores]
    std_deviations = tile_ds.reconstruct_from_tiles(std_scores)

    return smxcores, segmentations, std_deviations

In [None]:
###export 
@patch
def get_mc_dropout_results(self, plot=True, dl=None, tile_ds:TileDataset=None, 
                           max_n=9, n_times=20, figsize=(15,15), **kwargs):
    "Get results with MC Dropout enabled. Plot results is enabled by default."
    if dl is None:
        dl = self.dls.valid
    if tile_ds is None:
        tile_ds = self.dls.valid_ds    
    
    smxs, segs, std_devs = self.predict_tiles_with_mc_dropout(dl, tile_ds, n_times)
    entrp = {tile_ds.files[i]:std_devs[i] for i in range(len(tile_ds.files))}
    
    if plot==True:
        imgs = tile_ds.get_images()
        for i, path in enumerate(tile_ds.files):
            img = imgs[i]
            msk = tile_ds.lbl_wgt_pdf[path.name][0] if path.name in tile_ds.lbl_wgt_pdf else np.ones_like(imgs)
            pred = segs[i]
            std_dev = std_devs[i]
            entr = entropy(std_dev[...,1]).mean()
            ser_tmp = pd.Series({'File' : path.name, 'Entropy': entr})
            fig, axs = plt.subplots(nrows=1, ncols=4, figsize=figsize)
                        
            
            axs[0].imshow(imgs[i], cmap='binary_r')
            axs[0].set_axis_off()
            axs[0].set_title('Image {}'.format(path.name))
            
            axs[1].imshow(msk, cmap='binary_r')
            axs[1].set_axis_off()
            axs[1].set_title('Target')

            
            axs[3].set_title('Std ({} Entropy)'.format(np.round(entrop,2)))
    
    return smxs, segs, std_devs

In [None]:

def test_results(self:Learner, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):
    if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
    b = dl.one_batch()
    _, _, preds  = self.get_preds(dl=[b], with_decoded=True)
    print(preds.shape)
    print(b)
    self.dls.show_results(b, preds, max_n=max_n, **kwargs)

In [None]:
@patch
def siampredict(self:Learner, item, rm_type_tfms=None, with_input=False):
    res = self.predict(item, rm_type_tfms=None, with_input=False)
    if res[0] == tensor(0):
        SiameseImage(item[0], item[1], 'Prediction: Not similar').show()
    else:
        SiameseImage(item[0], item[1], 'Prediction: Similar').show()
    return res

In [None]:
@typedispatch
def show_results(x:ImgMskTuple, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):
    if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
    b = dl.one_batch()
    _, _, preds  = self.get_preds(dl=[b], with_decoded=True)
    print(preds.shape)
    print(b)
    self.dls.show_results(b, preds, max_n=max_n, **kwargs)

In [None]:

@typedispatch
def show_results(x:ImgMskTuple, y, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
    if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    for i,ctx in enumerate(ctxs): 
        title = f'Actual: {["Not similar","Similar"][x[2][i].item()]} \n Prediction: {["Not similar","Similar"][y[2][i].item()]}'
        SiameseImage(x[0][i], x[1][i], title).show(ctx=ctx)

## Export 

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted index.ipynb.
