In [42]:
# loading libraries
import yaml
import os
import pandas as pd
import numpy as np
from typing import Callable

In [56]:
# Register a constructor for argparse.Namespace tag that returns a dictionary
yaml.SafeLoader.add_constructor(
    'tag:yaml.org,2002:python/object:argparse.Namespace', 
    yaml.SafeLoader.construct_mapping)

# defining some helper functions
def model_paths(run_path):
    labels = [name for name in os.listdir(run_path) if os.path.isdir(os.path.join(run_path, name))]
    paths = [run_path + label for label in labels]
    filtered_paths = [path for path in paths if not path.endswith('/evaluation')]
    return filtered_paths



def get_model_data(model_path: str, data_type: str, data_set_selection: list = ["test"]):

    if data_type == 'metrics':
        path = f'{model_path}/metrics.csv'
        metrics = pd.read_csv(path)

        return metrics

    elif data_type == 'predictions':

        for data_set in data_set_selection:
            if data_set not in ['train', 'validation', 'test']:
                raise ValueError(f'{data_set} is not a valid data set. Please choose from: ["train", "validation", "test"]')

        path = f'{model_path}/predictions.csv'
        data = pd.read_csv(path)

        filtered_data = data[data['data_set'].isin(data_set_selection)]

        return filtered_data
        
    elif data_type == 'model_info':
        
        path = f'{model_path}/all_parameters.yaml'

        with open(path, 'r') as file:
            model_info = yaml.safe_load(file)
        
        if model_info['n_mels'] is None:
            model_info['transform'] = 'Spec'
        else:
            model_info['transform'] = 'MelSpec'

        return model_info
    
    else:
        raise ValueError(f'{type} is not a valid data type. Please choose from: ["metrics", "predictions", "model_info"]')


def get_run_data(run_path: str, data_type: str, data_set_selection: list = ["test"]):

    model_paths_list = model_paths(run_path)
    run_data = []

    for model_path in model_paths_list:
        model_data = get_model_data(model_path, data_type, data_set_selection)
        run_data.append(model_data)
    
    return run_data

def calculate_metrics(run_path: str, function: Callable[[list, list], int | list], data_set_selection: list = ["test"]):

    metrics = []

    run_data = get_run_data(run_path, 'predictions', data_set_selection)

    for data in run_data:
        metric = function(data['class_ID'], data['class_ID_pred'])
        metrics.append(metric)

    return metrics

def accuracy(y_true, y_pred):
    return (y_true == y_pred).mean()

run_path = "../logs/main_run/"
calculate_metrics(run_path, accuracy, ["test"])

In [58]:
def calculate_metrics(run_path: str, function: Callable[[list, list], int | list], data_set_selection: list = ["test"]):

    metrics = []

    run_data = get_run_data(run_path, 'predictions', data_set_selection)

    for data in run_data:
        metric = function(data['class_ID'], data['class_ID_pred'])
        metrics.append(metric)

    return metrics

def accuracy(y_true, y_pred):
    return (y_true == y_pred).mean()

run_path = "../logs/main_run/"
calculate_metrics(run_path, accuracy, ["test"])

[0.44594594594594594,
 0.47297297297297297,
 0.35135135135135137,
 0.3783783783783784,
 0.47297297297297297,
 0.3783783783783784,
 0.5405405405405406,
 0.4189189189189189,
 0.5405405405405406,
 0.40540540540540543,
 0.6081081081081081,
 0.5405405405405406,
 0.3918918918918919,
 0.5540540540540541,
 0.4594594594594595,
 0.4189189189189189,
 0.6486486486486487,
 0.3783783783783784,
 0.4594594594594595,
 0.33783783783783783,
 0.5405405405405406,
 0.5675675675675675,
 0.6081081081081081,
 0.3783783783783784,
 0.581081081081081,
 0.4594594594594595,
 0.5675675675675675,
 0.5,
 0.527027027027027,
 0.17567567567567569,
 0.5675675675675675,
 0.5540540540540541,
 0.3783783783783784,
 0.4864864864864865,
 0.28378378378378377,
 0.22972972972972974]

In [53]:
run_path = "../logs/main_run/"
data_set_selection = ["test"]

run_data = get_run_data(run_path, 'predictions', data_set_selection)

run_data[0]

Unnamed: 0.1,Unnamed: 0,file_name,species,class_ID,data_set,original_file_name,data_path,class_ID_pred
6,6,Nemobiussylvestris_MixPre-304.M.wav,Nemobiussylvestris,11,test,MixPre-304.M.wav,./data/Orthoptera,11
10,10,Pseudochorthippusparallelus_VOC_150705-0325.M.wav,Pseudochorthippusparallelus,28,test,VOC_150705-0325.M.wav,./data/Orthoptera,30
11,11,Chorthippusbiguttulus_Take91.wav,Chorthippusbiguttulus,2,test,Take91.wav,./data/Orthoptera,2
14,14,Tettigoniaviridissima_MixPre-198.wav,Tettigoniaviridissima,31,test,MixPre-198.wav,./data/Orthoptera,11
17,17,Pholidopteragriseoaptera_dat022-004.wav,Pholidopteragriseoaptera,13,test,dat022-004.wav,./data/Orthoptera,13
...,...,...,...,...,...,...,...,...
315,172,Myopsaltamelanobasis_Myopsalta_melanobasis_20k...,Myopsaltamelanobasis,9,test,Myopsalta_melanobasis_20km_S_Taroom.wav,./data/Cicadidae,9
323,180,Platypleurasp11cfhirtipennis_MHV%20859%20P.sp1...,Platypleurasp11cfhirtipennis,25,test,MHV%20859%20P.sp11%20Trawal%20%232.wav,./data/Cicadidae,23
324,181,Platypleuraintercapedinis_MHV%20947%20P.cf_.br...,Platypleuraintercapedinis,21,test,MHV%20947%20P.cf_.brunea%20Morningside%20Ranch...,./data/Cicadidae,16
326,183,Platypleurasp13_MHV%201487%20P.sp%2013%20nr%20...,Platypleurasp13,27,test,MHV%201487%20P.sp%2013%20nr%20Thornhill%20%231...,./data/Cicadidae,22


In [25]:
path_best_model = '../logs/main_run/mel064_nblock4_lr0.001_ks5'

y_true, y_hat = get_model_data(path_best_model, 'predictions')

accuracy = np.mean(y_true == y_hat)

print(f'Accuracy: {accuracy}')

Accuracy: 0.6486486486486487


In [40]:
run_path = "../logs/main_run/"

run_data = get_run_data(log_dir, data_type = 'predictions', data_set_selection = ["test"])

type(run_data[0]['class_ID'])

pandas.core.series.Series