In [None]:
import os
from pathlib import Path
import pickle
import yaml
import pyarrow.dataset as ds
import pandas as pd
from common import load_config

In [None]:
CONFIG = load_config()

OUTPUT_FOLDER_PATH = Path.cwd()/CONFIG['output_folder_name']
CHUNKS_FOLDER_PATH = OUTPUT_FOLDER_PATH/'masking'

In [None]:
# load results
dataset = ds.dataset(CHUNKS_FOLDER_PATH, format="parquet")
table = dataset.to_table()
df_raw = table.to_pandas()

print(f"Rows: {len(df_raw)/1e6:.1f}M")

In [None]:
# dump to a file
df_raw.to_pickle(OUTPUT_FOLDER_PATH/'Table4-raw.pkl.xz')

In [None]:
# preview
df_raw

In [None]:
COLNAMES_TO_DROP = [
    'original_molecule_smiles',
    'original_molecule_important_atom_indices',
    'masked_molecule_smiles',
]
INDEX_COLNAMES = [
    'dataset_name',
    'split_index',
    'model_name',
    'explanation_method',
    'molecule_index',
    'masked_atom_percentage',
    'direction',
    'masking_method',
    'masked_molecule_index',
]

df = df_raw.drop(
    columns=COLNAMES_TO_DROP
).set_index(
    INDEX_COLNAMES
)

In [None]:
df.index.get_level_values('explanation_method').value_counts().sort_index()

In [None]:
df.index.get_level_values('masking_method').value_counts().sort_index()

In [None]:
# preview
df

In [None]:
df = df.loc[
     :,             # dataset_name
     :,             # split_index
     :,             # model_name
     :,             # explanation_method
     :,             # molecule_index
     :,             # masked_atom_pecentage
     :,             # direction
     :,             # masking_method
     :              # masked_molecule_index
]

In [None]:
# preview (more specific)
df.loc[
    'Solubility_AqSolDB', # dataset_name
     0,             # split_index
    'MediumGIN',    # model_name
    'gradcam',      # explanation_method
     0,             # molecule_index
     :,             # masked_atom_pecentage
     :,             # direction
     :,             # masking_method
     :              # masked_molecule_index
]

In [None]:
# -- (Optional) Take 5 first molecules (so that results especially for CReM are of similar size to that of DiffLinker) 

# df = df.reset_index()[
#     df.reset_index()['masked_molecule_index'] < 5
# ].set_index(INDEX_COLNAMES)

In [None]:
# compute prediction values for modified molecules relative to prediction values for respective intact (original) molecules

df2 = df.copy()

for gname, gdf in df.groupby([
    'dataset_name',
    'split_index',
    'model_name',
    'molecule_index',
]):
    prediction_diff = (
        gdf.loc[:, :, :, :, :, :,    :, :, :]['prediction']
      - gdf.loc[:, :, :, :, :, '0%', :, :, :]['prediction'].iloc[0]
    )
    df2.loc[prediction_diff.index, 'delta_prediction'] = prediction_diff.values

# partial preview
gdf

In [None]:
# preview
df2

In [None]:
df3 = df2.loc[df2.index.get_level_values('masked_atom_percentage') != '0%']
df3 = df3[ df3['is_masking_proper'] == True ] #.drop(columns=['is_masking_proper'])
df3

In [None]:
df4 = df3.copy()

for n, g in df3.loc[:, :, :, :, :, :, :, :, :][['is_masking_proper', 'delta_prediction']].groupby([
    'dataset_name',
   #'split_index',
   # molecule_index
   #'model_name',
    'masking_method',
    'explanation_method',
    'direction',
    'masked_atom_percentage'
]):
    should_masked_part_increase_prediction = g.index.get_level_values('direction')[0] == '+'
    is_masking_expected_by_explainer_to_change_prediction_in_expected_direction = g['is_masking_proper']
    is_masking_changing_prediction_in_expected_direction = (
        (g['delta_prediction'] < 0) if should_masked_part_increase_prediction else
        (g['delta_prediction'] > 0)
    )
    is_expl_dir_corr = (
        is_masking_expected_by_explainer_to_change_prediction_in_expected_direction
       *is_masking_changing_prediction_in_expected_direction 
    )
    df4.loc[is_expl_dir_corr.index, 'is_explanation_direction_correct'] = is_expl_dir_corr.values

