In [1]:
import json 

import torch 
import numpy as np 

from torch.utils.tensorboard import SummaryWriter
from torch_geometric.transforms import PointPairFeatures, KNNGraph
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

from ingraham.struct2seq.protein_features import ProteinFeatures
from ingraham.struct2seq.data import StructureDataset, StructureLoader
from ingraham.experiments.utils import featurize
from ingraham.struct2seq.struct2seq import Struct2Seq


device = torch.device("cpu")

In [2]:
dataset = StructureDataset("data/cath/chain_set.jsonl", truncate=3, verbose=True, max_length=512)

len(dataset)

3

In [3]:
list(d["name"] for d in dataset)

['12as.A', '132l.A', '153l.A']

In [4]:
loader = StructureLoader(dataset, batch_size=10_000, shuffle=False)

len(loader)

[330, 129, 185]


1

In [5]:
for batch in loader:
    break

batch[0].keys() 

[[1, 2, 0]]


dict_keys(['seq', 'coords', 'num_chains', 'name', 'CATH'])

In [6]:
# featurize with ingrham 

result = featurize(batch, device)

X, S, mask, lengths = result 

X.shape, S.shape, mask.shape, lengths.shape

(torch.Size([3, 330, 4, 3]), torch.Size([3, 330]), torch.Size([3, 330]), (3,))

In [7]:
X[0, 50, 0]  # coords of the N atom, residue 50, first structure 

tensor([-6.0160, 17.0520, 68.6860])

In [8]:
# we then pass this to the model 

hidden = 128 
neighbors = 16 

model = Struct2Seq(num_letters=20, node_features=hidden, edge_features=hidden, hidden_dim=hidden, k_neighbors=neighbors, 
            protein_features="full")

logits = model(X, S, lengths, mask)

logits.shape 

Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Cross.cpp:66.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)


torch.Size([3, 330, 20])

In [9]:
# note that the model as implemented here outputs the logits for each of the residues in each sequence 

In [10]:
# but inside the model, there are two steps, first the `ProteinFeatures` then the transformer 

features = ProteinFeatures(hidden, hidden, num_positional_embeddings=16, num_rbf=16, top_k=neighbors, features_type="full")

result = features.forward(X, lengths, mask)

V, E, E_idx = result 

V.shape, E.shape, E_idx.shape 

(torch.Size([3, 330, 128]),
 torch.Size([3, 330, 16, 128]),
 torch.Size([3, 330, 16]))

In [11]:
X.shape

torch.Size([3, 330, 4, 3])

In [12]:
# same deal for ProteinMPNN 

In [13]:
from ProteinMPNN.training.model_utils import featurize as justas_featurize
from ProteinMPNN.training.model_utils import ProteinMPNN
from ProteinMPNN.training.model_utils import ProteinFeatures as JustasProteinFeatures 
from ProteinMPNN.training.utils import StructureDataset as JustasStructureDataset
from ProteinMPNN.training.utils import StructureLoader as JustasStructureLoader 

In [14]:
justas_raw_data = []
max_samples = 3 

def transform_for_justas(pkg):
    pkg["masked_list"] = []
    pkg["visible_list"] = [pkg["name"][5]]
    pkg[f"seq_chain_{pkg['name'][5]}"] = pkg["seq"]
    pkg["num_of_chains"] = 1
    for backbone_atom in ["N", "CA", "C", "O"]:
        pkg["coords"][f"{backbone_atom}_chain_{pkg['name'][5]}"] = pkg["coords"][backbone_atom]
    pkg[f"coords_chain_{pkg['name'][5]}"] = pkg["coords"]
    
    return pkg 

with open("data/cath/chain_set.jsonl") as fn:
    count = 0 
    for line in fn.readlines():
        pkg = json.loads(line)
        pkg = transform_for_justas(pkg)
        justas_raw_data.append(pkg)
        count +=1 
        if count >= max_samples:
            break 

