In [None]:
# default_exp inference

# Inference

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export

import numpy as np
from fastai.basics import *

In [None]:
#export

from scipy.stats import mode

most_freq = lambda seq: mode(seq)[0][0]

def decode_preds(vocabs, preds):
    for distortion, severity in zip(preds[0].cpu().numpy(), preds[1].cpu().numpy()):
        yield vocabs[0][distortion], vocabs[1][severity]

def fill_preds(dataf, vocabs, preds):
    pred_distortion, pred_severity =  list(zip(*list(decode_preds(vocabs, preds))))
    dataf['distortion_preds'] = pred_distortion
    dataf['severity_preds'] = pred_severity
    return dataf

def aggregate_preds(dataf):
    idf = dataf.groupby(by='video_name').agg({
        'distortion': first, 
        'distortion_preds': list, 
        'severity': first, 
        'severity_preds': list, 
    })
    idf['distortion_inference'] = idf['distortion_preds'].apply(most_freq)
    idf['severity_inference'] = idf['severity_preds'].apply(most_freq)
    return idf

def format_pred(pred):
    distortions = pred[0].split("_")
    if len(distortions)==0:
        return ''
    severity = pred[1]
    return ','.join(sorted([f"{distortion[1]}_{severity}" for distortion in distortions]))

def get_test_inferences(dls, learn, tst_df):
    tst_dl = dls.test_dl(tst_df)
    tst_learn = Learner(
        dls, 
        learn.model,
        loss_func=learn.loss_func,
        splitter=learn.model.splitter
    )
    probs, targets, preds = tst_learn.get_preds(dl=tst_dl, with_decoded=True)
    inference_df = aggregate_preds(fill_preds(tst_df, dls.vocab, preds))
    return inference_df

def make_submission_preds(distortion, severity):
    return L(zip(distortion, severity)).map(format_pred)

In [None]:
#hide
test_eq(format_pred(("D1_D5", "4")), '1_4,5_4')
test_eq(format_pred(("D1", "1")), '1_1')