# Gibbs-Helmholtz Graph Neural Network (GH-GNN)

An illustration of how to use GH-GNN for obtaining predictions on infinite dilution activity coefficients

## 1. Activate the conda environment

Make sure you have activated an anaconda environment with all the necessary dependencies installed as mentioned in the [README file](https://github.com/edgarsmdn/GH-GNN/edit/main/README.md) on GitHub.


## 2. Import the model

In [1]:
from GHGNN import GH_GNN

## 3. Making predictions with GH-GNN

Let's first import some binary systems for which we try to predict infinite dilution activity coefficients

In [2]:
import pandas as pd
from tqdm import tqdm
import numpy as np

In [3]:
df_brouwer = pd.read_csv('../data/processed/brouwer_edge_test.csv')

Now, let's make the predictions for all systems.

This "user-frindely" implementation of the model as a Python class was built to support the prediction of one system at the time. Therefore, it cannot take advantage of the parallezation of the calculations for multiple systems. However, coding such implementation into this class should be straightforward.

The important function here is ```predict``` which takes as arguments:

- ```Solute SMILES```
- ```Solvent SMILES```
- ```Temperature``` in degree Celsius
- ```AD``` flag which refers to a valid Applicability domain option (explained in the following sections)

    - ```'both'```
    - ```'class'```
    - ```'tanimoto'```
    - ```None```

In [4]:
ln_gamma_predictions = []

solutes = df_brouwer['Solute_SMILES'].tolist()
solvents = df_brouwer['Solvent_SMILES'].tolist()
Ts = df_brouwer['T'].to_numpy() + 273.15

for solute, solvent, T in tqdm(zip(solutes, solvents, Ts), total=len(Ts)):
    model = GH_GNN(solute, solvent)
    pred = model.predict(T, AD=None)
    ln_gamma_predictions.append(pred)

df_brouwer['GH-GNN'] = ln_gamma_predictions   

100%|██████████| 2233/2233 [04:01<00:00,  9.25it/s]


Finally, let's see that the predictions here correspond to the ones reported in the original paper:

In [5]:
df_brouwer_pred = pd.read_csv('../models/temperature_dependency/discrete_extrapolation/brouwer_edge_predictions.csv')

In [6]:
np.mean(np.array(ln_gamma_predictions) - df_brouwer_pred['GNNGH_T'].to_numpy())

1.8691514783963946e-08

In [10]:
import plotly.express as px
import molplotly
import random

fig_scatter = px.scatter(df_brouwer,
                             x="log-gamma",
                             y='GH-GNN',
                             title='GH-GNN parity plot',
                             labels={'GH-GNN': 'Predicted IDAC (GH-GNN)',
                                     'log-gamma': 'Experimental IDAC',
                                     },
                             width=800,
                             height=700)

y = df_brouwer["log-gamma"].values
fig_scatter.add_shape(
    type="line", line=dict(dash='dash'),
    x0=y.min(), y0=y.min(),
    x1=y.max(), y1=y.max()
)

app_scatter = molplotly.add_molecules(fig=fig_scatter,
                                      df=df_brouwer,
                                      smiles_col=['Solvent_SMILES', 'Solute_SMILES'],
                                      caption_cols=['T'],
                                      caption_transform={'Predicted IDAC (GH-GNN)': lambda x: f"{x:.2f}",
                                                         'Experimental IDAC': lambda x: f"{x:.2f}",
                                                         'T': lambda x: f"{x:.2f}"
                                                         },
                                      )

app_scatter.run_server(mode='inline', port=8020+random.randint(0, 999), height=800)

## 4. Using the applicability domain recommendations

In the original paper, the following recommendations are made regarding the applicability domain of GH-GNN in order to obtained a prediction with good accuracy:
- **Chemical class** representation in the training set $\geq25$
- **Tanimoto** indicator $\geq0.35$

Let's now restrict our predictions to include only feasible systems according to the **chemical class representation indicator**. Note that the following cell will last for a while to finish. The reason for this is that the API connection of [Classyfire](http://classyfire.wishartlab.com/) has a limit of around 12 queries per minute.

In [11]:
ln_gamma_predictions=[]
feasibles = []
n_classes = []

for solute, solvent, T in tqdm(zip(solutes, solvents, Ts), total=len(Ts)):
    model = GH_GNN(solute, solvent)
    pred, feasible, n_class = model.predict(T, AD='class')
    ln_gamma_predictions.append(pred)
    feasibles.append(feasible)
    n_classes.append(n_class)

df_brouwer['GH-GNN'] = ln_gamma_predictions
df_brouwer['Feasible Chemical Class'] = feasibles
df_brouwer['N systems in training'] = n_classes

100%|██████████| 2233/2233 [23:58<00:00,  1.55it/s] 


In [12]:
fig_scatter = px.scatter(df_brouwer,
                             x="log-gamma",
                             y='GH-GNN',
                             color='Feasible Chemical Class',
                             title='GH-GNN parity plot',
                             labels={'GH-GNN': 'Predicted IDAC (GH-GNN)',
                                     'log-gamma': 'Experimental IDAC',
                                     'Feasible Chemical Class': 'Feasible according to Class?'
                                     },
                             width=800,
                             height=700)

y = df_brouwer["log-gamma"].values
fig_scatter.add_shape(
    type="line", line=dict(dash='dash'),
    x0=y.min(), y0=y.min(),
    x1=y.max(), y1=y.max()
)

app_scatter = molplotly.add_molecules(fig=fig_scatter,
                                      df=df_brouwer,
                                      smiles_col=['Solvent_SMILES', 'Solute_SMILES'],
                                      caption_cols=['T', 'N systems in training'],
                                      caption_transform={'Predicted IDAC (GH-GNN)': lambda x: f"{x:.2f}",
                                                         'Experimental IDAC': lambda x: f"{x:.2f}",
                                                         'T': lambda x: f"{x:.2f}"
                                                         },
                                      color_col='Feasible Chemical Class'
                                      )

app_scatter.run_server(mode='inline', port=8020+random.randint(0, 999), height=800)

Now, let's try only the Tanimoto indicator. You will notice that the predictions take longer given that the **Tanimoto indicator** is now being calculated for all systems.

In [13]:
ln_gamma_predictions=[]
feasibles = []
tanimoto_indicators = []

for solute, solvent, T in tqdm(zip(solutes, solvents, Ts), total=len(Ts)):
    model = GH_GNN(solute, solvent)
    pred, feasible, max_10_sim = model.predict(T, AD='tanimoto')
    ln_gamma_predictions.append(pred)
    feasibles.append(feasible)
    tanimoto_indicators.append(max_10_sim)

df_brouwer['GH-GNN'] = ln_gamma_predictions
df_brouwer['Feasible Tanimoto'] = feasibles
df_brouwer['Tanimoto indicators'] = tanimoto_indicators

100%|██████████| 2233/2233 [05:10<00:00,  7.19it/s]


In [14]:
fig_scatter = px.scatter(df_brouwer,
                             x="log-gamma",
                             y='GH-GNN',
                             color='Feasible Tanimoto',
                             title='GH-GNN parity plot',
                             labels={'GH-GNN': 'Predicted IDAC (GH-GNN)',
                                     'log-gamma': 'Experimental IDAC',
                                     'Feasible Tanimoto': 'Feasible according to Tanimoto?'
                                     },
                             width=800,
                             height=700)

y = df_brouwer["log-gamma"].values
fig_scatter.add_shape(
    type="line", line=dict(dash='dash'),
    x0=y.min(), y0=y.min(),
    x1=y.max(), y1=y.max()
)

app_scatter = molplotly.add_molecules(fig=fig_scatter,
                                      df=df_brouwer,
                                      smiles_col=['Solvent_SMILES', 'Solute_SMILES'],
                                      caption_cols=['T', 'Tanimoto indicators'],
                                      caption_transform={'Predicted IDAC (GH-GNN)': lambda x: f"{x:.2f}",
                                                         'Experimental IDAC': lambda x: f"{x:.2f}",
                                                         'T': lambda x: f"{x:.2f}",
                                                         'Tanimoto indicators': lambda x: f"{x:.2f}"
                                                         },
                                      color_col='Feasible Tanimoto'
                                      )

app_scatter.run_server(mode='inline', port=8020+random.randint(0, 999), height=800)

As can be seen by comparing the previous two plots, the two different applicability domain indicators have conflicts in classifying several systems as feasible or unfeasible. However, we can notice that the Tanimoto indicator provides a more reliable prediction of the applicability domain of GH-GNN compared to the chemical classes indicator.

Let's look at the intersection fo these two indicators.

In [15]:
df_intersection = df_brouwer[(df_brouwer['Feasible Chemical Class']==True) & (df_brouwer['Feasible Tanimoto']==True)]

In [16]:
fig_scatter = px.scatter(df_intersection,
                             x="log-gamma",
                             y='GH-GNN',
                             title='GH-GNN parity plot',
                             labels={'GH-GNN': 'Predicted IDAC (GH-GNN)',
                                     'log-gamma': 'Experimental IDAC',
                                     },
                             width=800,
                             height=700)

y = df_brouwer["log-gamma"].values
fig_scatter.add_shape(
    type="line", line=dict(dash='dash'),
    x0=y.min(), y0=y.min(),
    x1=y.max(), y1=y.max()
)

app_scatter = molplotly.add_molecules(fig=fig_scatter,
                                      df=df_intersection,
                                      smiles_col=['Solvent_SMILES', 'Solute_SMILES'],
                                      caption_cols=['T', 'N systems in training', 'Tanimoto indicators'],
                                      caption_transform={'Predicted IDAC (GH-GNN)': lambda x: f"{x:.2f}",
                                                         'Experimental IDAC': lambda x: f"{x:.2f}",
                                                         'T': lambda x: f"{x:.2f}",
                                                         'Tanimoto indicators': lambda x: f"{x:.2f}"
                                                         },
                                      )

app_scatter.run_server(mode='inline', port=8020+random.randint(0, 999), height=800)

In [17]:
from sklearn.metrics import mean_absolute_error

mean_absolute_error(df_intersection['log-gamma'].to_numpy(), df_intersection['GH-GNN'].to_numpy())

0.2881038590512974