In [None]:
# default_exp bayes_inference

# Bayes Inference 

> * adding a measure of uncertainty to predictions


* useful for detecting out of distribution (OOD) samples in your data
* works without an OOD sample dataset
* works with existing trained models
* tradeoff: slower inference due to sampling over distribution
* behind the scenes -> uses the MonteCarlo Dropout Callback (MCDropoutCallback) 
* based on the article : [Bayesian deep learning with Fastai : how not to be uncertain about your uncertainty !](https://towardsdatascience.com/bayesian-deep-learning-with-fastai-how-not-to-be-uncertain-about-your-uncertainty-6a99d1aa686e)
* and on the [github code](https://github.com/dhuynh95/fastai_bayesian) by Daniel Huynh
* updated for fastai v2

In [None]:
#ci
#hide
!pip install -Uqq fastai --upgrade
!pip install -Uqq seaborn

In [None]:
#local
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#export
from fastai.callback.preds import MCDropoutCallback
from fastai.learner import Learner
from fastcore.foundation import patch, L
from fastai.torch_core import to_np

In [None]:
#export
from collections import Counter
import seaborn as sns
import torch

#### Bayesian Metrics

This is modified  from fastai_bayesian github code by Daniel Huynh 
but modified to use Pytorch tensors instead of Numpy arrays

In [None]:
#export
def entropy(probs):
    """Return the prediction of a T*N*C tensor with :
        - T : the number of samples
        - N : the batch size
        - C : the number of classes
    """
    mean_probs = probs.mean(dim=0)
    entrop = - (torch.log(mean_probs) * mean_probs).sum(dim=1)
    return entrop

def uncertainty_best_probability(probs):
    """Return the standard deviation of the most probable class"""
    idx = probs.mean(dim=0).argmax(dim=1)

    std = probs[:, torch.arange(len(idx)), idx].std(dim=0)

    return std

def BALD(probs):
    """Information Gain, distance between the entropy of averages and average of entropy"""
    entrop1 = entropy(probs)
    entrop2 = - (torch.log(probs) * probs).sum(dim=2)
    entrop2 = entrop2.mean(dim=0)

    ig = entrop1 - entrop2
    return ig

def top_k_uncertainty(s, k=5, reverse=True):
    """Return the top k indexes"""
    sorted_s = sorted(list(zip(torch.arange(len(s)), s)),
                      key=lambda x: x[1], reverse=reverse)
    output = [sorted_s[i][0] for i in range(k)]
    
def plot_hist_groups(pred,y,metric,bins=None,figsize=(16,16)):
    TP = to_np((pred.mean(dim=0).argmax(dim=1) == y) & (y == 1))
    TN = to_np((pred.mean(dim=0).argmax(dim=1) == y) & (y == 0))
    FP = to_np((pred.mean(dim=0).argmax(dim=1) != y) & (y == 0))
    FN = to_np((pred.mean(dim=0).argmax(dim=1) != y) & (y == 1))
    
    result = metric(pred)
    
    TP_result = result[TP]
    TN_result = result[TN]
    FP_result = result[FP]
    FN_result = result[FN]
    
    fig,ax = plt.subplots(2,2,figsize=figsize)
    
    sns.distplot(TP_result,ax=ax[0,0],bins=bins)
    ax[0,0].set_title(f"True positive")
    
    sns.distplot(TN_result,ax=ax[0,1],bins=bins)
    ax[0,1].set_title(f"True negative")
    
    sns.distplot(FP_result,ax=ax[1,0],bins=bins)
    ax[1,0].set_title(f"False positive")
    
    sns.distplot(FN_result,ax=ax[1,1],bins=bins)
    ax[1,1].set_title(f"False negative")
    return output

##### Get predictions for a test set
This patches a method to learner to make mc dropout predictions

In [None]:
#export
@patch
def bayes_get_preds(self:Learner, ds_idx=1, dl=None, n_sample=10, 
                    act=None,with_loss=False, **kwargs):
    """Get MC Dropout predictions from a learner, and eventually reduce the samples"""  
    cbs = [MCDropoutCallback()]
    if 'cbs' in kwargs:
        kw_cbs = kwargs.pop('cbs') 
        if 'MCDropoutCallback' not in L(kw_cbs).attrgot('name'):
            cbs = kw_cbs + cbs
    preds = []        
    with self.no_bar():
        for i in range(n_sample):
            pred, y = self.get_preds(ds_idx=ds_idx,dl=dl,act=act,
                                     with_loss=with_loss, cbs=cbs, **kwargs)
            # pred = n_dl x n_vocab
            preds.append(pred)
    preds = torch.stack(preds)
    ents = entropy(preds)
    mean_preds = preds.mean(dim=0)
    max_preds = mean_preds.max(dim=1)
    best_guess = max_preds.indices
    best_prob = max_preds.values
    best_cat = L(best_guess,use_list=True).map(lambda o: self.dls.vocab[o.item()])
    return preds, mean_preds, ents,best_guess, best_prob, best_cat 

##### Get predictions for an image item

In [None]:
#export
@patch
def bayes_predict(self:Learner,item, rm_type_tfms=None, with_input=False,
                  sample_size=10,reduce=True):
    "gets a sample distribution of predictions and computes entropy"
    dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
    
    # modify get_preds to get distributed samples
    collect_preds = []
    collect_targs = []
    collect_dec_preds = []
    collect_inp = None
    cbs = [MCDropoutCallback()]
    with self.no_bar():
        for j in range(sample_size):
            inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True,
                                                   with_decoded=True, 
                                                   cbs=cbs)
            i = getattr(self.dls, 'n_inp', -1)
            inp = (inp,) if i==1 else tuplify(inp)
            dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]
            dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
            # res = dec_targ,dec_preds[0],preds[0]
            if with_input and collect_inp is None: # collect inp first iter only
                   collect_inp = dec_inp                                     
            collect_targs.append(dec_targ)
            collect_dec_preds.append(dec_preds[0])
            collect_preds.append(preds[0])
    dist_preds = torch.stack(collect_preds) 
    dist_dec_preds = L(collect_dec_preds).map(lambda o: o.item())
    dist_targs = L(collect_targs)
    res1 = (dist_targs, dist_dec_preds, dist_preds) 
    
    mean_pred = dist_preds.mean(dim=0)
    ent = entropy(dist_preds.unsqueeze(1)).item()
    best_guess = torch.argmax(mean_pred).item()
    best_prob = mean_pred[best_guess].item()
    best_cat = self.dls.vocab[best_guess]
    res2 = (ent, best_prob, best_guess, best_cat)
    
    if reduce:
        if len(dist_targs.unique()) > 1:
            targ = Counter(dist_targs)
        else:
            targ = dist_targs.unique()[0]
            
        if len(dist_dec_preds.unique()) > 1:
            dec_pred = Counter(dist_dec_preds)
        else:
            dec_pred = dist_dec_preds.unique()[0]
        res1 = (targ, dec_pred, mean_pred)
    
    res = res1 + res2
    if with_input:
        res = (collect_inp,) + res
    return res
        

