In [1]:
import json 

import torch 

from ingraham.struct2seq.protein_features import ProteinFeatures
from ingraham.struct2seq.data import StructureDataset, StructureLoader
from ingraham.experiments.utils import featurize
from ingraham.struct2seq.struct2seq import Struct2Seq

from torch_geometric.transforms import PointPairFeatures, KNNGraph

device = torch.device("cpu")

  Referenced from: <75FFC412-93B5-322B-8E6D-268DA3498CF4> /Users/alex/Documents/inverse-folding-unrolled/.venv/lib/python3.11/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file)
  Referenced from: <75FFC412-93B5-322B-8E6D-268DA3498CF4> /Users/alex/Documents/inverse-folding-unrolled/.venv/lib/python3.11/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file)


In [2]:
dataset = StructureDataset("data/cath/chain_set.jsonl", truncate=3, verbose=True, max_length=512)

len(dataset)

3

In [3]:
dataset[0]

{'seq': 'MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL',
 'coords': {'N': array([[        nan,         nan,         nan],
         [        nan,         nan,         nan],
         [        nan,         nan,         nan],
         [ 1.1751e+01,  3.7846e+01,  2.9016e+01],
         [ 1.4235e+01,  3.9531e+01,  2.6906e+01],
         [ 1.6789e+01,  3.9630e+01,  2.8369e+01],
         [ 1.6368e+01,  3.7519e+01,  3.0261e+01],
         [ 1.5825e+01,  3.5211e+01,  2.8535e+01],
         [ 1.8356e+01,  3.5312e+01,  2.7173e+01],
         [ 2.0058e+01,  3.4382e+01,  2.9069e+01],
         [ 1.9058e+01,  3.1934e+01,  2.9727e+01],
         [ 1.9717e+01,  3.0651e+01,  2.7158e+01],
         [ 2.2605e+01,  3.0674e+01,  2.7231

In [4]:
loader = StructureLoader(dataset, batch_size=10_000, shuffle=False)

len(loader)

[330, 129, 185]


1

In [5]:
for batch in loader:
    break

batch[0].keys() 

[[1, 2, 0]]


dict_keys(['seq', 'coords', 'num_chains', 'name', 'CATH'])

In [6]:
# featurize with ingrham 

result = featurize(batch, device)

X, S, mask, lengths = result 

X.shape, S.shape, mask.shape, lengths.shape

(torch.Size([3, 330, 4, 3]), torch.Size([3, 330]), torch.Size([3, 330]), (3,))

In [7]:
# we then pass this to the model 

hidden = 128 

model = Struct2Seq(num_letters=20, node_features=hidden, edge_features=hidden, hidden_dim=hidden)

logits = model(X, S, lengths, mask)

logits.shape 

Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Cross.cpp:66.)
  n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)


torch.Size([3, 330, 20])

In [8]:
# but inside the model, there are two steps, first the `ProteinFeatures` then the transformer 

features = ProteinFeatures(hidden, hidden)

result = features.forward(X, lengths, mask)

V, E, E_idx = result 

V.shape, E.shape, E_idx.shape 

(torch.Size([3, 330, 128]),
 torch.Size([3, 330, 30, 128]),
 torch.Size([3, 330, 30]))

In [9]:
X.shape

torch.Size([3, 330, 4, 3])

In [10]:
# same deal for ProteinMPNN 

In [11]:
from ProteinMPNN.training.model_utils import featurize as justas_featurize
from ProteinMPNN.training.model_utils import ProteinMPNN
from ProteinMPNN.training.model_utils import ProteinFeatures as JustasProteinFeatures 
from ProteinMPNN.training.utils import StructureDataset as JustasStructureDataset
from ProteinMPNN.training.utils import StructureLoader as JustasStructureLoader 


In [12]:
justas_raw_data = []
max_samples = 3 

def transform_for_justas(pkg):
    pkg["masked_list"] = []
    pkg["visible_list"] = [pkg["name"][5]]
    pkg[f"seq_chain_{pkg['name'][5]}"] = pkg["seq"]
    pkg["num_of_chains"] = 1
    for backbone_atom in ["N", "CA", "C", "O"]:
        pkg["coords"][f"{backbone_atom}_chain_{pkg['name'][5]}"] = pkg["coords"][backbone_atom]
    pkg[f"coords_chain_{pkg['name'][5]}"] = pkg["coords"]
    
    return pkg 

with open("data/cath/chain_set.jsonl") as fn:
    count = 0 
    for line in fn.readlines():
        pkg = json.loads(line)
        pkg = transform_for_justas(pkg)
        justas_raw_data.append(pkg)
        count +=1 
        if count >= max_samples:
            break 

len(justas_raw_data)

3

In [13]:
justas_dataset = JustasStructureDataset(justas_raw_data, max_length=512)

len(justas_dataset)

3

In [14]:
justas_loader = JustasStructureLoader(justas_dataset, batch_size=10_000, shuffle=False)

justas_loader

<ProteinMPNN.training.utils.StructureLoader at 0x1682384d0>

In [15]:
for justas_batch in justas_loader:
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = justas_featurize(justas_batch, device)

len(justas_batch)

3

In [16]:
justas_model = ProteinMPNN()

logits = justas_model.forward(X, S, mask, chain_M, residue_idx, chain_encoding_all)



In [17]:
logits.shape 

torch.Size([3, 330, 21])