In [1]:
import sys
import numpy as np
from math import pow
import py3Dmol
from pathlib import Path
from io import StringIO
from Bio.PDB import PDBIO
from Bio.PDB import MMCIFParser, Superimposer
sys.path.append(str(Path("./src/simplefold").resolve()))

In [2]:
# following are example amino acid sequences:
example_sequences = {
    # "7ftv_A": "GASKLRAVLEKLKLSRDDISTAAGMVKGVVDHLLLRLKCDSAFRGVGLLNTGSYYEHVKISAPNEFDVMFKLEVPRIQLEEYSNTRAYYFVKFKRNPKENPLSQFLEGEILSASKMLSKFRKIIKEEINDDTDVIMKRKRGGSPAVTLLISEKISVDITLALESKSSWPASTQEGLRIQNWLSAKVRKQLRLKPFYLVPKHAEETWRLSFSHIEKEILNNHGKSKTCCENKEEKCCRKDCLKLMKYLLEQLKERFKDKKHLDKFSSYHVKTAFFHVCTQNPQDSQWDRKDLGLCFDNCVTYFLQCLRTEKLENYFIPEFNLFSSNLIDKRSKEFLTKQIEYERNNEFPVFD",
    # "a2a": "XSSVYITVELAIAVLAILGNVLVCWAVWLNSNLQNVTNYFVVSLAAADIAVGVLAIPFAITISTGFXAAXXGXLFIACFVLVLTQSSIFSLLAIAIDRYIAIRIPLRYNGLVTGTRAKGIIAICWVLSFAIGLTPMLGWNNXGQPKEGKNXSQGXGEGQVAXLFEDVVPMNYMVYFNFFACVLVPLLLMLGVYLRIFLAARRQLKQMESQPLPGERARSTLQKEVXAAKSLAIIVGLFALCWLPLXIINCFTFFXPDXSXAPLWLMYLAIVLSXTNSVVNPFIYAYRIREFRQTFRKIIRSX"
    "a2a_nocappings": "SSVYITVELAIAVLAILGNVLVCWAVWLNSNLQNVTNYFVVSLAAADIAVGVLAIPFAITISTGFXAAXXGXLFIACFVLVLTQSSIFSLLAIAIDRYIAIRIPLRYNGLVTGTRAKGIIAICWVLSFAIGLTPMLGWNNXGQPKEGKNXSQGXGEGQVAXLFEDVVPMNYMVYFNFFACVLVPLLLMLGVYLRIFLAARRQLKQMESQPLPGERARSTLQKEVXAAKSLAIIVGLFALCWLPLXIINCFTFFXPDXSXAPLWLMYLAIVLSXTNSVVNPFIYAYRIREFRQTFRKIIRS"
}
seq_id = "a2a_nocappings"  # choose from example_sequences
aa_sequence = example_sequences[seq_id]
print(f"Predicting structure for {seq_id} with {len(aa_sequence)} amino acids.")

Predicting structure for a2a_nocappings with 300 amino acids.


In [3]:
simplefold_model = "simplefold_100M" # choose from 100M, 360M, 700M, 1.1B, 1.6B, 3B
backend = "torch" # choose from ["mlx", "torch"]

ckpt_dir = "artifacts"
output_dir = "artifacts"
prediction_dir = f"predictions_{simplefold_model}_{backend}"

output_name = f"{seq_id}"
num_steps = 500 # number of inference steps for flow-matching
tau = 0.05 # stochasticity scale
plddt = True # whether to use pLDDT confidence module
nsample_per_protein = 1 # number of samples per protein

In [4]:
# set random seed for reproducibility
import lightning.pytorch as pl
pl.seed_everything(42, workers=True)

Seed set to 42


42

In [5]:
from src.simplefold.wrapper import ModelWrapper, InferenceWrapper

# initialize the folding model and pLDDT model
model_wrapper = ModelWrapper(
    simplefold_model=simplefold_model,
    ckpt_dir=ckpt_dir,
    plddt=plddt,
    backend=backend,
)
device = model_wrapper.device
folding_model = model_wrapper.from_pretrained_folding_model()
plddt_model = model_wrapper.from_pretrained_plddt_model()

MLX not installed, skip importing MLX related packages.
Folding model simplefold_100M loaded with torch backend.
pLDDT output module loaded with torch backend.
pLDDT latent module loaded with torch backend.


In [6]:
# initialize the inference module with inference configurations
inference_wrapper = InferenceWrapper(
    output_dir=output_dir,
    prediction_dir=prediction_dir,
    num_steps=num_steps,
    tau=tau,
    nsample_per_protein=nsample_per_protein,
    device=device,
    backend=backend
)

Using cache found in /home/nobilm@usi.ch/.cache/torch/hub/facebookresearch_esm_main


pLM ESM-3B loaded with torch backend.


In [7]:
# process input sequence and run inference
batch, structure, record = inference_wrapper.process_input(aa_sequence)


Processing input data.


100%|██████████| 1/1 [00:00<00:00, 18.17it/s]


