# Setup

## Install required libraries. 

[Torchdrug](https://torchdrug.ai/) is build on top of PyTorch and tailored for drug discovery. GearNet is using Torchdrug.

In [18]:
!pip install torch
!pip install torchdrug
!pip install easydict pyyaml


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m


## Download model and data

Pre-trained models weights for GearNet are stored [here](https://zenodo.org/record/7593637). Several weigths are available, optained using different training techniques.

For the data, I chose some example protein [Free fatty acid receptor 2](https://alphafold.ebi.ac.uk/entry/O15552).

Currently, GearNet works only with `.pdb` files. To load data, you use `data.Protein.from_pdb()` method. Unfortunately, there is no `data.Protein.from_mmcif()`. Under the hood, they are using [rdkit](https://www.rdkit.org/) to parse files, but adding support for parsing `mmcif` files is still [an open issue](https://github.com/rdkit/rdkit/issues/2054).

In [3]:
!wget https://zenodo.org/record/7593637/files/mc_gearnet_edge.pth
!mkdir models
!mv mc_gearnet_edge.pth models/mc_gearnet_edge.pth

--2023-08-22 18:01:17--  https://zenodo.org/record/7593637/files/mc_gearnet_edge.pth
Resolving zenodo.org (zenodo.org)... 188.185.124.72
Connecting to zenodo.org (zenodo.org)|188.185.124.72|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 80700937 (77M) [application/octet-stream]
Saving to: ‘mc_gearnet_edge.pth’


2023-08-22 18:01:23 (14.5 MB/s) - ‘mc_gearnet_edge.pth’ saved [80700937/80700937]



In [4]:
!wget https://alphafold.ebi.ac.uk/files/AF-O15552-F1-model_v4.pdb
!mkdir data
!mv AF-O15552-F1-model_v4.pdb data/AF-O15552-F1-model_v4.pdb

--2023-08-22 18:01:28--  https://alphafold.ebi.ac.uk/files/AF-O15552-F1-model_v4.pdb
Resolving alphafold.ebi.ac.uk (alphafold.ebi.ac.uk)... 34.149.152.8
Connecting to alphafold.ebi.ac.uk (alphafold.ebi.ac.uk)|34.149.152.8|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/octet-stream]
Saving to: ‘AF-O15552-F1-model_v4.pdb’

AF-O15552-F1-model_     [ <=>                ] 213.49K  --.-KB/s    in 0.01s   

2023-08-22 18:01:28 (15.8 MB/s) - ‘AF-O15552-F1-model_v4.pdb’ saved [218618]



# Prepare data

In [28]:
from torchdrug import core, datasets, tasks, models, transforms, data, layers
from torchdrug.layers import geometry

In [21]:
transform = transforms.ProteinView(view="residue")
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

In [37]:
%%timeit

PROTEIN_PATH = './data/AF-O15552-F1-model_v4.pdb'
protein = data.Protein.from_pdb(PROTEIN_PATH, atom_feature="position", bond_feature="length", residue_feature="symbol")

with protein.residue():
    protein.residue_feature = protein.residue_feature.to_dense()
    
item = {"graph": protein}
item = transform(item)

_protein = data.Protein.pack([item['graph']])
protein = graph_construction_model(_protein)

114 ms ± 366 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Prepare model

In [14]:
import torch

WEIGHTS_PATH = './models/mc_gearnet_edge.pth'

# define model architecturemodel
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512],
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

net = torch.load(WEIGHTS_PATH)
gearnet_edge.load_state_dict(net)
gearnet_edge.eval()

GeometryAwareRelationalGraphNeuralNetwork(
  (layers): ModuleList(
    (0): GeometricRelationalGraphConv(
      (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (self_loop): Linear(in_features=21, out_features=512, bias=True)
      (linear): Linear(in_features=147, out_features=512, bias=True)
    )
    (1-5): 5 x GeometricRelationalGraphConv(
      (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (self_loop): Linear(in_features=512, out_features=512, bias=True)
      (linear): Linear(in_features=3584, out_features=512, bias=True)
    )
  )
  (spatial_line_graph): SpatialLineGraph()
  (edge_layers): ModuleList(
    (0): GeometricRelationalGraphConv(
      (batch_norm): BatchNorm1d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (self_loop): Linear(in_features=59, out_features=21, bias=True)
      (linear): Linear(in_features=472, out_features=21, bias=True)


# Compute embeddings

In [35]:
%%timeit
output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None)

1.28 s ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
output['graph_feature']

tensor([[ -1.3649, -21.1418,  14.6464,  ...,  28.3855,  21.2432,  48.4165]])

In [40]:
import pandas as pd 
    
pd.DataFrame(output['graph_feature'][0].detach().numpy()).describe()

Unnamed: 0,0
count,3072.0
mean,1221.709106
std,4182.805176
min,-3180.314941
25%,-41.301126
50%,80.202492
75%,701.187408
max,91498.945312
