# Preamble
This section downloads and install a conda environment on your colab virtual machine, which makes dealing with dependencies a lot easier

## Installing condacolab
We will use a package called condacolab which will download an initialize a conda environment which we can later install packages to

In [1]:
#Preamble, this will install a conda (mamba actually) environment on your Colab VM, which really makes working with complex dependencies (RDKit in this case) easier.
# YOU WILL LIKELY GET A NOTIFICATION ABOUT YOUR SESSION CRASHING, THIS IS EXPECTED BEHAVIOUR (the install will restart the python kernel). 
# Wait until this cell is done before running the rest of the notebook.
!pip install -q condacolab
import condacolab
condacolab.install()

✨🍰✨ Everything looks OK!


In [2]:
# Check that we now have a working conda environment. You should get the output "Everything looks OK!"
import condacolab
condacolab.check()

✨🍰✨ Everything looks OK!


## Installing the notebook packages
Now that condacolab is up and running, we install the packages. To install the correct version of cuda toolkits  for pytorch, we first check which version of cuda is installed on the VM. If the last line is not `cuda_11.1.xxx`, you need to change the mamba install line to match the version you see from `nvcc`

In [3]:
!nvcc --version  # Check what cuda version is installed, this must match the cudatoolkit=XY.Z we give to the mamba install line

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0


In [4]:
# Install the required packages. Note that condacolab uses mamba (a conda 
# reimplementation) by default. This will likely take between 5-10 minutes.
!mamba install pytorch cudatoolkit=11.1 rdkit -c pytorch -c conda-forge > /dev/null
print("Done installing packages")

Done installing packages


# Graph Neural Network Encoder

This code repeats the code from the GNN Encoder notebook

In [5]:
from collections import defaultdict
from collections.abc import Set

import rdkit
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG=True  #< set this to False if you want PNGs instead of SVGs
IPythonConsole.drawOptions.addAtomIndices = True  # This will help when looking at the Mol graph representation
IPythonConsole.molSize = 600, 600

import torch
from torch.utils.data import Dataset, DataLoader

float_type = torch.float32  # We're hardcoding types in the tensors further down
categorical_type = torch.long
mask_type = torch.float32  # We're going to be multiplying our internal calculations with a mask using this type
labels_type = torch.float32 # We're going to use BCEWithLogitsLoss, which expects the labels to be of the same type as the predictions

In [6]:
class ContinuousFeature:
  def __init__(self, name):
    self.name = name

  def __repr__(self):
    return f'<ContinuousFeature: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name

  def __hash__(self):
    return hash(self.name)

class CategoricalFeature:
  def __init__(self, name, values, add_null_value=True):
    self.name = name
    self.has_null_value = add_null_value
    if self.has_null_value:
      self.null_value = None
      values = (None,) + tuple(values)
    self.values = tuple(values)
    self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
    self.inv_value_to_idx_mapping = {i: v for v, i in self.value_to_idx_mapping.items()}
    
    if self.has_null_value:
      self.null_value_idx = self.value_to_idx_mapping[self.null_value]
  
  def get_null_idx(self):
    if self.has_null_value:
      return self.null_value_idx
    else:
      raise RuntimeError(f"Categorical variable {self.name} has no null value")

  def value_to_idx(self, value):
    return self.value_to_idx_mapping[value]
  
  def idx_to_value(self, idx):
    return self.inv_value_to_idx_mapping[idx]
  
  def __len__(self):
    return len(self.values)
  
  def __repr__(self):
    return f'<CategoricalFeature: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name and self.values == other.values

  def __hash__(self):
    return hash((self.name, self.values))

In [7]:
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 
                'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 
                'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 
                'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 
                'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 
                'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 
                'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg', 
                'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce', 'Pr', 
                'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 
                'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 
                'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']
ATOM_SYMBOLS_FEATURE = CategoricalFeature('atom_symbol', ATOM_SYMBOLS)

ATOM_AROMATIC_VALUES = [True, False]
ATOM_AROMATIC_FEATURE = CategoricalFeature('is_aromatic', ATOM_AROMATIC_VALUES)

# In practice you might like to use categroical features for valence, but we use continuous here for demonstration
ATOM_EXPLICIT_VALENCE_FEATURE = ContinuousFeature('explicit_valence')

ATOM_IMPLICIT_VALENCE_FEATURE = ContinuousFeature('implicit_valence')

ATOM_FEATURES = [ATOM_SYMBOLS_FEATURE, ATOM_AROMATIC_FEATURE, ATOM_EXPLICIT_VALENCE_FEATURE, ATOM_IMPLICIT_VALENCE_FEATURE]