Processing ESM features for inference...


In [8]:
for k,v in batch.items():
    try:
        print(k, type(v), v.shape)
    except Exception as e:
        print(k, type(v), v[:10])

token_index <class 'torch.Tensor'> torch.Size([1, 300])
residue_index <class 'torch.Tensor'> torch.Size([1, 300])
asym_id <class 'torch.Tensor'> torch.Size([1, 300])
entity_id <class 'torch.Tensor'> torch.Size([1, 300])
sym_id <class 'torch.Tensor'> torch.Size([1, 300])
mol_type <class 'torch.Tensor'> torch.Size([1, 300])
res_type <class 'torch.Tensor'> torch.Size([1, 300, 33])
disto_center <class 'torch.Tensor'> torch.Size([1, 300, 3])
token_pad_mask <class 'torch.Tensor'> torch.Size([1, 300])
token_resolved_mask <class 'torch.Tensor'> torch.Size([1, 300])
token_disto_mask <class 'torch.Tensor'> torch.Size([1, 300])
pocket_feature <class 'torch.Tensor'> torch.Size([1, 300, 4])
ref_pos <class 'torch.Tensor'> torch.Size([1, 2304, 3])
atom_resolved_mask <class 'torch.Tensor'> torch.Size([1, 2304])
ref_element <class 'torch.Tensor'> torch.Size([1, 2304, 128])
ref_charge <class 'torch.Tensor'> torch.Size([1, 2304])
ref_atom_name_chars <class 'torch.Tensor'> torch.Size([1, 2304, 4, 64])
ref

In [9]:
structure