len(justas_raw_data)

3

In [15]:
justas_dataset = JustasStructureDataset(justas_raw_data, max_length=512)

len(justas_dataset)

3

In [16]:
justas_loader = JustasStructureLoader(justas_dataset, batch_size=10_000, shuffle=False)

justas_loader

<ProteinMPNN.training.utils.StructureLoader at 0x30f058a50>

In [17]:
for justas_batch in justas_loader:
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = justas_featurize(justas_batch, device)

len(justas_batch)

3

In [18]:
justas_model = ProteinMPNN(k_neighbors=neighbors)

logits = justas_model.forward(X, S, mask, chain_M, residue_idx, chain_encoding_all)



In [19]:
logits.shape 

torch.Size([3, 330, 21])

In [20]:
# same thing here, the model takes the XYZ directly and then, internally, does ... 

In [21]:
justas_features = JustasProteinFeatures(hidden, hidden, num_positional_embeddings=16, num_rbf=16, top_k=neighbors, num_chain_embeddings=16)

E, E_idx = justas_features.forward(X, mask, residue_idx, chain_encoding_all)

E.shape, E_idx.shape 

(torch.Size([3, 330, 16, 128]), torch.Size([3, 330, 16]))

In [22]:
# so internally, the model takes our XYZ features and then outputs edge features `E` as well as a matrix `E_idx` which provides a list of the neightbors for each of 330 residues in the 3 proteins in the batch

In [23]:
# now now all we need really is a new network 

# that accepts the XYZ and outputs the logits 

In [24]:
# but first let's load this into PyG data 

In [25]:
alphabet = "ACDEFGHIKLMNPQRSTVWY"
itos = {i: letter for i, letter in enumerate(alphabet)}
stoi = {v: k for k, v in itos.items()}    
vocab_size = len(set(alphabet))

features = ProteinFeatures(hidden, hidden, num_positional_embeddings=16, num_rbf=16, top_k=neighbors, features_type="full")

def transform_for_pyg(pkg):
    X, S, mask, lengths = featurize([pkg], device)
    # print(lengths)
    V, E, E_idx = features.forward(X, lengths, mask)
    tokens = torch.tensor(list(stoi[s] for s in pkg["seq"])).detach() 
    edges_1 = []
    edges_2 = [] 
    edge_index = 0 
    edge_attr = []
    for seq_pos in range(S.shape[1]):
        neighbors = E_idx[0, seq_pos]
        for nbr_idx, neighbor in enumerate(neighbors):
            # the edge is between seq_pos, and neighbor 
            edges_1.append(seq_pos)
            edges_2.append(neighbor)

            # get the data from `E`, recall E is [batch, seq, k, features] and you just want [features] for a particular one 
            # this one, in fact 
            my_edge_features = E[0, seq_pos, nbr_idx]
            edge_attr.append(my_edge_features)
            edge_index += 1

    edge_attr = torch.tensor(np.stack(edge_attr), dtype=torch.long)
    edges = torch.tensor((edges_1, edges_2)).detach()
    data = Data(x=V[0].detach(), edge_attr=edge_attr, edge_index=edges, y=tokens)
    return data 


max_samples = 128 
pyg_data = [] 
names = []

with torch.no_grad():
    with open("data/cath/chain_set.jsonl") as fn:
        count = 0 
        for line in fn.readlines():
            try:
                pkg = json.loads(line)
                pkg = transform_for_pyg(pkg)
                pyg_data.append(pkg)
                #names.append(pkg["name"])
                count +=1 
            except:
                pass
            if count >= max_samples:
                break 

pyg_data[0], len(pyg_data)

(Data(x=[330, 128], edge_index=[2, 5280], edge_attr=[5280, 128], y=[330]), 128)

In [26]:
loader = DataLoader(pyg_data, batch_size=1)

for batch in loader:
    break 

batch.edge_attr.shape

torch.Size([5280, 128])