# STCRpy for Graph Machine Learning Example

This notebook generates a graph dataset form TCR and TCR:pMHC structures, and then builds and trains a small equivariant graph neural network to annotate amino acids by region. 

This is an illustrative example with a few data points, rather than a fully scaled application.

As well as pytorch and pytorch-geometric you may also need to install some small helper packages: 
```
pip install glob tqdm einops
```

In [1]:
from glob import glob
import pandas as pd
import random

import torch
from torch import nn
from tqdm import tqdm
import numpy as np


from stcrpy.tcr_datasets.tcr_graph_dataset import TCRGraphDataset

Pymol package not found. 
Interaction profiler initialising without visualisation capabilitites. 
To enable pymol visualisations, install pymol with:
            
conda install -c conda-forge -c schrodinger numpy pymol-bundle




In [2]:
pdb_files = glob("../test/test_files/*.pdb") + glob("../test/test_files/*.cif")
df = pd.Series({f.split("/")[-1].split('.')[0]: f for f in pdb_files}, name="path")
df.to_csv("example_graph_data.csv")

In [None]:
input_dataset = TCRGraphDataset(data_paths="example_graph_data.csv", root="./input_graphs", edge_features="fully_connected", mhc_distance_cutoff=15.)

Processing...
Header could not be parsedCommon molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antigen
Common molecule GOL found in the binding site - not considered an antige

TCRGraphDataset(20)

In [5]:
input_dataset

TCRGraphDataset(20)

In [8]:
from torch_geometric.loader import DataLoader

from egnn import EGNN

class NodeClassificationEGNN(torch.nn.Module):
    def __init__(self, dim=32, num_classes=6, input_dim=69):
        super(NodeClassificationEGNN, self).__init__()
        self.dim = dim 
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.linear_projection = nn.Linear(self.input_dim, self.dim)
        self.layer1 = EGNN(dim=self.dim, num_nearest_neighbors=16)
        self.layer2 = EGNN(dim=self.dim, num_nearest_neighbors=16)
        self.layer3 = EGNN(dim=self.dim, num_nearest_neighbors=16)
        self.out = nn.Sequential(nn.Linear(self.dim, 16), nn.ReLU(), nn.Linear(16, self.num_classes))
        self.softmax = nn.Softmax()

    def forward(self, node_feats, coords, adj_mat=None):
        node_feats = self.linear_projection(node_feats.squeeze()).unsqueeze(0)
        node_feats, coords = self.layer1(node_feats, coords, adj_mat=adj_mat)
        node_feats, coords = self.layer2(node_feats, coords, adj_mat=adj_mat)
        node_feats, coords = self.layer3(node_feats, coords, adj_mat=adj_mat)
        node_feats = self.out(node_feats.squeeze())
        return node_feats.unsqueeze(0)


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = NodeClassificationEGNN(num_classes=4, input_dim=58)
model = model.to(device)

random.seed(0)
train_index = random.sample(list(range(len(input_dataset))), k=int(0.75 * len(input_dataset)))
val_index = [i for i in range(len(input_dataset)) if i not in train_index]
train_dl = DataLoader(input_dataset[train_index], batch_size=10)
val_dl = DataLoader(input_dataset[val_index], batch_size=10)


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([187., 237.,  62.,   9.]).sum() / torch.tensor([187., 237.,  62.,   9.]))       # Weighted by approximate node frequency

model.train()

