# Post-processing

In [1]:
from os import listdir

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval
import numpy as np
from pathlib import Path

In [2]:
baseline_df = pd.DataFrame()
for file_path in Path("../results").rglob('baseline_evaluation_results.csv'):
    baseline_df = pd.concat([baseline_df, pd.read_csv(file_path).drop_duplicates()])

ft_df = pd.DataFrame()
for file_path in Path("../results").rglob('evaluation_results.csv'):
    ft_df = pd.concat([ft_df, pd.read_csv(file_path).drop_duplicates()])

In [None]:
baseline_df.head()

In [None]:
ft_df.head()

In [5]:
ft_df['fine-tuned languages'] = ft_df['fine-tuned languages'].apply(literal_eval).apply(lambda x: " ".join(x))
baseline_df['evaluated languages'] = baseline_df['evaluated languages'].apply(literal_eval).apply(lambda x: " ".join(x))
ft_df['evaluated languages'] = ft_df['evaluated languages'].apply(literal_eval).apply(lambda x: " ".join(x))
baseline_df['fine-tuned languages'] = '-'

In [6]:
ft_df['model_id'] = ft_df['model'] + " ft with " + ft_df['fine-tuned languages']
baseline_df['model_id'] = "ASV trained " + baseline_df['model']

In [None]:
ft_df['model_id'].unique()

In [None]:
baseline_df = baseline_df.groupby(['model', 'model_id', 'fine-tuned languages', 'evaluated languages'])[baseline_df.select_dtypes(include=[np.number]).columns].mean()
ft_df = ft_df.groupby(['model', 'model_id', 'fine-tuned languages', 'evaluated languages'])[ft_df.select_dtypes(include=[np.number]).columns].mean()
ft_df.head()

In [None]:
baseline_df.head()

In [None]:
df = pd.concat([baseline_df, ft_df])
df

In [11]:
df = df.apply(lambda x: x).reset_index()

In [None]:
df.pivot(index=['model_id'], columns='evaluated languages', values=['eer'])

In [None]:
def create_heatmap(dataframe, model_name, metric='eer', metric_label="EER"):
    heatmap_data = dataframe.loc[dataframe['model'] == model_name].pivot(index='fine-tuned languages', columns='evaluated languages', values=metric)
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(
        heatmap_data,
        annot=True,
        cmap="Blues",
        cbar_kws={'label': metric_label},
        annot_kws={"size": 16},
        fmt='.2f'
    )
    plt.xlabel('Evaluated with', fontsize=16)
    plt.ylabel('Fine-Tuned With', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    if '-' in heatmap_data.index:
        idx = heatmap_data.index.get_loc('-')
        n_rows = len(heatmap_data.index)
        n_cols = len(heatmap_data.columns)
        
        # Draw horizontal lines (top and bottom of the frame)
        ax.add_patch(plt.Rectangle((0, idx), n_cols, 1, fill=False, edgecolor='black', lw=4))
        
        # Draw vertical lines (left and right of the frame)
        ax.vlines(x=[0, n_cols], ymin=n_rows-idx-1, ymax=n_rows-idx, colors='black', linewidth=4)
        
    model_name = model_name.lower()
    if 'rawgat_st' in model_name:
        model_name = model_name.replace('rawgat_st', 'RawGAT ST')
    if 'w2v_aasist' in model_name:
        model_name = model_name.replace('w2v_aasist', 'W2V + AASIST')
    elif 'mesonet' in model_name:
        model_name = model_name.replace('mesonet', 'LFCC + MesoNet')
    if 'whisper_aasist' in model_name:
        model_name= model_name.replace('whisper_aasist', 'Whisper + AASIST')
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=16)
    cbar.set_label(metric_label, fontsize=16)
    plt.title(f'{metric_label} for {model_name}', fontsize=18)
    plt.savefig(f"../results/plots/{model_name.lower().replace('+', '-')}.png")
    plt.show()

In [None]:
for name in df.apply(lambda x: x)['model'].unique():
    create_heatmap(df, name)


In [None]:
for name in df.apply(lambda x: x)['model'].unique():
    create_heatmap(df, name, "accuracy", "Accuracy")