In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from collections import defaultdict

In [2]:
wrkdir = '/pubhome/xli02/project/PLIM/deep_learning/FAST/fast_plim/test_asign_charge'
label_to_res_dict = defaultdict(list)

In [3]:
log_dir_names = ['PDBbind_intersected_Uw','PLANet_Uw']
mdl_types = ['complex_6A', 'lig_alone']
test_types = ['valid', 'train', 'test', 'core', 'core_intersected_Uw'] #change by the order in log

In [4]:
for log_dir_name in log_dir_names:
    for mdl_type in mdl_types:
        if mdl_type == 'complex_6A':
            m_type = 'cpx'
        else:
            m_type = 'lig_alone'
        log_dir = f'{wrkdir}/3-test/scripts/rm_core_ids/{log_dir_name}/{mdl_type}/log/'

        log_files = [str(p) for p in list(Path(log_dir).glob('*log'))]
        log_files.sort()
        for i, log_f in enumerate(log_files):
            with open(log_f, 'r') as f:
                lines = f.readlines()
            j = 0
            for line in lines:
                if 'Performance on test set' in line:
                    R2 = float(line.split(',')[0].split(':')[2])
                    mae = float(line.split(',')[1].split(':')[1])
                    mse = float(line.split(',')[2].split(':')[1])
                    pearsonr = float(line.split(',')[3].split('(')[1])
                    spearmanr =float(line.split(',')[5].split('=')[1])
                    label_to_res_dict[f'{log_dir_name}_{m_type}_{i+1}_{test_types[j]}']=[R2, mae, mse, pearsonr, spearmanr]
                    j = j + 1

In [9]:
sum_df = pd.DataFrame.from_dict(label_to_res_dict, orient='index', columns=['R2', 'mae', 'mse', 'pearsonr', 'spearmanr']).reset_index()
sum_df.rename(columns={"index": "model_names_test_type"}, inplace=True)
sum_df

Unnamed: 0,model_names_test_type,R2,mae,mse,pearsonr,spearmanr
0,PDBbind_intersected_Uw_cpx_1_valid,0.416541,0.983162,1.548035,0.648518,0.573573
1,PDBbind_intersected_Uw_cpx_1_train,0.477236,0.935431,1.401182,0.701890,0.648219
2,PDBbind_intersected_Uw_cpx_1_test,0.411513,0.979637,1.542658,0.648389,0.572081
3,PDBbind_intersected_Uw_cpx_1_core,0.464997,1.277336,2.520483,0.696768,0.685162
4,PDBbind_intersected_Uw_cpx_1_core_intersected_Uw,0.526759,1.177720,2.054616,0.797355,0.771584
...,...,...,...,...,...,...
95,PLANet_Uw_lig_alone_5_valid,0.573174,0.691384,0.819810,0.763093,0.745674
96,PLANet_Uw_lig_alone_5_train,0.796283,0.482763,0.395801,0.892445,0.881875
97,PLANet_Uw_lig_alone_5_test,0.572112,0.702247,0.856615,0.761756,0.745869
98,PLANet_Uw_lig_alone_5_core,0.377481,1.408895,2.932783,0.628508,0.627304


In [12]:
sum_df['model_names'] = ['_'.join(m.split('_')[:-3]) if 'core_intersected_Uw' in m else '_'.join(m.split('_')[:-1]) for m in sum_df['model_names_test_type']]
sum_df['test_type'] = ['core_intersected_Uw' if 'core_intersected_Uw' in m else m.split('_')[-1] for m in sum_df['model_names_test_type']]
sum_df['dataset'] = ['_'.join(m.split('_')[:-2]) if 'cpx' in m else '_'.join(m.split('_')[:-3]) for m in sum_df['model_names']]
sum_df['model_type'] = ['complex' if 'cpx' in m else 'ligand_alone' for m in sum_df['model_names']]
sum_df

