In [1]:
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)

file = 'sensitivity_10_1e5.csv'

def mad(df):
    return (df - df.median()).abs().median()

In [2]:
sens = pd.DataFrame(np.loadtxt(f'../out/{file}'), columns=['r1', 'r2', 'r3', 't1', 't2', 't3', 'det', 'det_true'])\
            .filter(['r1', 'r2', 'r3', 'det', 'det_true'])
sens = sens.groupby(['r1', 'r2', 'r3']).agg([mad, np.median]).reset_index()
sens['ratio'] = sens[('det', 'mad')] / sens[('det', 'median')]

sens

Unnamed: 0_level_0,r1,r2,r3,det,det,det_true,det_true,ratio
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mad,median,mad,median,Unnamed: 8_level_1
0,0.1,0.1,0.1,0.0003427852,11.488274,1.664354,4.308576,2.983783e-05
1,0.1,0.1,1.0,0.002445812,9.317828,1.661183,4.302993,0.0002624874
2,0.1,0.1,10.0,0.0005822396,6.845921,1.698494,4.302488,8.504914e-05
3,0.1,1.0,0.1,0.001211174,8.447628,1.574803,4.350253,0.0001433744
4,0.1,1.0,1.0,0.0009926816,7.48898,1.575495,4.351099,0.0001325523
5,0.1,1.0,10.0,0.0009620019,5.699785,1.615442,4.326397,0.0001687786
6,0.1,10.0,0.1,0.0008723427,4.75176,1.575203,4.350253,0.0001835831
7,0.1,10.0,1.0,0.0004801992,4.636697,1.57488,4.350248,0.0001035649
8,0.1,10.0,10.0,0.0009036149,4.045476,1.603087,4.315084,0.0002233643
9,1.0,0.1,0.1,1.332268e-15,7.6113,1.555677,4.326486,1.750381e-16


In [3]:
import plotly.graph_objects as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
import re

fig = make_subplots(1, 3, subplot_titles=[
    f'$r_4 = {r3}$' for r3 in [0.1, 1, 10]
], horizontal_spacing=0.05, vertical_spacing=0.1)

for i, (r3, d1) in enumerate(sens.groupby('r3')):
    z = d1.sort_values(['r2', 'r1']).ratio.to_numpy().reshape(3, 3)
    z = np.flipud(z)
    fig.add_trace(go.Heatmap(
        z=z*100,
        x=['0.1', '1', '10'],
        y=['10', '1', '0.1'],
        coloraxis='coloraxis',
        text=z*100,
        texttemplate='%{text:.2f}%',
        textfont={'size': 8}
    ), row=1, col=i+1)

fig.update_layout(
    coloraxis={'colorscale': 'viridis'},
    height=300,
)
fig.update_xaxes(title='$r_2$', row=1)
fig.update_yaxes(title='$r_3$', col=1)
fig.show()
fig.write_image('../figures/sensitivity_mad_optim.png')