In [None]:
%load_ext autoreload
%autoreload 2

%pip install nglview ase

import sys
sys.path.append('..')


In [None]:
import json
import urllib.request

# Download the model weights
weights_url = "http://localhost:8000/get-artifact?path=checkpoints%2Fcheckpoint_7600&run_uuid=4a7df87277824ffba36ec8477abf5be5"
weights_file = "model.pt"
urllib.request.urlretrieve(weights_url, weights_file)

# Load model config (TODO - shift to storing this in the model.pt file or downloading separately)
with open("../config/config.json") as f:
    config = json.load(f)

In [None]:
import torch
from nanofold.train.model import Nanofold

model = Nanofold(**Nanofold.get_args(config), inference=True)
model.load_state_dict(torch.load(weights_file)["model"])

In [None]:
import torch
from nanofold.train.chain_dataset import encode_one_hot
from nanofold.train.chain_dataset import get_reference_features
from nanofold.common.residue_definitions import RESIDUE_INDEX_MSA

sequence = "FTLYPYDTNYLIYTQTSDLNKEAIASYDWAENARKDEVKFQLSLAFPLWRGILGPNSVLGASYTQKSWWQLSNSEESSPFRETNYEPQLFLGFATD"
residue_index = torch.arange(len(sequence))
features = {
    **get_reference_features(sequence, residue_index),
    "template_restype": torch.empty(0, len(sequence), len(RESIDUE_INDEX_MSA)),
    "template_backbone_frame_mask": torch.empty(0, len(sequence), 3),
    "template_distogram": torch.empty(0, len(sequence), 37),
    "template_unit_vector": torch.empty(0, len(sequence), 3),
}

from pathlib import Path
from nanofold.preprocess.ipc import get_msa_features
from nanofold.common.msa_metadata import COMPRESSED_MSA_FIELDS
sparse_msa_features = get_msa_features(Path("./"), {"_id": {"structure_id": "1qd6", "chain_id": "C"}})
for msa_field in COMPRESSED_MSA_FIELDS:
    sparse_matrix = torch.sparse_coo_tensor(
        sparse_msa_features[f"{msa_field}_coords"],
        sparse_msa_features[f"{msa_field}_data"],
        sparse_msa_features[f"{msa_field}_shape"],
    )
    dense_matrix = (
        torch.stack([sparse_matrix[i] for i in residue_index])
        .to_dense()
        .reshape(len(sequence), -1, COMPRESSED_MSA_FIELDS[msa_field].feat_size)
        .transpose(0, 1)
    )
    if msa_field == "msa":
        features["profile"] = dense_matrix.mean(dim=-3)
    if msa_field == "deletion_value":
        features["deletion_mean"] = dense_matrix.mean(dim=-3).squeeze()
    features[msa_field] = dense_matrix[:1024]

In [None]:
model.eval()
with torch.no_grad():
    coords = model(features)

In [None]:
import nglview as nv
from nanofold.inference.structure import coords_to_bio_structure

nv.show_biopython(coords_to_bio_structure(sequence, coords))

In [None]:
from ase import Atom, Atoms

from nanofold.common.residue_definitions import get_3l_res_name
from nanofold.common.residue_definitions import BACKBONE_POSITIONS

formula = ""
positions = []
atoms = []
scale_factor = 0.19
for i, r in enumerate(sequence):
    res_name = get_3l_res_name(r)
    for meta, c in zip(BACKBONE_POSITIONS[res_name], scale_factor * (coords[i] - coords.reshape(-1, 3).mean(dim=0))):
        symbol = meta[0] if meta[0] != "CA" else "C"
        atoms.append(Atom(symbol, c))
        formula += symbol
        positions.append(c)
nv.show_ase(Atoms(formula, positions=positions))