def get_atom_features(rd_atom):
  atom_symbol = rd_atom.GetSymbol()
  is_aromatic = rd_atom.GetIsAromatic()
  implicit_valence = float(rd_atom.GetImplicitValence())
  explicit_valence = float(rd_atom.GetExplicitValence())
  return {ATOM_SYMBOLS_FEATURE: atom_symbol,
          ATOM_AROMATIC_FEATURE: is_aromatic,
          ATOM_EXPLICIT_VALENCE_FEATURE: explicit_valence,
          ATOM_IMPLICIT_VALENCE_FEATURE: implicit_valence}

In [8]:
# We could use the RDKit enumeration types instead of strings, but the advantage
# of doing it like this is that our representation becomes independent of RDKit
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE', 
              'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
              'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC', 
              'IONIC', 'HYDROGEN', 'THREECENTER',	'DATIVEONE', 'DATIVE',
              'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']
TYPE_FEATURE = CategoricalFeature('bond_type', BOND_TYPES)

BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT', 'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalFeature('bond_direction', BOND_DIRECTIONS)

BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE', 
               'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalFeature('bond_stereo', BOND_STEREO)

AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalFeature('is_aromatic', AROMATIC_VALUES)

BOND_FEATURES = [TYPE_FEATURE, DIRECTION_FEATURE, AROMATIC_FEATURE, STEREO_FEATURE]

def get_bond_features(rd_bond):
  bond_type = str(rd_bond.GetBondType())
  bond_stereo_info = str(rd_bond.GetStereo())
  bond_direction = str(rd_bond.GetBondDir())
  is_aromatic = rd_bond.GetIsAromatic()
  return {TYPE_FEATURE: bond_type,
          DIRECTION_FEATURE: bond_direction,
          AROMATIC_FEATURE: is_aromatic,
          STEREO_FEATURE: bond_stereo_info}

In [9]:
def rdmol_to_graph(mol):
  atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
  bonds = {frozenset((rd_bond.GetBeginAtomIdx(), rd_bond.GetEndAtomIdx())): get_bond_features(rd_bond) for rd_bond in mol.GetBonds()}
  return atoms, bonds

In [10]:
def smiles_to_graph(smiles):
  rd_mol = MolFromSmiles(smiles)
  graph = rdmol_to_graph(rd_mol)
  return graph

In [11]:
g = smiles_to_graph('c1ccccc1')

