In [187]:
import argparse
import os
import time
import subprocess

import numpy
import torch
import torch.nn.functional as F
from torch import sigmoid, softmax, argmax
from torchaudio import load as ta_load

from SSL.util.utils import DotDict
from SSL.util.checkpoint import CheckPoint
from SSL.util.model_loader import load_model
from SSL.util.loaders import load_optimizer, load_preprocesser, load_dataset

# list all model

In [188]:
root = '../model_save/ComParE2021-PRS/supervised/MobileNetV2'
for f in os.listdir(root):
    if 'aug' in f:
        print(os.path.join(root, f))

../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_0.1-sr_64-bs_1234-seed_True-aug__.last
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_1.0-sr_64-bs_1234-seed_True-aug__.last
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_1.0-sr_64-bs_1234-seed_train-aug_mixup-max-label-1.0-a_specAugment-12-tdw-1-tsn-4-fdw-1-fsn__.last
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_0.1-sr_64-bs_1234-seed_True-aug_mixup-max-label-1.0-a_specAugment-12-tdw-1-tsn-4-fdw-1-fsn__.best
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_0.1-sr_64-bs_1234-seed_True-aug__.best
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_0.1-sr_64-bs_1234-seed_True-aug_mixup-max-label-1.0-a_specAugment-12-tdw-1-tsn-4-fdw-1-fsn__.last
../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_1.0-sr_64-bs_1234-seed_train-aug_mixup-max-label-1.0-a_specAugm

In [200]:
selected_path = '../model_save/ComParE2021-PRS/supervised/MobileNetV2/MobileNetV2__0.001-lr_1.0-sr_64-bs_1234-seed_True-aug__.best'

# minimum required

In [201]:
args = DotDict(
    dataset='ComParE2021-prs',
    method='supervised',
    model='MobileNetV2',  # Must be compatible with weight file
    nb_class=5
)

In [202]:
nb_class = {
    'ubs8k': 10,
    'esc10': 10,
    'esc50': 50,
    'speechcommand': 35,
    'compare2021-prs': 5,
    'audioset-unbalanced': 527,
    'audioset-balanced': 527,
}

In [203]:
print('Loading preprocesser ...')
t_transform, v_transform = load_preprocesser(args.dataset, args.method)

print('Loading model ...')
model_func = load_model(args.dataset, args.model)
model = model_func(num_classes=nb_class[args.dataset.lower()])

print('Loading weights ...')
optimizer = load_optimizer(args.dataset, args.method, model=model, learning_rate=0.003)
checkpoint = CheckPoint(model, optimizer, mode="max", name='./.tmp')
checkpoint.load(selected_path)

print('Loading the dataset ...')
_, train_loader, val_loader = load_dataset(
        args.dataset,
        "supervised",

        dataset_root='../datasets',
        supervised_ratio=1.0,
        batch_size=32,
        train_folds=None,
        val_folds=None,

        train_transform=t_transform,
        val_transform=v_transform,

        num_workers=4,  # With the cache enable, it is faster to have only one worker
        pin_memory=False,

        verbose=1
    )

Loading preprocesser ...
loading dataset: supervised | compare2021-prs
Loading model ...
Loading weights ...
loading dataset: supervised | compare2021-prs
checkpoint initialise at:  /home/lcances/sync/Documents_sync/Projet/semi-supervised/notebooks/.tmp
name:  .tmp
mode:  max
['state_dict', 'optimizer', 'epoch', 'best_metric']
['state_dict', 'optimizer', 'epoch', 'best_metric']
Loading the dataset ...
loading dataset: supervised | compare2021-prs
cache path:  .ComParE2021_PRS/.cache_batch_size=32_seed=1234
split ready, loading cache file
Sort the classes


# inference

In [204]:
a_logits, a_y = [], []

nb_batch = len(val_loader)

start_time = time.time()
print("")

model.eval()

with torch.set_grad_enabled(False):
    for i, (X, y) in enumerate(val_loader):
        X = X.float()
        y = F.one_hot(y, num_classes=nb_class[args.dataset.lower()]).float()

        a_logits.append(model(X))
        a_y.append(y)

        # logs
        print(f'{i} / {nb_batch}', end='\r')




  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore


216 / 217

In [205]:
a_logits_ = torch.vstack(a_logits)
a_y_ = torch.vstack(a_y)

# Calc the metrics

In [206]:
from typing import Dict, Tuple
from mlu.metrics import Metric, CategoricalAccuracy, UAR, AveragePrecision

metrics = DotDict(
    acc = CategoricalAccuracy(),
    mAP = AveragePrecision(),
    uar = UAR(),
)

In [207]:
# A = lambda x: torch.sigmoid(x)
# A = lambda x: torch.softmax(x, dim=1)
A = lambda x: x

In [208]:
print('acc: ', metrics.acc(A(a_logits_), a_y_))
print('mAP: ', metrics.mAP(A(a_logits_), a_y_))
print('UAR: ', metrics.uar(A(a_logits_), a_y_))

acc:  tensor(0.7626)
mAP:  tensor(0.6548, dtype=torch.float64)
UAR:  tensor(0.6008, dtype=torch.float64)
