In [None]:
!pip install rdkit

In [None]:
!pip install torch-geometric

In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.explain import Explainer, GNNExplainer

from exai_tutorial import get_esol_data, binary_accuracy
from graphrepr import featurise_data, feature_meaning
from cyp_train import Net, train_model, sign_accuracy, best_ranked_accuracy, mse_masked_loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device}")

# **Part 1** Node-level regression task

In this section, we will focus on a neural network that predicts Crippen contributions (regression) for each atom (node-level task).

## Dataset preparation
We will use ESOL data but instead of predicting solubility, we will calculate Crippen contributions for each atom and use them as labels.

In [None]:
esol = get_esol_data()
display(esol)

In [None]:
data_list = featurise_data(esol, node_level=True, device=device)
num_node_features = data_list[0].x.shape[1]
print(f'Number of features: {num_node_features}')
print(data_list[0])

In [None]:
# primitive train-test split
num_train = 900
train_loader = DataLoader(data_list[:num_train], batch_size=64, shuffle=False)
test_loader = DataLoader(data_list[num_train:], batch_size=64, shuffle=False)

## GNN definition and training

We will use a standard GCN model that has only one convolutional layer.

In [None]:
# building and training neural network
model_name = 'model_regression.p'

model = Net(hidden_size=512, num_node_features=num_node_features, num_classes=1,
            num_conv_layers=1, num_linear_layers=5, dropout=0.5,
            conv_layer='GCN', skip_connections=False, batch_norm=True, dummy_size=0, device=device).to(device)


if os.path.isfile(model_name):
  model.load_state_dict(torch.load(model_name, weights_only=True))
  model.eval()
  print('loaded a model')
else:
  n_epochs = 150  # ITCO CPU 100
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
  test_loss, epoch, metrics = train_model(model, train_loader, test_loader,
                                           optimizer, n_epochs=n_epochs,
                                           metrics={
                                               'sign_acc': (sign_accuracy, 'graph'),
                                               'mse': (mse_masked_loss, 'graph'),
                                               },
                                           device=device, model_path=model_name)

  print(f'\n trained a new model in {epoch} epochs and reached loss of {float(test_loss):.4f}')
  print(f'test scores: {metrics}')

Let's analyse model's errors!

We will calculate average mean square errors for each atom and visualise them.

In [None]:
def mse(model, data_loader):
  loss = []
  for mol in data_loader:
    pred = model.forward(data=mol)
    loss.extend([mse_masked_loss(pred, mol.y).detach().cpu().numpy(),] * mol.num_nodes)

  loss = np.array(loss)
  print(f'MSE loss: {np.mean(loss)}')
  return loss

train_loss = mse(model, train_loader)
test_loss = mse(model, test_loader)

In [None]:
tetr = np.concatenate((test_loss, train_loss))
maximum = sorted(tetr)[int(0.9*len(tetr))]  # 90% of errors are smaller/equal than this

plt.figure(figsize=(10,3))
plt.subplot(121)
plt.hist(train_loss, range=(np.min(tetr), np.max(tetr)), bins=120)
plt.title('train errors')
plt.subplot(122)
plt.hist(test_loss, range=(np.min(tetr), np.max(tetr)), bins=120)
plt.title('test errors')
plt.show()

**conclusion:**

## Explainability

Let's get accustomed to GNNExplainer explanations!

<mark>First, we must define an explainer. Try figuring out what parameter values to use.<mark>