In [12]:
class GraphDataset(Dataset):
  def __init__(self, *, graphs, labels, node_variables, edge_variables, metadata=None):
    '''
    Create a new graph dataset, 
    '''
    self.graphs = graphs
    self.labels = labels
    assert len(self.graphs) == len(self.labels), "The graphs and labels lists must be the same length"
    self.metadata = metadata
    if self.metadata is not None:
      assert len(self.metadata) == len(self.graphs), "The metadata list needs to be as long as the graphs"
    self.node_variables = node_variables
    self.edge_variables = edge_variables
    self.categorical_node_variables = [var for var in self.node_variables if isinstance(var, CategoricalFeature)]
    self.continuous_node_variables = [var for var in self.node_variables if isinstance(var, ContinuousFeature)]
    self.categorical_edge_variables = [var for var in self.edge_variables if isinstance(var, CategoricalFeature)]
    self.continuous_edge_variables = [var for var in self.edge_variables if isinstance(var, ContinuousFeature)]

  def __len__(self):
    return len(self.graphs)

  def make_continuous_node_features(self, nodes):
    if len(self.continuous_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.continuous_node_variables)
    continuous_node_features = torch.zeros((n_nodes, n_features), dtype=float_type)
    for node_idx, features in nodes.items():
      node_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_node_variables], dtype=float_type)
      continuous_node_features[node_idx] = node_features
    return continuous_node_features
      
  def make_categorical_node_features(self, nodes):
    if len(self.categorical_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.categorical_node_variables)
    categorical_node_features = torch.zeros((n_nodes, n_features), dtype=categorical_type)

    for node_idx, features in nodes.items():
      for i, categorical_variable in enumerate(self.categorical_node_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_node_features[node_idx, i] = value_index

    return categorical_node_features

  def make_continuous_edge_features(self, n_nodes, edges):
    if len(self.continuous_edge_variables) == 0:
      return None
    n_features = len(self.continuous_edge_variables)
    continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=float_type)
    for edge, features in edges.items():
      edge_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_edge_variables], dtype=float_type)
      u,v = edge
      continuous_edge_features[u, v] = edge_features
      if isinstance(edge, Set):
        continuous_edge_features[v, u] = edge_features

    return continuous_edge_features

  def make_categorical_edge_features(self, n_nodes, edges):
    if len(self.categorical_edge_variables) == 0:
      return None
    n_features = len(self.categorical_edge_variables)
    categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=categorical_type)

    for edge, features in edges.items():
      u,v = edge
      for i, categorical_variable in enumerate(self.categorical_edge_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_edge_features[u, v, i] = value_index
          if isinstance(edge, Set):
            categorical_edge_features[v, u, i] = value_index

    return categorical_edge_features
  
  def __getitem__(self, index):
    # This is where the important stuff happens. We use our node and 
    # edge variable attributes to select what node and edge features to use.
    # In practice, we often do this as a pre-processing step, but here we do it 
    # in the getitem function for clarity

    graph = self.graphs[index]
    nodes, edges = graph
    n_nodes = len(nodes)
    continuous_node_features = self.make_continuous_node_features(nodes)
    categorical_node_features = self.make_categorical_node_features(nodes)
    continuous_edge_features = self.make_continuous_edge_features(n_nodes, edges)
    categorical_edge_features = self.make_categorical_edge_features(n_nodes, edges)

    label = self.labels[index]

    nodes_idx = sorted(nodes.keys())
    edge_list = sorted(edges.keys())

    n_nodes = len(nodes)
    adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
    for edge in edges:
      u, v = edge
      adjacency_matrix[u,v] = 1
      if isinstance(edge, Set):
        # This edge is unordered, assume this is a undirected graph
        adjacency_matrix[v,u] = 1

    adjacency_list = defaultdict(list)
    for edge in edges:
      u,v = edge
      adjacency_list[u].append(v)
      # Assume undirected graph is the edge is a set
      if isinstance(edge, Set):
        adjacency_list[v].append(u)

    data_record = {'nodes': nodes_idx,
                   'adjacency_matrix': adjacency_matrix,
                   'adjacency_list': adjacency_list,
                   'categorical_node_features': categorical_node_features,
                   'continuous_node_features': continuous_node_features,
                   'categorical_edge_features': categorical_edge_features,
                   'continuous_edge_features': continuous_edge_features,
                   'label': label}

    # If you need to add extra information (metadata about this graph) you can 
    # add an extra key-value pair here. The advantage of using a dict compared 
    # to a tuple is that the downstreams code doesn't break as long as at least 
    # the expected keys are present. The downside is that using a dict adds 
    # overhead (accessing a dict compared to unpacking a tuple).
    # A more robust implementation might actually make a separate class for 
    # dataset entires
    if self.metadata is not None:
      data_record['metadata'] = self.metadata[index]
    return data_record

  def get_node_variables(self):
    return {'continuous': self.continuous_node_variables,
            'categorical': self.categorical_node_variables}
  
  def get_edge_variables(self):
    return {'continuous': self.continuous_edge_variables,
            'categorical': self.categorical_edge_variables}

In [13]:
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES, bond_features=BOND_FEATURES):
  '''
  Create a new GraphDataset from a list of smiles_records dictionaries.
  These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
  '''
  graphs = []
  labels = []
  metadata = []
  for smiles_record in smiles_records:
    smiles = smiles_record['smiles']
    label = smiles_record['label']
    graph = smiles_to_graph(smiles)
    graphs.append(graph)
    labels.append(label)
    metadata.append(smiles_record)
  return GraphDataset(graphs=graphs, 
                      labels=labels, 
                      node_variables=atom_features, 
                      edge_variables=bond_features, 
                      metadata=metadata)
  

In [14]:
import torch
from torch.nn import Embedding, Module, ModuleList

In [15]:
class Embedder(Module):
  def __init__(self, categorical_variables, embedding_dim):
    super().__init__()
    self.categorical_variables = categorical_variables
    embeddings = []
    for var in categorical_variables:
      num_embeddings = len(var)
      if var.has_null_value:
        # It's not uncommon to have missing values, we support this assinging a special 0-index which have the zero-vector as its embedding
        embedding = Embedding(num_embeddings, embedding_dim, padding_idx=var.get_null_idx())
      else:
        embedding = Embedding(num_embeddings, embedding_dim)
      embeddings.append(embedding)
    self.embeddings = ModuleList(embeddings)
    
  
  def forward(self, categorical_features):
    # The node features is a matrix with as many rows as nodes of our graph
    # and as many columns as we have categorical features
    all_embedded_vars = []
    for i, embedding in enumerate(self.embeddings):
      # We pick out just the i'th column. The ellipsis '...' in a numpy-style 
      # slice is a useful way of saying you want full range over all other axises
      # We use it so that this can actually take a categorical_features array
      # with arbitrary number of trailing axises to support both the node 
      # features, the edge features and the mini-batched version of both
      var_indices = categorical_features[..., i]  
      embedded_vars = embedding(var_indices)
      all_embedded_vars.append(embedded_vars)

    # If you like, you can implement concatenation instead of sum here
    stacked_embedded_vars = torch.stack(all_embedded_vars, dim=0)
    embedded_vars = torch.sum(stacked_embedded_vars, dim=0)
    return embedded_vars

