# Inference pipeline for DHoa's first place model

The input for this model: audio file in ogg format

Output: Predicted music genre

The first place solution briefly [described here](https://www.kaggle.com/competitions/kaggle-pog-series-s01e02/discussion/321281)

In [1]:
! pip install -Uqq huggingface_hub fastai kornia==0.5.8

In [1]:
from huggingface_hub import hf_hub_download
from fastai.vision.all import *
from fastai.learner import load_learner

import kornia

In [3]:
class ReflectionCrop(RandomCrop):
    def encodes(self, x:(Image.Image,TensorBBox,TensorPoint)):
        return x.crop_pad(self.size, self.tl, orig_sz=self.orig_sz, pad_mode=PadMode.Reflection)

In [4]:
def get_y(filename):
    resample_name = filename.stem + '.ogg'
    return df_train[df_train['filename']==resample_name]['genre'].values[0]

In [5]:
class CustomDataBlock(DataBlock):
    def datasets(self:DataBlock, source, verbose=False, splits=None):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)
    def dataloaders(self, source, path='.', verbose=False, splits=None, **kwargs):
        dsets = self.datasets(source, verbose=verbose, splits=splits)
        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

In [6]:
def convert_MP_to_blurMP(model, layer_type_old):
    conversion_count = 0
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = convert_MP_to_blurMP(module, layer_type_old)

        if type(module) == layer_type_old:
            layer_old = module
            layer_new = kornia.contrib.MaxBlurPool2d(3, True)
            model._modules[name] = layer_new

    return model

In [7]:
filenames = ["learn_export_0.pkl", "learn_export_1.pkl", "learn_export_2.pkl", "learn_export_3.pkl",
             "learn_export_4.pkl", "learn_export_5.pkl", "learn_export_6.pkl", "learn_export_101_0.pkl",
             "learn_export_101_1.pkl"]

learns = [load_learner(
    hf_hub_download("kurianbenoy/inference-music-genre-dhoa", f)
) for f in filenames]

In [8]:
len(learns)

9

In [9]:
_before_epoch = [event.before_fit, event.before_epoch]
_after_epoch  = [event.after_epoch, event.after_fit]

@patch
def ttacustom(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):
    "Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation"
    if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
    if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
    try:
        self(_before_epoch)
        with dl.dataset.set_split_idx(0), self.no_mbar():
            if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
            aug_preds = []
            for i in self.progress.mbar if hasattr(self,'progress') else range(n):
                self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
                preds = self.get_preds(dl=dl, inner=True)[0][None]
                preds_idx = preds.squeeze().argmax(1)
                aug_preds.append(preds_idx)
#         aug_preds = torch.cat(aug_preds)
#         aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)
#         self.epoch = n
#         with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True)
    finally: self(event.after_fit)

#     if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs
#     preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)
    return aug_preds