- [Explainer documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/explain.html#explainer)
- [GNNExplainer documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.algorithm.GNNExplainer.html#torch_geometric.explain.algorithm.GNNExplainer)

In [None]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',     # Explains the model prediction.
    node_mask_type='attributes',  # Will mask each feature across all nodes.
    edge_mask_type='object',      # Will mask each edge.
    model_config=dict(
        mode='regression',
        task_level='node',
        return_type='raw',  # not probabilities or log-probabilities
    ),
)

Let's calculate an explanation for some sample.

Since the GNN made a prediction for each atom separately, we must provide index of the atom whose prediction we're interested in.

In [None]:
data = data_list[-5]
node_index = 1

expl = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {expl.available_explanations}')

In [None]:
print('This is how the explanation object looks like:')
print(expl)

<mark>Let's write several functions that will help us analyse the explanations.</mark>

Remember to call `cpu().numpy()` before using `torch.Tensor`s as input to numpy functions.

`important_atom_features:`
- **input:** explanation
- **output:** indices of atom features with nonzero importance as `np.array[n_features, 2]`  with `(node index, feature index)` in each row
- `np.nonzero` will be useful

`neighbours:`
- **input:** explanation, index of atom whose prediction was explained
- **output:** indices of direct neighbours of the atom whose prediction was explained as `np.array`

`important_features_indices:`
- **input:** explanation
- **output:** indices of all features which had nonzero importance for at least one atom
- `np.where` will be useful

`important_edges:`
- **input:** explanation
- **output:** a list of edges with nonzero importance as `np.array[2, n_edges]` `(begin node index, end node index)`

In [None]:
# this will be written by students
def important_atom_features(explanation):
  return np.nonzero(explanation.node_mask).cpu().numpy()  # node index x feature index

def neighbours(explanation, node_id):
  # check if nodes with important features are those that are close enough to the node that is being explained
  adj = explanation.edge_index.cpu().numpy()
  which_nodes = adj[1, adj[0,:]==node_id]  # this should give us all nodes directly connected to node at `node_id`
  return which_nodes

def important_features_indices(explanation):
  feature_importance = np.sum(explanation.node_mask.cpu().numpy(), axis=0)
  which_features = np.where(feature_importance>0)  # indices of important features
  return which_features

def important_edges(explanation):
  return explanation.edge_index.cpu().numpy()[:, explanation.edge_mask.cpu().numpy()>0]   # te krawędzie były ważne


In [None]:
imp_features = important_atom_features(expl)
neigh_nodes = neighbours(expl, node_index)
imp_edges = important_edges(expl)

<mark>Sanity check</mark>
Check if nodes with important features are those that are close enough to the node that is being explained.

In [None]:
for node in set(imp_features[:,0]):
  print(node, node in neigh_nodes.tolist() + [node_index,])

**conclusion:**

Which features are important (their meaning)?

In [None]:
feature_meaning[important_features_indices(expl)]

<mark>Is there a tendency for important features to have a certain feature value (zero or one)?</mark>

- `np.where` will be useful again

In [None]:
# these should be sets
nonzero_features = np.where(expl.x.cpu().numpy()!=0)
nonzero_features = set(zip(*nonzero_features))

important_features = set(zip(*imp_features.T))

print(important_features - nonzero_features)

**conclusion:**

## Visualising explanations

In [None]:
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
import io
from PIL import Image
from collections import defaultdict

<mark>Let's write a few functions that will help us with visualisations.</mark>

`important_nodes:`
- **input:** explanation
- **output:** indices of all nodes for which at least one feature had nonzero importance as a `list`
- use `np.nonzero`

`node_importance:`
- **input:** explanation
- **output:** importance of each node defined as a sum of importance scores of it's features as a `np.array[n_nodes]`
- use `np.sum`

`edge_importance:`
- **input:** explanation
- **output:** importance of each bond defined analogously as above as a `dict` with keys `(start node index, end node index)`
- mind that each bond appears twice in `edge_mask`

In [None]:
def important_nodes(explanation):
  # which nodes were important
  return np.nonzero(np.sum(expl.node_mask.cpu().numpy(), axis=1))[0].tolist()

def node_importance(explanation):
  # how important each node is
  return np.sum(expl.node_mask.cpu().numpy(), axis=1).astype(float)

def edge_importance(explanation):
  bond_importance = defaultdict(float)
  for e1, e2, imp in list(zip(*explanation.edge_index.cpu(), explanation.edge_mask.cpu())):
    start = np.min((e1, e2))
    end = np.max((e1, e2))
    bond_importance[(int(start), int(end))] += float(imp)
  return dict(bond_importance)

### Visualise everything
Our first function will visualise:
- for which node the explanation was calculated
- what are it's neighbours
- how important each atom is
- how important each bond is

In [None]:
def visualise_everything(sample, explanation):
  mol = sample.mol

  no_col = (0.0, 0.0, 0.0, 0.0)
  rgba_color = (0.0, 0.0, 1.0, 0.5) # transparent blue for node being explained
  neigh_col = (0.8, 0.0, 0.8, 0.5)  # purple for direct neighbours
  imp_col = (0.0, 0.8, 0.0)    # green for nodes with non-zero importance
  bonds_col = (0.8, 0.8, 0.8)  # gray for bonds

  imp_nodes = important_nodes(explanation)       # which atoms are important
  atom_importance = node_importance(explanation) # how important each atom is

  # atoms that are neighbour to the node being explained
  neighs = neighbours(explanation, explanation.index).tolist() if hasattr(explanation, 'index') else []
  imp_edges = important_edges(explanation)  # edges that are important
  bond_importance = edge_importance(explanation)  # how important each edge is
  bond_normalisation = np.max(list(bond_importance.values()))

  atom_highlights = defaultdict(list)  # higlight colours for each atom
  arads = {}                           # highlight size (here: atom importance)

  # colouring atoms
  for a in mol.GetAtoms():
    a_idx = a.GetIdx()
    colours = []

    if hasattr(explanation, 'index') and a_idx == explanation.index:
      # node being explained
      colours.append(rgba_color)
    if a_idx in neighs:
      # its neighbours
      colours.append(neigh_col)
    if a_idx in imp_nodes:
      # nodes with non-zero importance
      colours.append(imp_col + (float(atom_importance[a_idx]/np.max(atom_importance)), ) )

    if len(colours)==0:
      # other atoms (each atom must be defined)
      colours.append(no_col)

    atom_highlights[a_idx].extend(colours)
    arads[a_idx] = float(atom_importance[a_idx])

  # colouring bonds
  bond_highlights = defaultdict(list)
  imp_bonds = [mol.GetBondBetweenAtoms(int(edge[0]), int(edge[1])).GetIdx() for edge in list(zip(*imp_edges))]

  for bond in mol.GetBonds():
      bid = bond.GetIdx()
      b1, b2 = min(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()), max(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())

      if bid in imp_bonds:
          bond_highlights[bid].append(bonds_col + (float(bond_importance[(b1, b2)]/bond_normalisation), ))
      else:
          bond_highlights[bid].append(no_col)

  arads = dict()

  # making a drawing
  d = rdMolDraw2D.MolDraw2DSVG(400, 200) # MolDraw2DSVG for SVG or MolDraw2DCairo to get PNGs
  d.DrawMoleculeWithHighlights(mol, sample.smiles, dict(atom_highlights), dict(bond_highlights), arads, {})
  d.FinishDrawing()

  return d


In [None]:
d = visualise_everything(data, expl)
SVG(d.GetDrawingText())

**conclusion:**

### Visualise importance only

Our second function will colour atoms based on their importance for the prediction.

In [None]:
def visualise_importance(sample, explanation):
  mol = sample.mol
  red = (1, 0, 0.3)

  atom_highlights = defaultdict(list)  # higlight colours for each atom
  arads = {}                           # highlight size

  atom_importance = node_importance(explanation) # how important each atom is

  for a in mol.GetAtoms():
    a_idx = a.GetIdx()

    col = red + (atom_importance[a_idx]/np.max(atom_importance),)
    atom_highlights[a_idx].append(col)

  d = rdMolDraw2D.MolDraw2DSVG(400, 200) # MolDraw2DSVG for SVG or MolDraw2DCairo to get PNGs
  d.DrawMoleculeWithHighlights(mol, sample.smiles, dict(atom_highlights), dict(), {}, {})

  d.FinishDrawing()

  return d

In [None]:
d = visualise_importance(data, expl)
SVG(d.GetDrawingText())  # SVG

<mark>Play time!</mark>

Now let's look at explanations for some molecules. Do they make sense from a chemical point of view?

In [None]:
random_sample = np.random.randint(len(data_list))
data = data_list[random_sample]
node_index = np.random.randint(data.num_nodes)

expl = explainer(data.x, data.edge_index, index=node_index)
print(f'Molecule index: {random_sample}, atom index: {node_index}')
print(f'Generated explanations in {expl.available_explanations}')

In [None]:
data.mol

In [None]:
d = visualise_everything(data, expl)
SVG(d.GetDrawingText())  # SVG

In [None]:
d = visualise_importance(data, expl)
SVG(d.GetDrawingText())  # SVG

# **Part 2** Graph-level classification task

In this section, we will focus on a neural network that predicts if a molecule (graph-level task) is soluble (classification).

We will use ESOL data and classify molecules as having solubility higher or lower than **`-3`**.

## Dataset preparation

In [None]:
data_list = featurise_data(esol, node_level=False, device=device)
num_node_features = data_list[0].x.shape[1]
print(f'Number of features: {num_node_features}')
print(data_list[0])

# primitive train-test split
num_train = 900
train_loader = DataLoader(data_list[:num_train], batch_size=64, shuffle=False)
test_loader = DataLoader(data_list[num_train:], batch_size=64, shuffle=False)

## Model definition and training

This time, we will use a GCN with several convolutional layers that makes a prediction for the entire molecule.

In [None]:
# building and training neural network
model_name = 'model_classification.p'

model = Net(hidden_size=512, num_node_features=num_node_features, num_classes=2,
            num_conv_layers=3, num_linear_layers=3, dropout=0.5,
            conv_layer='GCN', skip_connections=False, batch_norm=True, dummy_size=0,
            graph_level=True, device=device).to(device)

if os.path.isfile(model_name):
  model.load_state_dict(torch.load(model_name, weights_only=True))
  model.eval()
  print('loaded a model')
else:
  n_epochs = 40  # 30 ITCO CPU
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

  test_loss, epoch, metrics = train_model(model, train_loader, test_loader, optimizer,
                                          n_epochs=n_epochs, model_path=model_name,
                                          metrics={'acc': (binary_accuracy, 'graph'),},
                                          device=device)

  print(f'\n trained a new model in {epoch} epochs and reached loss of {float(test_loss):.4f}')
  print(f'test scores: {metrics}')

In [None]:
acc = 0
for batch in train_loader:
  acc += binary_accuracy(model(batch), batch.y) * batch.num_graphs

acc = acc/len(train_loader.dataset)
print(f'Train accuracy: {acc:.4f}')


Let's see what for molecules the model makes errors.

In [None]:
def analyse_mispredictions(model, loader, print_func=lambda x:None):
  smis, vals, preds = [], [], []
  for batch in loader:
    pred = model(batch)
    pred_class = pred[:, 0] < pred[:, 1]
    mask = (pred_class!=batch.y).cpu().numpy()

    smis.extend(np.array(batch.smiles)[mask])
    vals.extend(batch.raw_y[mask])
    preds.extend(pred[mask])

  for smi, val, p in zip(smis, vals, preds):
    print_func(f'true: {float(val):.2f}  {str(smi)} ')

  return smis, np.array(vals), preds


In [None]:
tr_smis, tr_vals, tr_preds = analyse_mispredictions(model, train_loader)
te_smis, te_vals, te_preds = analyse_mispredictions(model, test_loader)

plt.figure()
plt.suptitle("solubility of mispredicted molecules from the...")
plt.subplot(121)
plt.hist(tr_vals, bins=15)
plt.plot([-3, -3], [0, max(plt.yticks()[0])])
plt.title(f"train set ({len(tr_vals)} = {np.sum(tr_vals<=-3)} + {np.sum(tr_vals>-3)})")

plt.subplot(122)
plt.hist(te_vals, bins=15)
plt.plot([-3, -3], [0, max(plt.yticks()[0])])
plt.title(f"test set ({len(te_vals)} = {np.sum(te_vals<=-3)} + {np.sum(te_vals>-3)})")

plt.show()

Do mispredicted molecules come from both classes equally or are they mostly from one of the classes?

Is is the same for train and test data?

**conclusion:**

Is there a correlation between train/test data and solubility value?

In [None]:
y = [d.raw_y for d in data_list]
plt.figure(figsize=(10,3))
plt.scatter(list(range(len(y))), y, s=5)
plt.show()

**conclusion:**

## Explainability

<mark>What should be the parameter values in this case?</mark>

In [None]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model', # Explains the model prediction.
    node_mask_type='object',  # Will mask each node.
    edge_mask_type='object',  # Will mask each edge.
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',  # the model returns log-probabilities
    ),
)

