In [None]:
from glob import glob
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt

files = glob('results/analyses/deep_comparison/*.json')

df = []
for file in files:
    _, icd_code, _, _, _ = file.split('/')[-1].split('.')[0].split('_')
    with open(file) as f:
        data = json.load(f)
        for key in data:
            data[key].update({'icd_code': icd_code,'model':key})
            df.append(data[key])
df = pd.DataFrame(df)


# cast all numeric columns to float
for col in df.columns:
    if df[col].dtype == 'object':
        try:
            df[col] = df[col].astype(float)
        except:
            pass

df['t-score'] = df['effect']/df['sem']
df

In [None]:
# reduce horizontal spacing
plt.rcParams['ytick.labelsize'] = 'xx-small'
plt.rcParams['xtick.labelsize'] = 'x-small'

plt.rcParams['axes.titlesize'] = 'small'
plt.rcParams['axes.labelsize'] = 'x-small'
plt.rcParams['legend.fontsize'] = 'x-small'

In [None]:
# 3 horizontal subplots
fig, axes = plt.subplots(1, 3, figsize=(4.5,1.5), )
# plot each icd_code on a different subplot

df_ =  df[(df['model']!='vit_seed_1') & (df['corr']==1) & (df['mask']==1)  &  (df['icd_code'].isin(['F10', 'F31', 'F32',])) ].sort_values('icd_code')
g = sns.barplot(ax=axes[0],x='icd_code', y='effect', hue='model', data=df_, hue_order=['ridge_seed_1','swin_seed_1','cnn_seed_1'])
g.legend_.remove()  # Remove the legend from the first subplot
axes[0].set_xlabel('')

df_ =  df[(df['model']!='vit_seed_1') & (df['corr']==1) & (df['mask']==1)  &  (df['icd_code'].isin(['G20', 'G40', 'G47',])) ].sort_values('icd_code')
sns.barplot(ax=axes[1],x='icd_code', y='effect', hue='model', data=df_, hue_order=['ridge_seed_1','swin_seed_1','cnn_seed_1'])
axes[1].legend_.remove()  # Remove the legend from the second subplot
#remove ylabels
axes[1].set_ylabel('')
axes[1].set_xlabel('')

df_ =  df[(df['model']!='vit_seed_1') & (df['corr']==1) & (df['mask']==1)  &  (df['icd_code'].isin([ 'fluid-intelligence-custom', 'stress-bin', 'socialsupport-bin',])) ].sort_values('icd_code')
df_['effect'] = df_['effect'].abs()
sns.barplot(ax=axes[2],x='icd_code', y='effect', hue='model', data=df_, hue_order=['ridge_seed_1','swin_seed_1','cnn_seed_1'])
axes[2].legend_.remove()  # Remove the legend from the third subplot
axes[2].set_ylabel('')
axes[2].set_xlabel('')

#rotate xlabels
for ax in axes:
    for item in ax.get_xticklabels():
        item.set_rotation(45)

#replace xlabels
g.set_xticklabels(['Alcohol Dependency','Bipolar Disorder','Depression'])
axes[1].set_xticklabels(["Parkinson's",'Epilepsy','Sleep Disorders'])
axes[2].set_xticklabels(['Fluid Intelligence','Severe Stress','Social Support'])

# replace legend labels
handles, labels = g.get_legend_handles_labels()
labels = ['Ridge (MAE=3.??)','Swin (MAE=2.67)','CNN (MAE=2.66)']
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, frameon=False)

# set ylabel
axes[0].set_ylabel('effect size')
plt.subplots_adjust(wspace=0.4)

# Adjust the layout to prevent overlapping of the legend
# plt.tight_layout()

plt.savefig('fig1.png',dpi=300,bbox_inches='tight')