In [25]:
import sys
sys.path.append('Lab02_NMT')

import pandas as pd
import numpy as np
import os
import json
import yaml
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader

from neural_network import NN_CATALOG
from dataset.bird_clef import load_wav
from dataset.augmentations import Normalize
from experiment.base import get_fold

In [26]:
all_meta = pd.read_csv('data/train_metadata_extended.csv')
with open('data/scored_birds.json') as fin:
    test_birds = json.load(fin)

In [27]:
all_meta.loc[:, 'secondary_labels'] = all_meta.secondary_labels.apply(eval)
all_meta['target_raw'] = all_meta.secondary_labels + all_meta.primary_label.apply(lambda x: [x])

In [28]:
all_species = sorted(set(all_meta.target_raw.sum()))
species2id = {s: i for i, s in enumerate(all_species)}
id2species = {i: s for i, s in enumerate(all_species)}

all_meta['target'] = all_meta.target_raw.apply(lambda species: [int(s in species) for s in all_species])

In [29]:
class TestDataset(Dataset):
    def __init__(self, data_root, meta_pd, augmentations=None, split_size=30):
        super().__init__()
        self.data_root = data_root
        self.meta_pd = meta_pd
        self.fnames = meta_pd.filename.values
        self.augmentations = augmentations
        self.split_size = split_size
        
    def __len__(self):
        return len(self.fnames)
    
    def pad(self, wav, sr):
        padded = wav
        if len(wav) % int(sr * self.split_size) != 0:
            crop_size = int(sr * self.split_size)
            padded_shape = (len(wav) // crop_size  + 1) * crop_size 
            padded = np.zeros(padded_shape)
            padded[:len(wav)] = wav
        return padded
    
    
    def __getitem__(self, idx):
        fpath = os.path.join(self.data_root, self.fnames[idx])
        wav, sr = load_wav(fpath, 0, 300)
        if self.augmentations:
            wav = self.augmentations(wav, None)
        wav = self.pad(wav, sr)
        wav = torch.tensor(wav).float()
        
        wav_len = wav.shape[0]
        split_factor =  wav_len // (self.split_size * sr)
        wav = wav.reshape((split_factor, wav_len // split_factor))
        return wav
    

In [30]:
_, val_meta = train_test_split(all_meta, test_size=0.2, random_state=42)

In [31]:
test_dataset = TestDataset(
    'data/train_audio',
    val_meta, 
    augmentations=Normalize(p=1)
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    drop_last=False,   
)

### Model

In [70]:
config_path = 'Lab02_NMT/configs/baseline_config.yaml'
model_path = f'Lab02_NMT/model_save/baseline_072/final-model.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open(config_path) as fin:
    config = yaml.safe_load(fin)

model_config = config['model']
if 'backbone_config' in model_config['params']:
    model_config['params']['backbone_config']['pretrained'] = False
data_config = config['data']
model_class = NN_CATALOG[model_config['name']]

model = model_class(len(all_species), int(data_config['crop_len'] // data_config['test_wav_len']),
                    **model_config['params'])
model.to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

AttentionNet(
  (audio2image): Sequential(
    (0): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (1): AmplitudeToDB()
  )
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
 

In [71]:
pred_list = []
treshold = 0.1
split_size = test_dataset.split_size

model.eval()
with torch.no_grad():
    for i, batch in tqdm(enumerate(test_dataloader), total = len(test_dataloader)):
        batch = batch[0]
        pred = model(batch.to(device))['logits']
        pred = pred.cpu().numpy()
        pred = pred.max(axis=0)
        
        fname = test_dataset.fnames[i]
        target = test_dataset.meta_pd.iloc[i].target
        
        pred_list.append({
            'filename': fname,
            'pred': pred,
            'target': target,
        })
pred_pd = pd.DataFrame(pred_list)

100%|██████████| 2971/2971 [01:21<00:00, 36.43it/s]


In [72]:
pred_np = np.array([p for p in pred_pd.pred])
target_np = np.array([t for t in pred_pd.target])

In [73]:
def score_pred(true, pred):
    if true.sum() == 0:
        return np.nan
    return balanced_accuracy_score(
            true,
            pred
        )

In [74]:
score_stat = []
trsh = [*[0.01 * i for i in range(1, 10)], 0.1, 0.15, 0.2, 0.25]

for t in trsh:
    score_stat.append({
        b: score_pred(
            target_np[:, species2id[b]],
            pred_np[:, species2id[b]] > t
        )
        for b in test_birds
    })
    
score_stat = pd.DataFrame(score_stat).T
score_stat.columns = trsh

In [75]:
score_stat

Unnamed: 0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.10,0.15,0.20,0.25
akiapo,0.990226,0.992248,0.993428,0.993933,0.995113,0.995956,0.996124,0.99663,0.996967,0.997135,0.873315,0.748652,0.749326
aniani,0.985497,0.993086,0.827263,0.828612,0.829117,0.829792,0.829961,0.830467,0.830467,0.831141,0.832322,0.749325,0.749494
apapan,0.933666,0.945037,0.951996,0.9359,0.937427,0.940143,0.94201,0.943367,0.944725,0.945743,0.930156,0.891853,0.873211
barpet,0.795597,0.822552,0.828111,0.663129,0.663803,0.664477,0.664814,0.664982,0.665656,0.665824,0.66633,0.666498,0.666498
crehon,,,,,,,,,,,,,
elepai,0.89011,0.902256,0.907148,0.909003,0.911702,0.913727,0.915076,0.915751,0.844997,0.845672,0.849889,0.850564,0.852926
ercfra,0.499495,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5
hawama,0.954868,0.967546,0.936691,0.939564,0.941931,0.905497,0.90685,0.908878,0.909554,0.909892,0.912766,0.914456,0.87684
hawcre,,,,,,,,,,,,,
hawgoo,0.980633,0.7468,0.748484,0.748653,0.749326,0.749495,0.749663,0.499663,0.499663,0.499663,0.499663,0.499663,0.5


use teacher, no maxpool loss

In [42]:
score_stat.idxmax(axis=1).to_dict()

{'akiapo': 0.01,
 'aniani': 0.01,
 'apapan': 0.01,
 'barpet': 0.01,
 'crehon': nan,
 'elepai': 0.01,
 'ercfra': 0.05,
 'hawama': 0.01,
 'hawcre': nan,
 'hawgoo': 0.2,
 'hawhaw': 0.05,
 'hawpet1': nan,
 'houfin': 0.05,
 'iiwi': 0.2,
 'jabwar': 0.01,
 'maupar': nan,
 'omao': 0.1,
 'puaioh': nan,
 'skylar': 0.05,
 'warwhe1': 0.01,
 'yefcan': 0.01}

no teacher, use maxpool loss

In [49]:
score_stat.idxmax(axis=1).to_dict()

{'akiapo': 0.05,
 'aniani': 0.01,
 'apapan': 0.01,
 'barpet': 0.01,
 'crehon': nan,
 'elepai': 0.01,
 'ercfra': 0.01,
 'hawama': 0.01,
 'hawcre': nan,
 'hawgoo': 0.1,
 'hawhaw': 0.01,
 'hawpet1': nan,
 'houfin': 0.01,
 'iiwi': 0.01,
 'jabwar': 0.01,
 'maupar': nan,
 'omao': 0.01,
 'puaioh': nan,
 'skylar': 0.01,
 'warwhe1': 0.01,
 'yefcan': 0.01}

baseline

In [76]:
score_stat.idxmax(axis=1).fillna(0.05).to_dict()

{'akiapo': 0.1,
 'aniani': 0.02,
 'apapan': 0.03,
 'barpet': 0.03,
 'crehon': 0.05,
 'elepai': 0.08,
 'ercfra': 0.02,
 'hawama': 0.02,
 'hawcre': 0.05,
 'hawgoo': 0.01,
 'hawhaw': 0.01,
 'hawpet1': 0.05,
 'houfin': 0.05,
 'iiwi': 0.25,
 'jabwar': 0.01,
 'maupar': 0.05,
 'omao': 0.25,
 'puaioh': 0.05,
 'skylar': 0.07,
 'warwhe1': 0.01,
 'yefcan': 0.01}

teacher eff focal

In [69]:
score_stat.idxmax(axis=1).fillna(0.05).to_dict()

{'akiapo': 0.08,
 'aniani': 0.03,
 'apapan': 0.07,
 'barpet': 0.01,
 'crehon': 0.05,
 'elepai': 0.09,
 'ercfra': 0.09,
 'hawama': 0.06,
 'hawcre': 0.05,
 'hawgoo': 0.25,
 'hawhaw': 0.02,
 'hawpet1': 0.05,
 'houfin': 0.1,
 'iiwi': 0.09,
 'jabwar': 0.05,
 'maupar': 0.05,
 'omao': 0.07,
 'puaioh': 0.05,
 'skylar': 0.1,
 'warwhe1': 0.07,
 'yefcan': 0.05}

In [14]:
with open('Lab02_NMT/model_save/baseline_39fvwny4/trsh.json', 'w') as fout:
    json.dump(score_stat.idxmax(axis=1).fillna(0.05).to_dict(), fout)