Unnamed: 0,model_names_test_type,R2,mae,mse,pearsonr,spearmanr,model_names,test_type,dataset,model_type
0,PDBbind_intersected_Uw_cpx_1_valid,0.416541,0.983162,1.548035,0.648518,0.573573,PDBbind_intersected_Uw_cpx_1,valid,PDBbind_intersected_Uw,complex
1,PDBbind_intersected_Uw_cpx_1_train,0.477236,0.935431,1.401182,0.701890,0.648219,PDBbind_intersected_Uw_cpx_1,train,PDBbind_intersected_Uw,complex
2,PDBbind_intersected_Uw_cpx_1_test,0.411513,0.979637,1.542658,0.648389,0.572081,PDBbind_intersected_Uw_cpx_1,test,PDBbind_intersected_Uw,complex
3,PDBbind_intersected_Uw_cpx_1_core,0.464997,1.277336,2.520483,0.696768,0.685162,PDBbind_intersected_Uw_cpx_1,core,PDBbind_intersected_Uw,complex
4,PDBbind_intersected_Uw_cpx_1_core_intersected_Uw,0.526759,1.177720,2.054616,0.797355,0.771584,PDBbind_intersected_Uw_cpx_1,core_intersected_Uw,PDBbind_intersected_Uw,complex
...,...,...,...,...,...,...,...,...,...,...
95,PLANet_Uw_lig_alone_5_valid,0.573174,0.691384,0.819810,0.763093,0.745674,PLANet_Uw_lig_alone_5,valid,PLANet_Uw,ligand_alone
96,PLANet_Uw_lig_alone_5_train,0.796283,0.482763,0.395801,0.892445,0.881875,PLANet_Uw_lig_alone_5,train,PLANet_Uw,ligand_alone
97,PLANet_Uw_lig_alone_5_test,0.572112,0.702247,0.856615,0.761756,0.745869,PLANet_Uw_lig_alone_5,test,PLANet_Uw,ligand_alone
98,PLANet_Uw_lig_alone_5_core,0.377481,1.408895,2.932783,0.628508,0.627304,PLANet_Uw_lig_alone_5,core,PLANet_Uw,ligand_alone


In [16]:
sum_df.to_csv(f'{wrkdir}/4-evaluation/PDBbind_intersected_Uw_vs_Uw_Rm_core_set/PDBbind_intersected_Uw_vs_Uw_Rm_core_set.csv', sep='\t', index=False)

In [14]:
grouped_median = sum_df.groupby(['test_type', 'dataset', 'model_type']).median().reset_index()
grouped_median

Unnamed: 0,test_type,dataset,model_type,R2,mae,mse,pearsonr,spearmanr
0,core,PDBbind_intersected_Uw,complex,0.481174,1.254616,2.44427,0.697031,0.696258
1,core,PDBbind_intersected_Uw,ligand_alone,0.350008,1.402436,3.062212,0.597368,0.606296
2,core,PLANet_Uw,complex,0.259747,1.494221,3.487446,0.533767,0.523904
3,core,PLANet_Uw,ligand_alone,0.363284,1.395829,2.999668,0.618361,0.606949
4,core_intersected_Uw,PDBbind_intersected_Uw,complex,0.552291,1.120039,1.943768,0.792828,0.771584
5,core_intersected_Uw,PDBbind_intersected_Uw,ligand_alone,0.465158,1.208176,2.322066,0.716789,0.656771
6,core_intersected_Uw,PLANet_Uw,complex,0.350822,1.316727,2.818465,0.634827,0.622413
7,core_intersected_Uw,PLANet_Uw,ligand_alone,0.46623,1.263153,2.317409,0.742276,0.705093
8,test,PDBbind_intersected_Uw,complex,0.410012,0.982232,1.546595,0.646297,0.572081
9,test,PDBbind_intersected_Uw,ligand_alone,0.345011,1.045288,1.716986,0.599512,0.525921


In [15]:
grouped_median['model_name'] = [f'{row.dataset}_cpx' if row.model_type == 'complex' else f'{row.dataset}_lig_alone' for row in grouped_median.itertuples()]
grouped_median.to_csv(f'{wrkdir}/4-evaluation/PDBbind_intersected_Uw_vs_Uw_Rm_core_set/PDBbind_intersected_Uw_vs_Uw_Rm_median.csv', sep='\t', index=False)

