In [None]:
#default_exp learner
from nbdev.showdoc import show_doc

In [None]:
#export
from fastai.vision.all import *
from fastcore.foundation import patch
from deepflash2.data import TileDataset
from scipy.stats import entropy
import ttach as tta

# Patches for the `fastai` Learner

> Imlements functions necessary to build `Learner` suitable for bioimgage segmentation

In [None]:
#export
@patch
def apply_dropout(self:Learner):
    "If a module contains 'dropout', it will be switched to .train() mode."
    for m in self.model.modules():
        if isinstance(m, nn.Dropout):
            m.train()

In [None]:
#export
@patch
def predict_tiles(self:Learner, ds_idx=1, dl=None, mc_dropout=False, n_times=1, use_tta=False, use_max=False):
    "Make predictions and reconstruct tiles, optional with dropout and/or tta applied."

    if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
    if use_tta: tfms=[tta.HorizontalFlip(), tta.Rotate90(angles=[90,180,270])]
    else: tfms=[]
    
    self.model.eval()
    if mc_dropout: self.apply_dropout()

    mean_list = []
    std_list = []
    for data in progress_bar(dl, leave=False):
        if isinstance(data, TensorImage): images = data
        else: images, _, _ = data
        out_list = [] 
        for t in tta.Compose(tfms): 
            for _ in range(n_times):
                #augment image
                aug_images = t.augment_image(images)           
                #predict
                with torch.no_grad():
                    out = self.model(aug_images)
                out = F.softmax(out, dim=1)
                #reverse augmentation for mask
                out = t.deaugment_mask(out)
                out_list.append(out)
        out_stack = torch.stack(out_list)
        out_means = torch.max(out_stack, dim=0)[0] if use_max else torch.mean(out_stack, dim=0)
        mean_list.append(out_means)

        out_sdts = torch.std(out_stack, dim=0)
        std_list.append(out_sdts)

    softmax_pred = torch.cat(mean_list).permute(0,2,3,1)
    smx_tiles = [x for x in softmax_pred.cpu().numpy()]

    std_pred = torch.cat(std_list).permute(0,2,3,1)
    std_tiles = [x for x in std_pred.cpu().numpy()]

    smxcores = dl.reconstruct_from_tiles(smx_tiles)
    segmentations = [np.argmax(x, axis=-1) for x in smxcores]
    std_deviations = dl.reconstruct_from_tiles(std_tiles)

    return smxcores, segmentations, std_deviations

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted add_information.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
