In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns

In [None]:
rows = []
for file in os.listdir('noise_results/data'):
    if file.endswith('.json'):
        with open(os.path.join('noise_results/data', file), 'r') as f:
            data = json.load(f)
        rows.append(data)

df = pd.DataFrame(rows)

In [None]:
for key in ['state_noise_db', 'latent_noise_db']:
    df[key][df[key] == 0.0] = 100.0

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

for i, (model, group) in enumerate(df.groupby('model_name')):
    if 'implicit' in model: continue
    for j, (noise_mode, sub_group) in enumerate(group.groupby('noise_mode')):
        # plot 2D matrix of latent and state noise with accuracy as z values
        pivot_table = sub_group.pivot_table(
            index='state_noise_db', 
            columns='latent_noise_db', 
            values='average_accuracy', 
            aggfunc='mean'
        )
        # Use vmin and vmax to set color scale limits
        sns.heatmap(pivot_table, ax=axes[j], cmap='YlOrRd', vmin=0, vmax=.6)

        axes[j].set_title(f"{model.replace('hf_models/mamba2-130m-', '')} - {noise_mode}")
        axes[j].set_xlabel('Latent Noise (dB)')
        axes[j].set_ylabel('State Noise (dB)')
    
    fig.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

for i, (model, group) in enumerate(df.groupby('model_name')):
    for j, (noise_mode, sub_group) in enumerate(group.groupby('noise_mode')):
        # plot 2D matrix of latent and state noise with accuracy as z values
        pivot_table = sub_group.pivot_table(
            index='state_noise_db', 
            columns='latent_noise_db', 
            values='average_accuracy', 
            aggfunc='mean'
        )
        # Use vmin and vmax to set color scale limits
        sns.heatmap(pivot_table, ax=axes[i,j], cmap='YlOrRd', vmin=0, vmax=.6)

        axes[i,j].set_title(f"{model.replace('hf_models/mamba2-130m-', '')} - {noise_mode}")
        axes[i,j].set_xlabel('Latent Noise (dB)')
        axes[i,j].set_ylabel('State Noise (dB)')
    
    fig.tight_layout()