[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hesther/rxn_workshop/blob/main/workshop_exercise.ipynb)

# GCNNs for molecules and reactions

Welcome to the workshop! This is a nearly-empty notebook that we will complete together during the workshop. If you are reading through this notebook later, open the workshop_solution.ipynb instead.

Let's install and import all packages we will need

In [None]:
!pip install -q rdkit numpy scikit-learn chemprop torch==2.0.1
!pip install -q torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
!pip install -q torch_geometric

In [None]:
from rdkit import Chem
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F

import torch_geometric as tg
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool

from sklearn.metrics import mean_absolute_error, mean_squared_error

The following cell contains some pre-made functions to obtain atom and bond features. You can always customize these functions or use entirely different ones.

In [None]:
def atom_features(atom):
    features = onek_encoding_unk(atom.GetSymbol(), ['H', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'Br', 'I']) + \
        onek_encoding_unk(atom.GetTotalDegree(), [0, 1, 2, 3, 4, 5]) + \
        onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) + \
        onek_encoding_unk(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4]) + \
        onek_encoding_unk(int(atom.GetHybridization()),[Chem.rdchem.HybridizationType.SP,
                                                        Chem.rdchem.HybridizationType.SP2,
                                                        Chem.rdchem.HybridizationType.SP3,
                                                        Chem.rdchem.HybridizationType.SP3D,
                                                        Chem.rdchem.HybridizationType.SP3D2
                                                        ]) + \
        [1 if atom.GetIsAromatic() else 0] + \
        [atom.GetMass() * 0.01]
    return features

def bond_features(bond):
    bond_fdim = 7

    if bond is None:
        fbond = [1] + [0] * (bond_fdim - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
    return fbond

def onek_encoding_unk(value, choices):
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1
    return encoding

def make_mol(smi):
    params = Chem.SmilesParserParams()
    params.removeHs = False
    return Chem.MolFromSmiles(smi,params)