In [22]:
import json 

import torch 

from ingraham.struct2seq.protein_features import ProteinFeatures
from ingraham.struct2seq.data import StructureDataset, StructureLoader
from ingraham.experiments.utils import featurize
from ingraham.struct2seq.struct2seq import Struct2Seq

from torch_geometric.transforms import PointPairFeatures, KNNGraph

device = torch.device("cpu")

In [23]:
dataset = StructureDataset("data/cath/chain_set.jsonl", truncate=3, verbose=True, max_length=512)

len(dataset)

3

In [24]:
list(d["name"] for d in dataset)

['12as.A', '132l.A', '153l.A']

In [25]:
loader = StructureLoader(dataset, batch_size=10_000, shuffle=False)

len(loader)

[330, 129, 185]


1

In [26]:
for batch in loader:
    break

batch[0].keys() 

[[1, 2, 0]]


dict_keys(['seq', 'coords', 'num_chains', 'name', 'CATH'])

In [27]:
# featurize with ingrham 

result = featurize(batch, device)

X, S, mask, lengths = result 

X.shape, S.shape, mask.shape, lengths.shape

(torch.Size([3, 330, 4, 3]), torch.Size([3, 330]), torch.Size([3, 330]), (3,))

In [28]:
X[0, 50, 0]  # coords of the N atom, residue 50, first structure 

tensor([-6.0160, 17.0520, 68.6860])

In [29]:
# we then pass this to the model 

hidden = 128 
neighbors = 16 

model = Struct2Seq(num_letters=20, node_features=hidden, edge_features=hidden, hidden_dim=hidden, k_neighbors=neighbors, 
            protein_features="full")

logits = model(X, S, lengths, mask)

logits.shape 

torch.Size([3, 330, 20])

In [30]:
# note that the model as implemented here outputs the logits for each of the residues in each sequence 

In [31]:
# but inside the model, there are two steps, first the `ProteinFeatures` then the transformer 

features = ProteinFeatures(hidden, hidden, num_positional_embeddings=16, num_rbf=16, top_k=neighbors, features_type="full")

result = features.forward(X, lengths, mask)

V, E, E_idx = result 

V.shape, E.shape, E_idx.shape 

(torch.Size([3, 330, 128]),
 torch.Size([3, 330, 16, 128]),
 torch.Size([3, 330, 16]))

In [32]:
X.shape

torch.Size([3, 330, 4, 3])

In [33]:
# same deal for ProteinMPNN 

In [34]:
from ProteinMPNN.training.model_utils import featurize as justas_featurize
from ProteinMPNN.training.model_utils import ProteinMPNN
from ProteinMPNN.training.model_utils import ProteinFeatures as JustasProteinFeatures 
from ProteinMPNN.training.utils import StructureDataset as JustasStructureDataset
from ProteinMPNN.training.utils import StructureLoader as JustasStructureLoader 


In [35]:
justas_raw_data = []
max_samples = 3 

def transform_for_justas(pkg):
    pkg["masked_list"] = []
    pkg["visible_list"] = [pkg["name"][5]]
    pkg[f"seq_chain_{pkg['name'][5]}"] = pkg["seq"]
    pkg["num_of_chains"] = 1
    for backbone_atom in ["N", "CA", "C", "O"]:
        pkg["coords"][f"{backbone_atom}_chain_{pkg['name'][5]}"] = pkg["coords"][backbone_atom]
    pkg[f"coords_chain_{pkg['name'][5]}"] = pkg["coords"]
    
    return pkg 

with open("data/cath/chain_set.jsonl") as fn:
    count = 0 
    for line in fn.readlines():
        pkg = json.loads(line)
        pkg = transform_for_justas(pkg)
        justas_raw_data.append(pkg)
        count +=1 
        if count >= max_samples:
            break 

len(justas_raw_data)

3

In [36]:
justas_dataset = JustasStructureDataset(justas_raw_data, max_length=512)

len(justas_dataset)

3

In [37]:
justas_loader = JustasStructureLoader(justas_dataset, batch_size=10_000, shuffle=False)

justas_loader

<ProteinMPNN.training.utils.StructureLoader at 0x3118294d0>

In [38]:
for justas_batch in justas_loader:
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = justas_featurize(justas_batch, device)

len(justas_batch)

3

In [39]:
justas_model = ProteinMPNN(k_neighbors=neighbors)

logits = justas_model.forward(X, S, mask, chain_M, residue_idx, chain_encoding_all)

In [40]:
logits.shape 

torch.Size([3, 330, 21])

In [41]:
# same thing here, the model takes the XYZ directly and then, internally, does ... 

In [42]:
justas_features = JustasProteinFeatures(hidden, hidden, num_positional_embeddings=16, num_rbf=16, top_k=neighbors, num_chain_embeddings=16)

E, E_idx = justas_features.forward(X, mask, residue_idx, chain_encoding_all)

E.shape, E_idx.shape 

(torch.Size([3, 330, 16, 128]), torch.Size([3, 330, 16]))

In [43]:
# so internally, the model takes our XYZ features and then outputs edge features `E` as well as a matrix `E_idx` which provides a list of the neightbors for each of 330 residues in the 3 proteins in the batch