In [23]:
grouped = sum_df.groupby('test_type')
for metric in ['R2', 'mae', 'mse', 'pearsonr', 'spearmanr']:
    out_dir = f'{wrkdir}/4-evaluation/PDBbind_intersected_Uw_vs_Uw_Rm_core_set/{metric}'
    if not Path(out_dir).exists():
        Path(out_dir).mkdir()
    for test_type in test_types:
        test_df = grouped.get_group(test_type)
        fig, ax= plt.subplots()
        sns.boxplot(x="dataset", y=metric, data=test_df, hue="model_type", order = log_dir_names, linewidth=2.5)
        sns.swarmplot(x="dataset", y=metric, data=test_df, hue="model_type", order = log_dir_names, size = 4, dodge=True, edgecolor="black", linewidth=0.7)
        ax.set_title(f'{metric} on {test_type} set')
        fig.autofmt_xdate()
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles[:2], labels[:2])
        vertical_offset = test_df[metric].median() * 0.07 # offset from median for display
        for i, modl in enumerate(log_dir_names):
            for tp in ['cpx', 'lig_alone']:
                median_metric = round(grouped_median.loc[(grouped_median['model_name'] == f'{modl}_{tp}') & (grouped_median['test_type'] == test_type)][metric].values[0], 3)
                if tp == 'cpx':
                    ax.text(i-.2, median_metric-vertical_offset, median_metric, horizontalalignment='center',size='small', weight='semibold')
                else:
                    ax.text(i+.2, median_metric-vertical_offset, median_metric, horizontalalignment='center',size='small',weight='semibold')
        ax.set_ylim(min(test_df[metric])-0.05, min(max(test_df[metric])+0.05, 1))
        plt.savefig(f'{out_dir}/{metric}_on_{test_type}_set(assign_charge_0).png', dpi=300, bbox_inches='tight')
        plt.close()

In [22]:
data_grouped = sum_df.groupby('dataset')
for metric in ['R2', 'mae', 'mse', 'pearsonr', 'spearmanr']:
    out_dir = f'{wrkdir}/4-evaluation/PDBbind_intersected_Uw_vs_Uw_Rm_core_set/{metric}/performance_inner_dataset'
    if not Path(out_dir).exists():
        Path(out_dir).mkdir()
    for test_data in log_dir_names:
        data_df = data_grouped.get_group(test_data)
        fig, ax= plt.subplots()
        sns.boxplot(x="test_type", y=metric, data=data_df, hue="model_type", order = ['train', 'valid', 'test', 'core', 'core_intersected_Uw'], linewidth=2.5)
        sns.swarmplot(x="test_type", y=metric, data=data_df, hue="model_type", order = ['train', 'valid', 'test', 'core', 'core_intersected_Uw'], size = 4, dodge=True, edgecolor="black", linewidth=0.7)
        ax.set_title(f'{metric} on {test_data} set')
        fig.autofmt_xdate()
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles[:2], labels[:2])
        vertical_offset = data_df[metric].median() * 0.07 # offset from median for display
        for i, test_tp in enumerate(['train', 'valid', 'test', 'core', 'core_intersected_Uw']):
            for tp in ['cpx', 'lig_alone']:
                median_metric = round(grouped_median.loc[(grouped_median['model_name'] == f'{test_data}_{tp}') & (grouped_median['test_type'] == test_tp)][metric].values[0], 3)
                if tp == 'cpx':
                    ax.text(i-.2, median_metric-vertical_offset, median_metric, horizontalalignment='center',size='small', weight='semibold')
                else:
                    ax.text(i+.2, median_metric-vertical_offset, median_metric, horizontalalignment='center',size='small',weight='semibold')
        ax.set_ylim(min(data_df[metric])-0.05, min(max(data_df[metric])+0.05, 1))
        plt.savefig(f'{out_dir}/{metric}_on_{test_data}(assign_charge_0).png', dpi=300, bbox_inches='tight')
        plt.close()