NodeClassificationEGNN(
  (linear_projection): Linear(in_features=58, out_features=32, bias=True)
  (layer1): EGNN(
    (edge_mlp): Sequential(
      (0): Linear(in_features=65, out_features=130, bias=True)
      (1): Identity()
      (2): SiLU()
      (3): Linear(in_features=130, out_features=16, bias=True)
      (4): SiLU()
    )
    (node_norm): Identity()
    (coors_norm): Identity()
    (node_mlp): Sequential(
      (0): Linear(in_features=48, out_features=64, bias=True)
      (1): Identity()
      (2): SiLU()
      (3): Linear(in_features=64, out_features=32, bias=True)
    )
    (coors_mlp): Sequential(
      (0): Linear(in_features=16, out_features=64, bias=True)
      (1): Identity()
      (2): SiLU()
      (3): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (layer2): EGNN(
    (edge_mlp): Sequential(
      (0): Linear(in_features=65, out_features=130, bias=True)
      (1): Identity()
      (2): SiLU()
      (3): Linear(in_features=130, out_features=16, bias=Tru

In [10]:
def edge_index_to_adj_mat(edge_index, N, self_connections=False):
    if self_connections:
        adj_mat = torch.eye((N, N), dtype=bool)
    else:
        adj_mat = torch.zeros((N, N), dtype=torch.bool)
    adj_mat[*edge_index] = 1
    return adj_mat


In [11]:
from einops import rearrange
def get_distances(coors, adj_mat=None):
    rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
    rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True).sqrt()
    if adj_mat is not None:
        rel_dist[~adj_mat.unsqueeze(0)] = torch.tensor([torch.inf])
    return rel_dist

In [None]:
torch.manual_seed(0)
model.train()
for epoch in range(1000):
    epoch_loss = 0.
    for data in tqdm(train_dl):
        adj_mat = edge_index_to_adj_mat(data.edge_index, data.num_nodes)
        target = data.x[:, :4]
        optimizer.zero_grad()
        out = model(data.x[:, 11:].float().unsqueeze(0), data.pos.unsqueeze(0), adj_mat=adj_mat)
        loss = loss_fn(out.squeeze(), target.float().squeeze())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch: {epoch} - Loss: {epoch_loss}")
    print(out.argmax(dim=-1).flatten().unique(return_counts=True))

    val_preds = []
    val_targets = []
    for data in val_dl: 
        with torch.no_grad():
            adj_mat = edge_index_to_adj_mat(data.edge_index, data.num_nodes)
            target = data.x[:, :4]
            val_targets.append(target.argmax(dim=-1))
            out = model(data.x[:, 11:].float().unsqueeze(0), data.pos.unsqueeze(0), adj_mat=adj_mat)
            val_preds.append(out.argmax(dim=-1))
    val_preds = torch.concat(val_preds, dim=-1).squeeze()
    val_targets = torch.concat(val_targets, dim=-1).squeeze()
    val_accuracy = (val_preds == val_targets).sum() / len(val_preds)
    print(f"Epoch: {epoch} - Val accuracy: {val_accuracy:.4}")


100%|██████████| 2/2 [00:06<00:00,  3.13s/it]


Epoch: 0 - Loss: 8.88794231414795
(tensor([0]), tensor([1910]))
Epoch: 0 - Val accuracy: 0.4085


100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


Epoch: 1 - Loss: 8.82473087310791
(tensor([0]), tensor([1910]))
Epoch: 1 - Val accuracy: 0.4085


100%|██████████| 2/2 [00:05<00:00,  2.80s/it]


Epoch: 2 - Loss: 8.788976669311523
(tensor([0, 2]), tensor([1909,    1]))
Epoch: 2 - Val accuracy: 0.343


100%|██████████| 2/2 [00:05<00:00,  2.84s/it]


Epoch: 3 - Loss: 8.751437187194824
(tensor([0, 2]), tensor([1829,   81]))
Epoch: 3 - Val accuracy: 0.4007


100%|██████████| 2/2 [00:05<00:00,  2.83s/it]


Epoch: 4 - Loss: 8.698257684707642
(tensor([0, 2]), tensor([1837,   73]))
Epoch: 4 - Val accuracy: 0.3582


In [None]:
torch.save(model, "egnn_model.pt")
torch.save(optimizer, "egnn_optimizer.pt")

In [None]:
train_preds = []
train_targets = []
for data in train_dl: 
    with torch.no_grad():
        adj_mat = edge_index_to_adj_mat(data.edge_index, data.num_nodes)
        target = data.x[:, :4]
        train_targets.append(target.argmax(dim=-1))
        out = model(data.x[:, 11:].float().unsqueeze(0), data.pos.unsqueeze(0), adj_mat=adj_mat)
        train_preds.append(out.argmax(dim=-1))
train_preds = torch.concat(train_preds, dim=-1).squeeze()
train_targets = torch.concat(train_targets, dim=-1).squeeze()
train_accuracy = (train_preds == train_targets).sum() / len(train_preds)
print(f"Epoch: {epoch} - Train accuracy: {train_accuracy:.4}")

Epoch: 12 - Train accuracy: 0.9444


In [None]:

val_preds = []
val_targets = []
for data in val_dl: 
    with torch.no_grad():
        adj_mat = edge_index_to_adj_mat(data.edge_index, data.num_nodes)
        target = data.x[:, :4]
        val_targets.append(target.argmax(dim=-1))
        out = model(data.x[:, 11:].float().unsqueeze(0), data.pos.unsqueeze(0), adj_mat=adj_mat)
        val_preds.append(out.argmax(dim=-1))
val_preds = torch.concat(val_preds, dim=-1).squeeze()
val_targets = torch.concat(val_targets, dim=-1).squeeze()
val_accuracy = (val_preds == val_targets).sum() / len(val_preds)
print(f"Epoch: {epoch} - Val accuracy: {val_accuracy:.4}")

Epoch: 12 - Val accuracy: 0.9393


In [None]:
val_preds.bincount()

tensor([20453, 25674,  3875,  1253])

In [None]:
val_accuracy = (val_preds[val_targets == 0] == val_targets[val_targets == 0]).sum() / (val_targets == 0).sum()
print(f"Val accuracy TCR a: {val_accuracy:.4}")
val_accuracy = (val_preds[val_targets == 1] == val_targets[val_targets == 1]).sum() / (val_targets == 1).sum()
print(f"Val accuracy TCR b: {val_accuracy:.4}")
val_accuracy = (val_preds[val_targets == 2] == val_targets[val_targets == 2]).sum() / (val_targets == 2).sum()
print(f"Val accuracy peptide: {val_accuracy:.4}")
val_accuracy = (val_preds[val_targets == 3] == val_targets[val_targets == 3]).sum() / (val_targets == 3).sum()
print(f"Val accuracy MHC: {val_accuracy:.4}")

Val accuracy TCR a: 0.9326
Val accuracy TCR b: 0.9599
Val accuracy peptide: 0.8692
Val accuracy MHC: 0.7776


In [None]:
baseline_random_preds = torch.tensor(np.random.choice(np.arange(4), size=val_preds.shape, p=torch.tensor([187., 237.,  62.,   9.]) / torch.tensor([187., 237.,  62.,   9.]).sum() ))
val_accuracy = (baseline_random_preds[val_targets == 0] == val_targets[val_targets == 0]).sum() / (val_targets == 0).sum()
print(f"Baseline val accuracy TCR a: {val_accuracy:.4}")
val_accuracy = (baseline_random_preds[val_targets == 1] == val_targets[val_targets == 1]).sum() / (val_targets == 1).sum()
print(f"Baseline val accuracy TCR b: {val_accuracy:.4}")
val_accuracy = (baseline_random_preds[val_targets == 2] == val_targets[val_targets == 2]).sum() / (val_targets == 2).sum()
print(f"Baseline val accuracy peptide: {val_accuracy:.4}")
val_accuracy = (baseline_random_preds[val_targets == 3] == val_targets[val_targets == 3]).sum() / (val_targets == 3).sum()
print(f"Baseline val accuracy MHC: {val_accuracy:.4}")

Val accuracy TCR a: 0.3743
Val accuracy TCR b: 0.4787
Val accuracy peptide: 0.1273
Val accuracy MHC: 0.0208