<mark>Play time!</mark>

Let's have a look at some explanations. Does the model looks where it should?

In [None]:
random_index = np.random.randint(len(data_list))
data = data_list[random_index]
mol =  data.mol
expl = explainer(data.x, data.edge_index)
print(f'Molecule index: {random_index}')
print(f'Generated explanations in {expl.available_explanations}')
print(expl)

In [None]:
mol

In [None]:
d = visualise_importance(data, expl)
SVG(d.GetDrawingText())

In [None]:
d = visualise_everything(data, expl)
SVG(d.GetDrawingText())

Let's analyse the influence of the number of epochs on explanations produced by GNNExplainer.

Just like (almost) any other method in ML, GNNExplainer has hyperparametrs whose values must be carefully chosen. In this case, it is the number of epochs for which the GNNExplainer model is trained.

<mark>What should be the parameter values in this case?</mark>

In [None]:
explainer50 = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=50),
    explanation_type='model', # Explains the model prediction.
    node_mask_type='object',  # Will mask each node.
    edge_mask_type='object',  # Will mask each edge.
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',  # the model returns log-probabilities
    ),
)

In [None]:
e200 = [explainer(data.x, data.edge_index) for i in range(10)]
e50 = [explainer50(data.x, data.edge_index) for i in range(10)]

In [None]:
nm200 = np.array([e.node_mask.cpu() for e in e200])
nm50 = np.array([e.node_mask.cpu() for e in e50])

nv200 = np.std(nm200, axis=0)
nv50 = np.std(nm50, axis=0)

for var50, var200 in zip(nv50[:,0], nv200[:,0]):
  print(f'{var50:.4f} {var200:.4f}')

**conclusion:**

In [None]:
em200 = np.array([e.edge_mask.cpu() for e in e200])
em50 = np.array([e.edge_mask.cpu() for e in e50])

ev200 = np.std(em200, axis=0)
ev50 = np.std(em50, axis=0)

for var50, var200 in zip(ev50, ev200):
  print(f'{var50:.4f} {var200:.4f}')

**conclusion:**