In [2]:
# loading libraries
import yaml
import os
import pandas as pd
import numpy as np
from typing import Callable

In [76]:
yaml.SafeLoader.add_constructor(
    'tag:yaml.org,2002:python/object:argparse.Namespace', 
    yaml.SafeLoader.construct_mapping)

class RunEval:
    """
    Class to evaluate a hyper parameter tuning run.

    """
    def __init__(
            self, 
            run_path: str):
        
        """
        run_path: Path to the run folder containing the models.
        """
        self.run_path = run_path
        self.model_paths = self.get_model_paths()

    def get_model_paths(self):
        labels = [name for name in os.listdir(self.run_path) if os.path.isdir(os.path.join(self.run_path, name))]
        paths = [run_path + label for label in labels]
        filtered_paths = [path for path in paths if not path.endswith('/evaluation')]
        return filtered_paths
    
    def get_model_data(self, data_type: str, eval_subset: list = ["test"]):
        """
        data_type:Type of data to be extracted. Options are: ["metrics", "predictions", "model_info"]
        eval_subset:List of data sets to be extracted. Options are: ["train", "validation", "test"]
        """


        if data_type == 'metrics':
            
            def model_data(model_path):
                path = f'{model_path}/metrics.csv'
                metrics = pd.read_csv(path)

                return metrics


        elif data_type == 'predictions':

            def model_data(model_path, eval_subset=eval_subset):
                for data_set in eval_subset:
                    if data_set not in ['train', 'validation', 'test']:
                        raise ValueError(f'{data_set} is not a valid data set. Please choose from: ["train", "validation", "test"]')

                path = f'{model_path}/predictions.csv'
                data = pd.read_csv(path)

                filtered_data = data[data['data_set'].isin(eval_subset)]

                return filtered_data
            
        elif data_type == 'model_info':
            
            def model_data(model_path):

                path = f'{model_path}/all_parameters.yaml'

                with open(path, 'r') as file:
                    model_info = yaml.safe_load(file)
                
                if model_info['n_mels'] is None:
                    model_info['transform'] = 'Spec'
                else:
                    model_info['transform'] = 'MelSpec'

                return model_info
        
        else:
            raise ValueError(f'{type} is not a valid data type. Please choose from: ["metrics", "predictions", "model_info"]')
        
        data = []

        for model_path in self.model_paths:
            data.append(model_data(model_path))
        
        return data
    
    def get_metrics(self, function: Callable[[list, list], list], eval_subset: list = ["test"]):
        """
        function: Function to be applied to the metrics. The function should take y_true and y_pred as inputs and return a list of metrics.
        eval_subset:List of data sets to be extracted. Options are: ["train", "validation", "test"]

        Example:
        self.get_metrics(lambda x, y: [np.mean(x == y)], eval_subset=["test"])
        """
        
        metrics = []

        run_data = self.get_model_data('predictions', eval_subset=eval_subset)

        for data in run_data:
            metric = function(data['class_ID'], data['class_ID_pred'])
            metrics.append(metric)


        return metrics

    def get_keys(self, data_type: str):
        """
        Datatype: Options are: ["metrics", "predictions", "model_info"]
        """
        
        data = self.get_model_data(data_type)
        keys = data[0].keys()
        
        return keys
 


In [78]:
run_path = "../logs/main_run/"

run_eval = RunEval(run_path)

In [79]:
run_eval.get_model_data('metrics')

[       epoch   step  train_acc_epoch  train_acc_step  train_loss_epoch  \
 0          0      0              NaN        0.100000               NaN   
 1          0      1              NaN        0.300000               NaN   
 2          0      2              NaN        0.000000               NaN   
 3          0      3              NaN        0.100000               NaN   
 4          0      4              NaN        0.000000               NaN   
 ...      ...    ...              ...             ...               ...   
 31160   1354  28452              NaN        0.700000               NaN   
 31161   1354  28453              NaN        0.300000               NaN   
 31162   1354  28454              NaN        0.833333               NaN   
 31163   1354  28454              NaN             NaN               NaN   
 31164   1354  28454         0.601942             NaN          0.475445   
 
        train_loss_step   val_acc  val_loss  
 0             0.347015       NaN       NaN  
 1    

In [72]:
run_eval.get_keys('metrics')

ValueError: Must pass 2-d input. shape=(1, 31165, 8)