In [13]:
import os
import pickle
import numpy as np
from tabulate import tabulate
from eval_utils import *

# Folder path containing the .pkl files
folder_path = 'benchmark_eval_results'

# Initialize dictionary to store results across models
all_results = {}

# Function to compute metrics (adjusted for loaded (preds, reals) tuples)
def extract_metrics(preds, reals):
    mse, mae, mape, directional_accuracy = compute_metrics(preds, reals)
    return {
        'MSE': mse,
        'MAE': mae,
        'MAPE': mape,
        'Directional Accuracy': directional_accuracy
    }

# Read all .pkl files in the folder
for file_name in os.listdir(folder_path):
    if file_name.endswith('.pkl'):
        model_name = file_name.replace('.pkl', '')
        file_path = os.path.join(folder_path, file_name)
        
        # Load the data from the .pkl file
        with open(file_path, 'rb') as f:
            results_dict = pickle.load(f)
        
        # Convert (preds, reals) tuples to metric dictionaries
        model_metrics = {}
        for dataset_name, (preds, reals) in results_dict.items():
            metrics = extract_metrics(preds, reals)
            model_metrics[dataset_name] = metrics
        
        all_results[model_name] = model_metrics

In [20]:
def display_metric_table(metric_name, all_results):
    headers = ['Dataset'] + list(all_results.keys())
    headers = [s.replace('_metrics', '') for s in headers]
    datasets = list(next(iter(all_results.values())).keys())
    
    table_data = []
    model_averages = {model: [] for model in all_results.keys()}
    
    # Populate the table with metric values for each dataset
    for dataset in datasets:
        row = [dataset]
        metric_values = []
        for model, metrics in all_results.items():
            metric_value = metrics.get(dataset, {}).get(metric_name, np.nan)
            metric_values.append(metric_value)
            model_averages[model].append(metric_value)
            row.append(f"{metric_value:.3f}" if not np.isnan(metric_value) else "N/A")
        
        # Highlight the best metric value
        best_index = None
        if metric_name in ['MSE', 'MAE', 'MAPE']:
            best_index = np.nanargmin(metric_values) if np.any(~np.isnan(metric_values)) else None
        elif metric_name == 'Directional Accuracy':
            best_index = np.nanargmax(metric_values) if np.any(~np.isnan(metric_values)) else None

        # Apply bold highlighting
        if best_index is not None:
            row[best_index + 1] = f"\033[1m{row[best_index + 1]}\033[0m"
        
        table_data.append(row)
    
    # Add a row for the average of each model
    avg_row = ["Average"]
    avg_values = []
    for model, values in model_averages.items():
        avg_value = np.nanmean(values) if len(values) > 0 else np.nan
        avg_values.append(avg_value)
        avg_row.append(f"{avg_value:.3f}" if not np.isnan(avg_value) else "N/A")
    
    # Determine the best average to highlight in bold
    best_avg_index = None
    if metric_name in ['MSE', 'MAE', 'MAPE']:
        best_avg_index = np.nanargmin(avg_values) if np.any(~np.isnan(avg_values)) else None
    elif metric_name == 'Directional Accuracy':
        best_avg_index = np.nanargmax(avg_values) if np.any(~np.isnan(avg_values)) else None

    # Apply bold to the best average
    if best_avg_index is not None:
        avg_row[best_avg_index + 1] = f"\033[1m{avg_row[best_avg_index + 1]}\033[0m"
    
    table_data.append(avg_row)

    # Print the table with left-aligned columns
    print(f"\nMetric: {metric_name}\n")
    print(tabulate(table_data, headers=headers, tablefmt='grid', colalign=["left"] * len(headers)))

# List of metrics to display
metrics = ['MSE', 'MAE', 'MAPE', 'Directional Accuracy']

# Display tables for each metric
for metric in metrics:
    display_metric_table(metric, all_results)


Metric: MSE

+------------------+-------------------+----------------+--------------------+----------------+---------------+
| Dataset          | moirai-MoE_base   | moirai_large   | moirai-MoE_small   | moirai_small   | moirai_base   |
| DOT              | 350.614           | 83.665         | 36.714             | [1m28.711[0m         | 31.965        |
+------------------+-------------------+----------------+--------------------+----------------+---------------+
| AMZN             | 162.689           | 167.575        | [1m160.878[0m            | 240.617        | 206.464       |
+------------------+-------------------+----------------+--------------------+----------------+---------------+
| Corn             | 1.566             | 1.386          | [1m1.377[0m              | 2.293          | 2.027         |
+------------------+-------------------+----------------+--------------------+----------------+---------------+
| PFE              | 8.317             | 10.288         | [1m7.97