Structure(atoms=array([([46,  0,  0,  0], 7, 0, [0., 0., 0.], [ 6.7465025e-01,  1.5018703e+00, -5.3672951e-01],  True, 0),
       ([35, 33,  0,  0], 6, 0, [0., 0., 0.], [ 1.3792863e-04,  4.9664670e-01,  2.8510505e-01],  True, 0),
       ([35,  0,  0,  0], 6, 0, [0., 0., 0.], [ 9.9410099e-01, -5.3746176e-01,  7.3505038e-01],  True, 0),
       ...,
       ([47,  0,  0,  0], 8, 0, [0., 0., 0.], [ 1.0545242e+00, -8.6835456e-01,  1.9495397e+00],  True, 0),
       ([35, 34,  0,  0], 6, 0, [0., 0., 0.], [-1.1279289e+00, -1.6593763e-01, -5.1609635e-01],  True, 0),
       ([47, 39,  0,  0], 8, 0, [0., 0., 0.], [-1.8135979e+00, -1.0852497e+00,  2.8947514e-01],  True, 0)],
      shape=(2300,), dtype=[('name', 'i1', (4,)), ('element', 'i1'), ('charge', 'i1'), ('coords', '<f4', (3,)), ('conformer', '<f4', (3,)), ('is_present', '?'), ('chirality', 'i1')]), bonds=array([], dtype=[('atom_1', '<i4'), ('atom_2', '<i4'), ('type', 'i1')]), residues=array([('SER', 17,   0,    0,  6,    1,    4,  True,  Tru

In [10]:
record

Record(id='input', structure={'resolution': None, 'method': None, 'deposited': None, 'released': None, 'revised': None, 'num_chains': 1, 'num_interfaces': None}, chains=[{'chain_id': 0, 'chain_name': 'A', 'mol_type': 0, 'cluster_id': -1, 'msa_id': -1, 'num_residues': 300, 'valid': True, 'entity_id': 0}], interfaces=[], inference_options={'binders': [], 'pocket': []})

In [11]:
results = inference_wrapper.run_inference(
    batch,
    folding_model,
    plddt_model,
    device=device,
)
save_paths = inference_wrapper.save_result(
    structure,
    record,
    results,
    out_name=output_name
)

Sampling: 100%|██████████| 500/500 [00:45<00:00, 10.89it/s]


In [12]:
# visualize the first predicted structure
pdb_path = save_paths[0]
view = py3Dmol.view(query=pdb_path)

In [13]:
# color based on the predicted confidence
# confidence coloring from low to high: red–orange–yellow–green–blue (0 to 100)
if plddt:
    view.setStyle({'cartoon':{'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':100}}})
    view.zoomTo()
    view.show()
# color in spectrum if pLDDT is not available
else:
    view.setStyle({'cartoon':{'color':'spectrum'}})
    view.zoomTo()
    view.show()

In [14]:
# visualize the all-atom structure
view.setStyle({'stick':{}})
view.zoomTo()
view.show()

In [15]:
pdb_path

PosixPath('artifacts/predictions_simplefold_100M_torch/a2a_nocappings_sampled_0.cif')

In [16]:
# visualize the predicted structure in 3D alongside the GT structure

def calculate_tm_score(coords1, coords2, L_target=None):
    """
    Compute TM-score for two aligned coordinate sets (numpy arrays).

    coords1, coords2: Nx3 numpy arrays (aligned atomic coordinates, e.g. CA atoms)
    L_target: length of target protein (default = len(coords1))
    """
    assert coords1.shape == coords2.shape, "Aligned coords must have same shape"
    N = coords1.shape[0]

    if L_target is None:
        L_target = N

    # distances between aligned atoms
    dists = np.linalg.norm(coords1 - coords2, axis=1)

    # scaling factor d0
    d0 = 1.24 * pow(L_target - 15, 1/3) - 1.8
    if d0 < 0.5:
        d0 = 0.5  # safeguard, as in TM-align

    # TM-score
    score = np.sum(1.0 / (1.0 + (dists/d0)**2)) / L_target
    return score

parser = MMCIFParser(QUIET=True)


In [17]:
# Load two structures
struct1 = parser.get_structure("ref", "/home/nobilm@usi.ch/ml-simplefold/data/pdb_inapo/FApo_no_caps.cif") # this comes from the .pdb
struct2 = parser.get_structure("prd", pdb_path)

# Select CA atoms for alignment
atoms1 = [a for a in struct1.get_atoms() if a.get_id() == 'CA']
atoms2 = [a for a in struct2.get_atoms() if a.get_id() == 'CA']
print(len(atoms1), len(atoms2))

# Superimpose
sup = Superimposer()
sup.set_atoms(atoms1, atoms2)
sup.apply(struct2.get_atoms())

# Calculate TM-score
coords1 = np.array([a.coord for a in atoms1])
coords2 = np.array([a.coord for a in atoms2])
tm_score = calculate_tm_score(coords1, coords2)

print("TM-score (0-1, higher is better): {:.3f}".format(tm_score))
print("RMSD (lower is better): {:.3f}".format(sup.rms))

# Save aligned structures to strings
io = PDBIO()

s1_buf, s2_buf = StringIO(), StringIO()
io.set_structure(struct1); io.save(s1_buf)
io.set_structure(struct2); io.save(s2_buf)

# Visualize in py3Dmol
view = py3Dmol.view(width=600, height=400)
view.addModel(s1_buf.getvalue(),"pdb")
view.addModel(s2_buf.getvalue(),"pdb")

# Color reference protein blue, predicted structure red
view.setStyle({'model': 0}, {'cartoon': {'color': 'blue'}})
view.setStyle({'model': 1}, {'cartoon': {'color': 'red'}})

# Add legend
view.addLabel("Ground Truth", {'position': {'x': 0, 'y': 0, 'z': 0}, 'backgroundColor': 'blue', 'fontColor': 'white', 'fontSize': 12})
view.addLabel("Predicted", {'position': {'x': 0, 'y': 4, 'z': 0}, 'backgroundColor': 'red', 'fontColor': 'white', 'fontSize': 12})

view.zoomTo()
view.show()

300 300
TM-score (0-1, higher is better): 0.859
RMSD (lower is better): 3.700


In [None]:
# Load two structures
struct1 = parser.get_structure("ref", "/home/nobilm@usi.ch/ml-simplefold/data/pdb_inapo/INApo_no_caps.cif") # this comes from the .pdb
struct2 = parser.get_structure("prd", pdb_path)

# Select CA atoms for alignment
atoms1 = [a for a in struct1.get_atoms() if a.get_id() == 'CA']
atoms2 = [a for a in struct2.get_atoms() if a.get_id() == 'CA']
print(len(atoms1), len(atoms2))

# Superimpose
sup = Superimposer()
sup.set_atoms(atoms1, atoms2)
sup.apply(struct2.get_atoms())

# Calculate TM-score
coords1 = np.array([a.coord for a in atoms1])
coords2 = np.array([a.coord for a in atoms2])
tm_score = calculate_tm_score(coords1, coords2)

print("TM-score (0-1, higher is better): {:.3f}".format(tm_score))
print("RMSD (lower is better): {:.3f}".format(sup.rms))

# Save aligned structures to strings
io = PDBIO()

s1_buf, s2_buf = StringIO(), StringIO()
io.set_structure(struct1); io.save(s1_buf)
io.set_structure(struct2); io.save(s2_buf)

# Visualize in py3Dmol
view = py3Dmol.view(width=600, height=400)
view.addModel(s1_buf.getvalue(),"pdb")
view.addModel(s2_buf.getvalue(),"pdb")

# Color reference protein blue, predicted structure red
view.setStyle({'model': 0}, {'cartoon': {'color': 'blue'}})
view.setStyle({'model': 1}, {'cartoon': {'color': 'red'}})

# Add legend
view.addLabel("Ground Truth", {'position': {'x': 0, 'y': 0, 'z': 0}, 'backgroundColor': 'blue', 'fontColor': 'white', 'fontSize': 12})
view.addLabel("Predicted", {'position': {'x': 0, 'y': 4, 'z': 0}, 'backgroundColor': 'red', 'fontColor': 'white', 'fontSize': 12})

view.zoomTo()
view.show()

300 300
TM-score (0-1, higher is better): 0.970
RMSD (lower is better): 1.252
