In [1]:
import sys
sys.path.append('Lab02_NMT')

import pandas as pd
import numpy as np
import os
import json
import yaml
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader

from neural_network import NN_CATALOG
from dataset.bird_clef import load_wav
from dataset.augmentations import Normalize
from experiment.base import get_fold


Bad key text.latex.preview in file /home/cherepaha/miniconda/envs/py37/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 123 ('text.latex.preview : False')
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.5.1/matplotlibrc.template
or from the matplotlib source distribution

Bad key mathtext.fallback_to_cm in file /home/cherepaha/miniconda/envs/py37/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 155 ('mathtext.fallback_to_cm : True  # When True, use symbols from the Computer Modern')
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.5.1/matplotlibrc.template
or from the matplotlib source distribution

Bad key savefig.jpeg_quality in file /home/cherepaha/miniconda/envs/py37/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 418 ('savefig.jpeg_quality: 95       # w

In [2]:
all_meta = pd.read_csv('data/train_metadata_extended.csv')
with open('data/scored_birds.json') as fin:
    test_birds = json.load(fin)

In [3]:
all_meta.loc[:, 'secondary_labels'] = all_meta.secondary_labels.apply(eval)
all_meta['target_raw'] = all_meta.secondary_labels + all_meta.primary_label.apply(lambda x: [x])

In [4]:
all_species = sorted(set(all_meta.target_raw.sum()))
species2id = {s: i for i, s in enumerate(all_species)}
id2species = {i: s for i, s in enumerate(all_species)}

all_meta['target'] = all_meta.target_raw.apply(lambda species: [int(s in species) for s in all_species])

In [5]:
class TestDataset(Dataset):
    def __init__(self, data_root, meta_pd, augmentations=None, split_size=30):
        super().__init__()
        self.data_root = data_root
        self.meta_pd = meta_pd
        self.fnames = meta_pd.filename.values
        self.augmentations = augmentations
        self.split_size = split_size
        
    def __len__(self):
        return len(self.fnames)
    
    def pad(self, wav, sr):
        padded = wav
        if len(wav) % int(sr * self.split_size) != 0:
            crop_size = int(sr * self.split_size)
            padded_shape = (len(wav) // crop_size  + 1) * crop_size 
            padded = np.zeros(padded_shape)
            padded[:len(wav)] = wav
        return padded
    
    
    def __getitem__(self, idx):
        fpath = os.path.join(self.data_root, self.fnames[idx])
        wav, sr = load_wav(fpath, 0, 300)
        if self.augmentations:
            wav = self.augmentations(wav, None)
        wav = self.pad(wav, sr)
        wav = torch.tensor(wav).float()
        
        wav_len = wav.shape[0]
        split_factor =  wav_len // (self.split_size * sr)
        wav = wav.reshape((split_factor, wav_len // split_factor))
        return wav
    

In [6]:
_, val_meta = train_test_split(all_meta, test_size=0.2, random_state=42)

In [7]:
test_dataset = TestDataset(
    'data/train_audio',
    val_meta, 
    augmentations=Normalize(p=1)
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    drop_last=False,   
)

### Model

In [8]:
config_path = 'Lab02_NMT/configs/eff_gem.yaml'
model_path = f'Lab02_NMT/model_save/baseline_n262c4mi/final-model.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open(config_path) as fin:
    config = yaml.safe_load(fin)

model_config = config['model']
if 'backbone_config' in model_config['params']:
    model_config['params']['backbone_config']['pretrained'] = False
data_config = config['data']
model_class = NN_CATALOG[model_config['name']]

model = model_class(len(all_species), int(data_config['crop_len'] // data_config['test_wav_len']),
                    **model_config['params'])
model.to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

GemAttention(
  (audio2image): Sequential(
    (0): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (1): AmplitudeToDB()
  )
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
 

In [9]:
pred_list = []
treshold = 0.1
split_size = test_dataset.split_size

model.eval()
with torch.no_grad():
    for i, batch in tqdm(enumerate(test_dataloader), total = len(test_dataloader)):
        batch = batch[0]
        pred = model(batch.to(device))['logits']
        pred = pred.cpu().numpy()
        pred = pred.max(axis=0)
        
        fname = test_dataset.fnames[i]
        target = test_dataset.meta_pd.iloc[i].target
        
        pred_list.append({
            'filename': fname,
            'pred': pred,
            'target': target,
        })
pred_pd = pd.DataFrame(pred_list)

  normalized, onesided, return_complex)
  normalized, onesided, return_complex)
100%|██████████| 2971/2971 [01:22<00:00, 35.98it/s]


In [10]:
pred_np = np.array([p for p in pred_pd.pred])
target_np = np.array([t for t in pred_pd.target])

In [11]:
def score_pred(true, pred):
    if true.sum() == 0:
        return np.nan
    return balanced_accuracy_score(
            true,
            pred
        )

In [12]:
score_stat = []
trsh = [*[0.01 * i for i in range(1, 10)], 0.1, 0.15, 0.2, 0.25]

for t in trsh:
    score_stat.append({
        b: score_pred(
            target_np[:, species2id[b]],
            pred_np[:, species2id[b]] > t
        )
        for b in test_birds
    })
    
score_stat = pd.DataFrame(score_stat).T
score_stat.columns = trsh

In [13]:
score_stat

Unnamed: 0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.10,0.15,0.20,0.25
akiapo,0.839569,0.910347,0.943377,0.960061,0.971689,0.982137,0.868765,0.870787,0.872135,0.873315,0.749326,0.749831,0.75
aniani,0.741821,0.856661,0.912142,0.939123,0.958685,0.978752,0.9052,0.825576,0.827094,0.74511,0.748145,0.749157,0.582659
apapan,0.623218,0.752206,0.834012,0.840007,0.870726,0.894657,0.895193,0.888601,0.897597,0.903537,0.892363,0.815757,0.817284
barpet,0.807278,0.885613,0.921664,0.944744,0.958558,0.970687,0.98467,0.991408,0.994104,0.996799,0.665824,0.666161,0.66633
crehon,,,,,,,,,,,,,
elepai,0.753711,0.781304,0.831068,0.863625,0.887411,0.904786,0.912377,0.91457,0.919631,0.921992,0.855793,0.713611,0.571091
ercfra,0.829966,0.898316,0.928283,0.947475,0.957407,0.965993,0.974242,0.981818,0.991077,0.994781,0.49899,0.499832,0.5
hawama,0.615619,0.740196,0.824206,0.878127,0.912272,0.89832,0.881495,0.896877,0.904821,0.91023,0.840576,0.843787,0.652494
hawcre,,,,,,,,,,,,,
hawgoo,0.852476,0.921354,0.946783,0.962108,0.970697,0.978444,0.988885,0.996127,0.748316,0.748653,0.749832,0.5,0.5


use teacher, no maxpool loss

In [42]:
score_stat.idxmax(axis=1).to_dict()

{'akiapo': 0.01,
 'aniani': 0.01,
 'apapan': 0.01,
 'barpet': 0.01,
 'crehon': nan,
 'elepai': 0.01,
 'ercfra': 0.05,
 'hawama': 0.01,
 'hawcre': nan,
 'hawgoo': 0.2,
 'hawhaw': 0.05,
 'hawpet1': nan,
 'houfin': 0.05,
 'iiwi': 0.2,
 'jabwar': 0.01,
 'maupar': nan,
 'omao': 0.1,
 'puaioh': nan,
 'skylar': 0.05,
 'warwhe1': 0.01,
 'yefcan': 0.01}

no teacher, use maxpool loss

In [49]:
score_stat.idxmax(axis=1).to_dict()

{'akiapo': 0.05,
 'aniani': 0.01,
 'apapan': 0.01,
 'barpet': 0.01,
 'crehon': nan,
 'elepai': 0.01,
 'ercfra': 0.01,
 'hawama': 0.01,
 'hawcre': nan,
 'hawgoo': 0.1,
 'hawhaw': 0.01,
 'hawpet1': nan,
 'houfin': 0.01,
 'iiwi': 0.01,
 'jabwar': 0.01,
 'maupar': nan,
 'omao': 0.01,
 'puaioh': nan,
 'skylar': 0.01,
 'warwhe1': 0.01,
 'yefcan': 0.01}

baseline

In [76]:
score_stat.idxmax(axis=1).fillna(0.05).to_dict()

{'akiapo': 0.1,
 'aniani': 0.02,
 'apapan': 0.03,
 'barpet': 0.03,
 'crehon': 0.05,
 'elepai': 0.08,
 'ercfra': 0.02,
 'hawama': 0.02,
 'hawcre': 0.05,
 'hawgoo': 0.01,
 'hawhaw': 0.01,
 'hawpet1': 0.05,
 'houfin': 0.05,
 'iiwi': 0.25,
 'jabwar': 0.01,
 'maupar': 0.05,
 'omao': 0.25,
 'puaioh': 0.05,
 'skylar': 0.07,
 'warwhe1': 0.01,
 'yefcan': 0.01}

teacher eff focal

In [69]:
score_stat.idxmax(axis=1).fillna(0.05).to_dict()

{'akiapo': 0.08,
 'aniani': 0.03,
 'apapan': 0.07,
 'barpet': 0.01,
 'crehon': 0.05,
 'elepai': 0.09,
 'ercfra': 0.09,
 'hawama': 0.06,
 'hawcre': 0.05,
 'hawgoo': 0.25,
 'hawhaw': 0.02,
 'hawpet1': 0.05,
 'houfin': 0.1,
 'iiwi': 0.09,
 'jabwar': 0.05,
 'maupar': 0.05,
 'omao': 0.07,
 'puaioh': 0.05,
 'skylar': 0.1,
 'warwhe1': 0.07,
 'yefcan': 0.05}

teacher att focal 

In [15]:
score_stat.idxmax(axis=1).fillna(0.01).to_dict()

{'akiapo': 0.08,
 'aniani': 0.05,
 'apapan': 0.06,
 'barpet': 0.06,
 'crehon': 0.01,
 'elepai': 0.05,
 'ercfra': 0.04,
 'hawama': 0.04,
 'hawcre': 0.01,
 'hawgoo': 0.05,
 'hawhaw': 0.01,
 'hawpet1': 0.01,
 'houfin': 0.1,
 'iiwi': 0.08,
 'jabwar': 0.05,
 'maupar': 0.01,
 'omao': 0.09,
 'puaioh': 0.01,
 'skylar': 0.09,
 'warwhe1': 0.05,
 'yefcan': 0.07}

gem attention

In [14]:
score_stat.idxmax(axis=1).fillna(0.01).to_dict()

{'akiapo': 0.06,
 'aniani': 0.06,
 'apapan': 0.1,
 'barpet': 0.1,
 'crehon': 0.01,
 'elepai': 0.1,
 'ercfra': 0.1,
 'hawama': 0.05,
 'hawcre': 0.01,
 'hawgoo': 0.08,
 'hawhaw': 0.04,
 'hawpet1': 0.01,
 'houfin': 0.1,
 'iiwi': 0.2,
 'jabwar': 0.1,
 'maupar': 0.01,
 'omao': 0.06,
 'puaioh': 0.01,
 'skylar': 0.15,
 'warwhe1': 0.09,
 'yefcan': 0.07}

In [14]:
with open('Lab02_NMT/model_save/baseline_39fvwny4/trsh.json', 'w') as fout:
    json.dump(score_stat.idxmax(axis=1).fillna(0.05).to_dict(), fout)