In [16]:
class FeatureCombiner(Module):
  def __init__(self, categorical_variables, embedding_dim):
    super().__init__()
    self.categorical_variables = categorical_variables
    self.embedder = Embedder(self.categorical_variables, embedding_dim)
    
  def forward(self, continuous_features, categorical_features, ):
    # We need to be agnostic to whether we have categorical features and continuous features (it's not uncommon to only use one kind)
    features = []
    if categorical_features is not None:
      embedded_features = self.embedder(categorical_features)
      features.append(embedded_features)
      # The embedded features are now of shape (n_nodes, embedding_dim)
    if continuous_features is not None:
      features.append(continuous_features)
    if len(features) == 0:
      raise RuntimeError('No features to combine')
    full_features = torch.cat(features, dim=-1)  # Now we concatenate along the feature dimension
    return full_features


In [17]:
from collections.abc import Set # We assume that edges as sets are for undirected graphs

def collate_graph_batch(batch):
  '''Collate a batch of graph dictionaries produdce by a GraphDataset'''
  batch_size = len(batch)

  max_nodes = max(len(graph['nodes']) for graph in batch)
  
  # We start by allocating the tensors we'll use. We defer allocating feature
  # tensors until we know the graphs actually has those kinds of features.
  adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes), dtype=float_type)
  labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
  stacked_continuous_node_features = None
  stacked_categorical_node_features = None
  stacked_continuous_edge_features = None
  stacked_categorical_edge_features = None

  nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
  edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)
  
  has_metadata = False

  for i, graph in enumerate(batch):
    if 'metadata' in graph:
      has_metadata = True
    # We'll take basic information about the different graphs from the adjacency 
    # matrix
    adjacency_matrix = graph['adjacency_matrix']
    g_nodes, g_nodes = adjacency_matrix.shape
    adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix

    # Now when we know how many of the entries are valid, we set those to 1s in
    # the masks
    edge_mask[i, :g_nodes, :g_nodes] = 1
    nodes_mask[i, :g_nodes] = 1
    

    # All the feature constructions follow the same recipie. We essentially
    # locate the entries in the stacked feature tensor (containing all graphs)
    # and set it with the features from the current graph.
    g_continuous_node_features = graph['continuous_node_features']
    if g_continuous_node_features is not None:
      if stacked_continuous_node_features is None:
        g_nodes, num_features = g_continuous_node_features.shape
        stacked_continuous_node_features = torch.zeros((batch_size, max_nodes, num_features))
      stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features
    
    g_categorical_node_features = graph['categorical_node_features']
    if g_categorical_node_features is not None:
      if stacked_categorical_node_features is None:
        g_nodes, num_features = g_categorical_node_features.shape
        stacked_categorical_node_features = torch.zeros((batch_size, max_nodes, num_features), dtype=categorical_type)
      stacked_categorical_node_features[i, :g_nodes] = g_categorical_node_features

    g_continuous_edge_features = graph['continuous_edge_features']
    if g_continuous_edge_features is not None:
      if stacked_continuous_edge_features is None:
        g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
        stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features))
      stacked_continuous_edge_features[i, :g_nodes, :g_nodes] = g_continuous_edge_features

    g_categorical_edge_features = graph['categorical_edge_features']
    if g_categorical_edge_features is not None:
      if stacked_categorical_edge_features is None:
        g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
        stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features), dtype=categorical_type)
      stacked_categorical_edge_features[i, :g_nodes, :g_nodes] = g_categorical_edge_features


  batch_record = {'adjacency_matrices': adjacency_matrices,
          'categorical_node_features': stacked_categorical_node_features,
          'continuous_node_features': stacked_continuous_node_features,
          'categorical_edge_features': stacked_categorical_edge_features,
          'continuous_edge_features': stacked_continuous_edge_features,
          'nodes_mask': nodes_mask,
          'edge_mask': edge_mask,
          'labels': labels}
  if has_metadata:
    batch_record['metadata'] = [g['metadata'] for g in batch]

  return batch_record


In [18]:
from torch.nn import Module, Embedding, ModuleList, Linear, Sequential, ReLU, LayerNorm, Dropout
from torch.nn.functional import layer_norm, dropout

class BasicGNNConfig:
  def __init__(self, *, 
               d_model: int, 
               n_layers: int, 
               ffn_dim: int):
    self.d_model = d_model
    self.n_layers = n_layers
    self.ffn_dim = ffn_dim
    