##### Add uncertainty threshold to prediction

In [None]:
#export
@patch
def bayes_predict_with_uncertainty(self:Learner, item, rm_type_tfms=None, with_input=False, threshold_entropy=0.2, sample_size=10, reduce=True):
    "gets prediction results plus if prediction passes entropy threshold"
    res = self.bayes_predict(item,rm_type_tfms=rm_type_tfms, 
                             with_input=with_input, sample_size=sample_size, 
                             reduce=reduce)
    ent = res[4] if with_input else res[3]
    return (ent < threshold_entropy,) + res

### Test Functions

In [None]:
from fastai.test_utils import synth_dbunch, synth_learner
try:
    from contextlib import nullcontext # python 3.7 only
except ImportError as e:
    from contextlib import suppress as nullcontext # supported in 3.6 below
dls = synth_dbunch()
dls.vocab = [1,]
learner = synth_learner(data=dls)
learner.no_bar = nullcontext
bears_dl = dls.train
pets_dl = dls.valid
N_SAMPLE = 2
CATEGORIES = 1
BS = 160

In [None]:
#local
from fastai.learner import load_learner
from fastai.data.transforms import get_image_files
from fastai.data.external import Config
from fastai.vision.core import PILImage
import random
# setup objects using local paths
cfg = Config()
learner = load_learner(cfg.model_path/'bears_classifier'/'export.pkl')
bear_path = cfg.data_path/'bears'
pet_path = cfg.data_path/'pets'
bear_img_files = get_image_files(bear_path)
pet_img_files = get_image_files(pet_path)

random.seed(69420) # fix images retrieved
pet_img = PILImage.create(pet_img_files.shuffle()[0])
bear_img = PILImage.create(bear_img_files.shuffle()[0])

pet_items = pet_img_files.shuffle()[:20]
bear_items = bear_img_files.shuffle()[:20]

pet_dset = pet_items.map(lambda o: PILImage.create(o))
bear_dset = bear_items.map(lambda o: PILImage.create(o))
pets_dl = learner.dls.test_dl(pet_dset,num_workers=0)

bears_dl = learner.dls.test_dl(bear_dset,num_workers=0)
# xb.shape = torch.size([20,3,224,224])
N_SAMPLE = 2
CATEGORIES = 3
BS = 20

In [None]:
from fastcore.test import *

##### Bayes Prediction for Test Set

In [None]:
bear_res = learner.bayes_get_preds(dl=bears_dl, n_sample=N_SAMPLE)
pet_res = learner.bayes_get_preds(dl=pets_dl, n_sample=N_SAMPLE)

In [None]:
# preds, mean_preds, ents,best_guess, best_prob, best_cat 
test_eq(len(bear_res),6)
# ci 6
# local 6

In [None]:
# predictions
test_eq(bear_res[0].shape, [N_SAMPLE,BS,CATEGORIES])
#ci torch.Size([2, 160, 1])
#local torch.Size([5, 20, 3])

In [None]:
# mean predictions
test_eq(bear_res[1].shape, [BS, CATEGORIES])
#ci torch.Size([160, 1])
#local torch.Size([20, 3])

In [None]:
# entropy
test_eq(bear_res[2].shape,[BS])
#ci torch.Size([160])
#local torch.Size([20])

In [None]:
# best guess (index of mean)
test_eq(bear_res[3].shape,[BS])
# ci torch.Size([160])
# local torch.Size([20])

In [None]:
# best probability (mean prediction)
test_eq(bear_res[4].shape,[BS]) 
#ci torch.Size([160])
#local torch.Size([20])

In [None]:
# best category (mean prediction)
test_eq(len(bear_res[5]),BS)
# ci 160
# local 20