df4

In [None]:
# -- aggregation
df5 = df4.loc[:, :, :, :, :, :, :, :, :]['is_explanation_direction_correct'].groupby([
    'dataset_name',
    'split_index',
   # molecule_index
    'model_name',
    'masking_method',
    'explanation_method',
    'direction',
    'masked_atom_percentage'
]).agg(
   'mean'
).groupby([
     'dataset_name',
#    'split_index',
#    'model_name',
     'masking_method',
     'explanation_method',
#    'direction',
     'masked_atom_percentage'
]).agg([
     'mean', 'std'
]).unstack()

df5

In [None]:
df5[('stats', '10%')] = [
    r'''\tikz[baseline]{
\fill[gray!29] (0, -0.09) rectangle (''' + f"{1.8*mean}" + r''', 0.34);
\node[anchor=west, font=\normalsize] at (0, 0.1) {''' + f"{100*mean:.0f}" + r'{\footnotesize\textcolor{darkgray}{$\,\pm\,$' + f"{100*std:.0f}" + r'''}}\%};
}
'''
    for mean, std in zip(
        df5[('mean', '10%')],
        df5[('std',  '10%')]
    )
]

df5

In [None]:
rows = [
    (mm, em)
    for mm in [
#       'counterfactual_crem',
        'counterfactual_crem_rings',
        'counterfactual_difflinker',
#       'counterfactual_difflinker_rings',
        'feature_zeroing',
#       'feature_zeroing_rings'
    ]
    for em in [
        'gradcam',
        'igradients',
#       'gnnexplainer_node_object',
        'gnnexplainer_node_attributes',
        'saliency',
        'random',
    ]
]
cols = [
    ('stats', ds, '10%')
    for ds in 'CYP2D6_Veith	CYP3A4_Veith  Solubility_AqSolDB  Lipophilicity_AstraZeneca  hERG_Karim'.split()
]

df6 = df5.reset_index().set_index([
    'masking_method',
    'explanation_method',
    'dataset_name',
]).unstack()
df7 = df6.copy()
df7.columns = df7.columns.swaplevel(1, 2)

df7

In [None]:
df8 = df7.loc[rows][cols]#['stats']['10%']
df8.to_latex(f'table3.tex')

df8

In [None]:
with open(f'table3.tex', 'r') as f:
    contents = f.read()

contents = contents\
    .replace('explanation_method', r'Explanation method')\
    .replace('masking_method', r'Masking method')\
    .replace('dataset_name', r'Dataset')\
    .replace('direction', r'Direction')\
    .replace('_Veith', '')\
    .replace('_AstraZeneca', '')\
    .replace('_AqSolDB', '')\
    .replace('_Karim', '')\
    .replace('gradcam', 'Grad-CAM')\
    .replace('igradients', 'Integrated Gradients')\
    .replace('gnnexplainer', 'GNNExplainer')\
    .replace('saliency', 'Saliency')\
    .replace('random', 'Random')\
    .replace('counterfactual_', '')\
    .replace('crem', 'CReM')\
    .replace('difflinker', 'DiffLinker')\
    .replace('feature_zeroing', r'Feature zeroing')\
    .replace('_node_object', r'\textsubscript{n=o}')\
    .replace('_node_attributes', r'')\
    .replace('_rings', r'\textsubscript{rings}')\
    .replace('& -', r'& Decr.')\
    .replace('& +', r'& Incr.')\
    .replace('{2}{r}', '{2}{c}')\
    .replace('_', ' ')

with open(f'table3.tex', 'w') as f:
    print(contents, file=f)