class BasicGraphLayer(Module):
  def __init__(self, config):
    super().__init__()    
    self.config = config
    self.input_dim = config.d_model
    self.output_dim = config.d_model
    self.ffn_dim = config.ffn_dim
    self.neighbour_edges_mlp = Sequential(Linear(self.input_dim, self.ffn_dim), 
                                          ReLU(), 
                                          Linear(self.ffn_dim, self.output_dim))
    self.center_mlp = Sequential(Linear(self.input_dim, self.ffn_dim), 
                                 ReLU(), 
                                 Linear(self.ffn_dim, self.output_dim))
    self.output_mlp = Sequential(Linear(self.input_dim, self.ffn_dim), 
                                 ReLU(), 
                                 Linear(self.ffn_dim, self.output_dim))
  
  def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
    center_updated_node_features = self.center_mlp(node_features)
    edge_and_node_features = edge_features + node_features.unsqueeze(dim=-2)
    neighbourhood = self.neighbour_edges_mlp(edge_and_node_features)

    masked_edge_and_node_features = neighbourhood * adjacency_matrix.unsqueeze(dim=-1)
    masked_edge_and_node_features = masked_edge_and_node_features * edge_mask.unsqueeze(dim=-1)

    reduced_neighbourhoods = masked_edge_and_node_features.sum(dim=2)
    
    aggregated_neighbourhoods = reduced_neighbourhoods + center_updated_node_features
    updated_node_features = self.output_mlp(aggregated_neighbourhoods)
    masked_updated_features = updated_node_features * node_mask.unsqueeze(dim=-1)
    return masked_updated_features


class BasicGraphEncoder(torch.nn.Module):
  def __init__(self, *,
               config: BasicGNNConfig, 
               continuous_node_variables=None,
               categorical_node_variables=None,
               continuous_edge_variables=None,
               categorical_edge_variables=None,
               layer_type=BasicGraphLayer):
    super().__init__()

    self.config = config
    self.layer_type = layer_type

    self.continuous_node_variables = continuous_node_variables
    self.categorical_node_variables = categorical_node_variables
    self.continuous_edge_variables = continuous_edge_variables
    self.categorical_edge_variables = categorical_edge_variables

    # We want the embeddings together with the continuous values to be of dimension d_model, therefore the allocate d_model - len(continuous_variables) as the embeddings dim
    self.categorical_node_embeddings_dim = config.d_model - len(self.continuous_node_variables)
    self.categorical_edge_embeddings_dim = config.d_model - len(self.continuous_edge_variables)

    self.node_featurizer = FeatureCombiner(self.categorical_node_variables, 
                                           self.categorical_node_embeddings_dim)
    self.edge_featurizer = FeatureCombiner(self.categorical_edge_variables, 
                                           self.categorical_edge_embeddings_dim)
    
    # Notice that we use the supplied layer type above when creating the graph
    # layers. This allows us to easily change the kind of graph layers
    # we use later on
    self.graph_layers = ModuleList([layer_type(config) for l in range(config.n_layers)])
    
  def forward(self, batch):
    # First order of business is to embed the node embeddings
    node_mask = batch['nodes_mask']
    batch_size, max_nodes = node_mask.shape
    
    continuous_node_features = batch['continuous_node_features']
    categorical_node_features = batch['categorical_node_features']
    node_features = self.node_featurizer(continuous_node_features, categorical_node_features)
    masked_node_features = node_features * node_mask.unsqueeze(-1)
    
    continuous_edge_features = batch['continuous_edge_features']
    categorical_edge_features = batch['categorical_edge_features']
    edge_features = self.edge_featurizer(continuous_edge_features, categorical_edge_features)
    edge_mask = batch['edge_mask']
    masked_edge_features = edge_features * edge_mask.unsqueeze(-1)

    # We have now embedded the node features, we'll propagate them through our 
    # graph layers
    adjacency_matrix = batch['adjacency_matrices']
    memory_state = masked_node_features
    for l in self.graph_layers:
      memory_state = l(adjacency_matrix, memory_state, masked_edge_features , node_mask, edge_mask)

    return memory_state

# A complete GNN
Now that we've seen how a Graph Neural Network encoder works, we will go through how to combine it with a prediction head to perform graph prediction .We'll use the same molecules as in the previous notebook (the BBBP dataset).

Your task will be to take the GNN Encoder defined above and combine it with a prediction head to train the network. See the _task_ section below


## The dataset
We will be using the same BBBP dataset as in the preivous notebook

In [19]:
# Well start by downloading the dataset. We're relying on the direct download link from MoleculeNet
! wget -q https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv

In [20]:
import pandas as pd
from collections import defaultdict, deque # We'll use this to construct the dataset splits
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmilesFromSmiles

