In [5]:
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)

file = 'sensitivity_10_1e5.csv'

def mad(df):
    return (df - df.median()).abs().median()

In [13]:
sens = pd.DataFrame(np.loadtxt(f'../out/{file}'), columns=['r1', 'r2', 'r3', 't1', 't2', 't3', 'det', 'det_true'])\
            .filter(['r1', 'r2', 'r3', 'det', 'det_true'])
sens = sens.groupby(['r1', 'r2', 'r3']).agg([mad, np.median]).reset_index()
sens['ratio'] = sens[('det', 'mad')] / sens[('det', 'median')]

sens

Unnamed: 0_level_0,r1,r2,r3,det,det,det_true,det_true,ratio
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mad,median,mad,median,Unnamed: 8_level_1
0,0.1,0.1,0.1,0.022266,14.31259,1.750557,2.518841,0.001556
1,0.1,0.1,1.0,0.002636,10.469039,1.760188,2.509755,0.000252
2,0.1,0.1,10.0,0.013415,5.379578,1.727879,2.494383,0.002494
3,0.1,1.0,0.1,0.001512,9.0707,1.810319,2.587702,0.000167
4,0.1,1.0,1.0,0.001117,7.873545,1.805144,2.59023,0.000142
5,0.1,1.0,10.0,0.010037,4.199752,1.729278,2.499147,0.00239
6,0.1,10.0,0.1,0.009294,2.570576,1.767242,2.553254,0.003616
7,0.1,10.0,1.0,0.002669,2.54914,1.777965,2.548916,0.001047
8,0.1,10.0,10.0,0.005267,2.173498,1.769052,2.533557,0.002423
9,1.0,0.1,0.1,0.000403,8.115368,1.729485,2.533428,5e-05


In [20]:
import plotly.graph_objects as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
import re

fig = make_subplots(1, 3, subplot_titles=[
    f'$r_4 = {r3}$' for r3 in [0.1, 1, 10]
], horizontal_spacing=0.05, vertical_spacing=0.1)

for i, (r3, d1) in enumerate(sens.groupby('r3')):
    z = d1.sort_values(['r2', 'r1']).ratio.to_numpy().reshape(3, 3)
    z = np.flipud(z)
    fig.add_trace(go.Heatmap(
        z=z*100,
        x=['0.1', '1', '10'],
        y=['10', '1', '0.1'],
        coloraxis='coloraxis',
        text=z*100,
        texttemplate='%{text:.2f}%',
        textfont={'size': 8}
    ), row=1, col=i+1)

fig.update_layout(
    coloraxis={'colorscale': 'viridis'},
    height=300,
)
fig.update_xaxes(title='$r_2$', row=1)
fig.update_yaxes(title='$r_3$', col=1)
fig.show()
fig.write_image('../figures/sensitivity_mad_optim.png')