In [21]:
bbbp_table = pd.read_csv('BBBP.csv')
bbbp_table

Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...
...,...,...,...,...
2045,2049,licostinel,1,C1=C(Cl)C(=C(C2=C1NC(=O)C(N2)=O)[N+](=O)[O-])Cl
2046,2050,ademetionine(adenosyl-methionine),1,[C@H]3([N]2C1=C(C(=NC=N1)N)N=C2)[C@@H]([C@@H](...
2047,2051,mesocarb,1,[O+]1=N[N](C=C1[N-]C(NC2=CC=CC=C2)=O)C(CC3=CC=...
2048,2052,tofisoline,1,C1=C(OC)C(=CC2=C1C(=[N+](C(=C2CC)C)[NH-])C3=CC...


We'll now filter the dataset and create the kind of data our dataset building function expects

In [22]:
# There are about 50 problematic SMILES, RDKit will give you about as many ERRORs 
# and WARNINGs, this is expected behaviour. We supress these warnings here
# 11 of the SMILES can't be parsed at all
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')  

smiles_records = []
for i, num, name, p_np, smiles in bbbp_table.to_records():
  # check if RDKit accepts this smiles
  if MolFromSmiles(smiles) is not None:
    smiles_record = {'smiles': smiles, 'label': p_np, 'metadata': {'row': i}}
    smiles_records.append(smiles_record)
  else:
    print(f'Molecule {smiles} on row {i} could not be parsed by RDKit')
  

Molecule O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3 on row 59 could not be parsed by RDKit
Molecule c1(nc(NC(N)=[NH2])sc1)CSCCNC(=[NH]C#N)NC on row 61 could not be parsed by RDKit
Molecule Cc1nc(sc1)\[NH]=C(\N)N on row 391 could not be parsed by RDKit
Molecule s1cc(CSCCN\C(NC)=[NH]\C#N)nc1\[NH]=C(\N)N on row 614 could not be parsed by RDKit
Molecule c1c(c(ncc1)CSCCN\C(=[NH]\C#N)NCC)Br on row 642 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1ccccc1 on row 645 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N on row 646 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)NC(C)=O on row 647 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N\C(NC)=[NH]\C#N on row 648 could not be parsed by RDKit
Molecule s1cc(nc1\[NH]=C(\N)N)C on row 649 could not be parsed by RDKit
Molecule c1(cc(N\C(=[NH]\c2cccc(c2)CC)C)ccc1)CC on row 685 could not be parsed by RDKit


We will create some "canonical" splits we can use in the remainder of the notebook. These are just random selections of the molecules.

In [23]:
import random
random.seed(1729)

training_fraction = 0.8
dev_fraction = 0.1
n_examples = len(smiles_records)
n_training_examples = int(n_examples*training_fraction)
n_dev_examples = int(n_examples*dev_fraction)

indices = list(range(n_examples))
random.shuffle(indices)  # shuffle is in place
training_indices = indices[:n_training_examples]
dev_indices = indices[n_training_examples:n_training_examples+n_dev_examples]
test_indices = indices[n_training_examples+n_dev_examples:]

training_smiles_records = [smiles_records[i] for i in training_indices]
dev_smiles_records = [smiles_records[i] for i in dev_indices]
test_smiles_records = [smiles_records[i] for i in test_indices]

In [24]:
training_graph_dataset = make_molecular_graph_dataset(training_smiles_records)
dev_graph_dataset = make_molecular_graph_dataset(dev_smiles_records)
test_graph_dataset = make_molecular_graph_dataset(test_smiles_records)

## The training loop
To make training models a bit easier, we encapsulate most of the training loop and its state in a Trainer class.

In [25]:
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score

In [26]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
print("Device is", device)

Device is cpu


In [27]:
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss, MSELoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error, median_absolute_error
import matplotlib.pyplot as plt
import seaborn as sns

def batch_to_device(batch, device):
  moved_batch = {}
  for k, v in batch.items():
    if torch.is_tensor(v):
      v = v.to(device)
    moved_batch[k] = v
  return moved_batch

class Trainer:
  def __init__(self, *, model, 
               loss_fn, training_dataloader, 
               dev_dataloader, device=device):
    self.model = model
    self.training_dataloader = training_dataloader
    self.dev_dataloader = dev_dataloader
    self.device = device
    self.model.to(device)
    self.total_epochs = 0
    self.optimizer = AdamW(self.model.parameters(), lr=1e-4)
    self.loss_fn = loss_fn

  def train(self, epochs):
    with trange(epochs, desc='Epoch', position=0) as epoch_progress:
      batches_per_epoch = len(self.training_dataloader) + len(self.dev_dataloader)
      for epoch in epoch_progress:
        train_loss = 0
        train_n = 0
        for i, training_batch in enumerate(tqdm(self.training_dataloader, desc='Training batch', leave=False)):
          self.optimizer.zero_grad()
          # Move all tensors to the device
          self.model.train()
          training_batch = batch_to_device(training_batch, self.device)
          prediction = self.model(training_batch)
          labels = training_batch['labels']
          loss = self.loss_fn(prediction.squeeze(), labels) # By default the predictions have shape (batch_size, 1)
          loss.backward()
          self.optimizer.step()
          batch_n = len(labels)
          train_loss += batch_n * loss.cpu().item()
          train_n += batch_n
        #print(f"Training loss for epoch {total_epochs}", train_loss/train_n)
        self.total_epochs += 1

        dev_predictions = []
        dev_labels = []
        dev_n = 0
        dev_loss = 0
        for i, dev_batch in enumerate(tqdm(self.dev_dataloader, desc="Dev batch", leave=False)):
          self.model.eval()
          with torch.no_grad():
            dev_batch = batch_to_device(dev_batch, self.device)
            prediction = self.model(dev_batch).squeeze()
            dev_predictions.extend(prediction.tolist())
            labels = dev_batch['labels']
            dev_labels.extend(labels.tolist())
            loss = self.loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
            batch_n = len(labels)
            dev_loss += batch_n*loss.cpu().item()
            dev_n += batch_n
        epoch_progress.set_description(f"Epoch: train loss {train_loss/train_n: .3f}, dev loss {dev_loss/dev_n: .3f}")


def evaluate_model(trainer, dataloader, label=None, hue_order=[0,1]):
  eval_predictions = []
  eval_labels = []
  eval_loss = 0
  eval_n = 0
  model = trainer.model
  loss_fn = trainer.loss_fn
  total_epochs = trainer.total_epochs
  for i, eval_batch in enumerate(tqdm(dataloader, desc='batch')):
    model.eval()
    with torch.no_grad():
      eval_batch = batch_to_device(eval_batch, device)
      prediction = model(eval_batch).squeeze()
      eval_predictions.extend(prediction.tolist())
      labels = eval_batch['labels']
      eval_labels.extend(labels.tolist())
      loss = loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
      batch_n = len(labels)
      eval_loss += batch_n*loss.cpu().item()
      eval_n += batch_n
  average_loss = eval_loss/eval_n
  roc_auc = roc_auc_score(eval_labels, eval_predictions)
  eval_df = pd.DataFrame(data={'target': eval_labels, 'predictions': eval_predictions})
  sns.kdeplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
  sns.rugplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
  
  if label is not None:
    title = f"{label} dataset after {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
  else:
    title = f"After {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
  plt.title(title)

# Task 1

Now that we've prepared the dataset and GNN Encoder, your task is flesh out the GraphPredictionHead and combine it with the GNNEncoder to train a network for Graph Prediction on the BBBP dataset. The main challange is to take the output of the Encoder, a tensor with a `max_nodes` axis and reduce this to having only a single feature vector per graph in the batch. Look to the final task of the previous notebook.

**Implement the reductions of the output of the Encoder and combine the encoder with the prediction head in the `GraphPredictionNeuralNetwork` class and train the neural network**

In [28]:
class GraphPredictionHeadConfig(BasicGNNConfig):
  def __init__(self, *, pooling_type='sum', **kwargs):
    # Pooling type can be 'sum' or 'mean'
    super().__init__(**kwargs)
    self.pooling_type = pooling_type

class GraphPredictionHead(Module):
  def __init__(self, input_dim, output_dim, config):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.config = config
    self.predictor = Sequential(Linear(self.input_dim, self.config.ffn_dim), 
                                ReLU(), 
                                Linear(self.config.ffn_dim, self.output_dim))

  def forward(self, node_features, node_mask):
    # You need to take the node_features, a tensor of shape 
    # (*leading_axises, max_nodes, d_model) and reduce this either by 'sum' or
    # 'mean'. You can assume these features are valid, i.e. that they have been
    # masked at a previous step. Assign the result to the name *pooled_nodes*
    if self.config.pooling_type == 'sum':
      # Implement 'sum' based pooling here
      pooled_nodes = ...
    elif self.config.pooling_type == 'mean':
      # Implement 'mean' based pooling here
      ...
    else:
      raise ValueError(f'Unsupported pooling type {self.config.pooling_type}')
    
    prediction = self.predictor(pooled_nodes)
    return prediction

In [29]:
class GraphPredictionNeuralNetwork(Module):
  def __init__(self, config, dataset, 
               output_dim=1,
               graph_layer_type=BasicGraphLayer):
    super().__init__()
    self.config = config
    self.encoder = BasicGraphEncoder(config=config,
                                     continuous_node_variables=dataset.continuous_node_variables,
                                     categorical_node_variables=dataset.categorical_node_variables,
                                     continuous_edge_variables=dataset.continuous_edge_variables,
                                     categorical_edge_variables=dataset.categorical_edge_variables,
                                     layer_type=graph_layer_type)
    
    self.prediction_head = GraphPredictionHead(input_dim=config.d_model, 
                                               output_dim=output_dim, config=config)
  def forward(self, batch):
    # Here you should combine the BasicGraphEncoder with the GraphPredictionHead.
    # assign the output of the head to the name *prediction*
    ...
    return prediction

In [30]:
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_dataloader = DataLoader(training_graph_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
torch.manual_seed(1729)
config = GraphPredictionHeadConfig(d_model=32, n_layers=4,
                        ffn_dim=32, 
                        pooling_type='sum')
model = GraphPredictionNeuralNetwork(config=config, dataset=training_graph_dataset)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)

In [31]:
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_dataloader = DataLoader(training_graph_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
torch.manual_seed(1729)
config = GraphPredictionHeadConfig(d_model=32, n_layers=4,
                        ffn_dim=32, 
                        pooling_type='sum')
model = GraphPredictionNeuralNetwork(config=config, dataset=training_graph_dataset)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)

In [None]:
trainer.train(2)

#Task 2

It has been observed that Graph Neural Networks suffer from depth, the neighbourhood aggregation ends ups diffusing the information in the network.

A simple remedy of this is to use residual connections. This is a way of constructing deep networks where the output of a transforming layer (e.g. a linear layer followed by a nonlinear function) is just added to the input to that layer. If the transform block is an MLP ($f_{MLP}()$) we could write it like:

$$
f_{residual}(\mathbf{x}) = f_{MLP}(\mathbf{x}) + \mathbf{x}
$$

This makes GNNs not suffer from these depth issues

**Task: Implement residual connections in the GraphEncoder. You can do it in the `forward()` of the Encoder or in
the GNN layer. Note that we should do this for all GNN layers**




**Does the performance of the network change when you add residual connections?**

**Note: for this to work, the dimensionality of the input vector must be the same as the output of the layer**

**Either** implement your change in the `ResdiualGraphLayer`or the `ResidualGraphEncoder` below.

## In the graph layer

In [32]:
from torch.nn import Module, Embedding, ModuleList, Linear, Sequential, ReLU, LayerNorm, Dropout
from torch.nn.functional import layer_norm, dropout

class ResidualGraphLayer(BasicGraphLayer):
  def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
    # Implement the residual connection here

    center_updated_node_features = self.center_mlp(node_features)
    edge_and_node_features = edge_features + node_features.unsqueeze(dim=-2)
    neighbourhood = self.neighbour_edges_mlp(edge_and_node_features)

    masked_edge_and_node_features = neighbourhood * adjacency_matrix.unsqueeze(dim=-1)
    masked_edge_and_node_features = masked_edge_and_node_features * edge_mask.unsqueeze(dim=-1)

    reduced_neighbourhoods = masked_edge_and_node_features.sum(dim=2)
    
    aggregated_neighbourhoods = reduced_neighbourhoods + center_updated_node_features
    updated_node_features = self.output_mlp(aggregated_neighbourhoods)
    masked_updated_features = updated_node_features * node_mask.unsqueeze(dim=-1)
    return masked_updated_features

In [33]:
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_dataloader = DataLoader(training_graph_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
torch.manual_seed(1729)
config = GraphPredictionHeadConfig(d_model=32, n_layers=4,
                        ffn_dim=32, 
                        pooling_type='sum')
model = GraphPredictionNeuralNetwork(config, training_graph_dataset, graph_layer_type=ResidualGraphLayer)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)

In [None]:
trainer.train(2)

## In the encoder

In [34]:
class ResidualGraphEncoder(BasicGraphEncoder):
  def forward(self, batch):
    
    node_mask = batch['nodes_mask']
    batch_size, max_nodes = node_mask.shape
    
    continuous_node_features = batch['continuous_node_features']
    categorical_node_features = batch['categorical_node_features']
    node_features = self.node_featurizer(continuous_node_features, categorical_node_features)
    masked_node_features = node_features * node_mask.unsqueeze(-1)
    
    continuous_edge_features = batch['continuous_edge_features']
    categorical_edge_features = batch['categorical_edge_features']
    edge_features = self.edge_featurizer(continuous_edge_features, categorical_edge_features)
    edge_mask = batch['edge_mask']
    masked_edge_features = edge_features * edge_mask.unsqueeze(-1)

    adjacency_matrix = batch['adjacency_matrices']
    memory_state = masked_node_features
    for l in self.graph_layers:
      # Implement the residual connection here
      memory_state = l(adjacency_matrix, memory_state, masked_edge_features , node_mask, edge_mask)

    return memory_state

class ResidualGraphPredictionNeuralNetwork(GraphPredictionNeuralNetwork):
  def __init__(self, config, dataset, 
               output_dim=1,
               graph_layer_type=BasicGraphLayer):
    super().__init__(config, dataset)
    self.config = config
    self.encoder = ResidualGraphEncoder(config=config,
                                     continuous_node_variables=dataset.continuous_node_variables,
                                     categorical_node_variables=dataset.categorical_node_variables,
                                     continuous_edge_variables=dataset.continuous_edge_variables,
                                     categorical_edge_variables=dataset.categorical_edge_variables,
                                     layer_type=graph_layer_type)
    
    self.prediction_head = GraphPredictionHead(input_dim=config.d_model, 
                                               output_dim=output_dim, config=config)


In [35]:
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_dataloader = DataLoader(training_graph_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
torch.manual_seed(1729)
config = GraphPredictionHeadConfig(d_model=32, n_layers=4,
                        ffn_dim=32, 
                        pooling_type='sum')
model = ResidualGraphPredictionNeuralNetwork(config, training_graph_dataset)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)

In [None]:
trainer.train(2)