In [7]:
import os
import torch as pt
from tqdm import tqdm
from glob import glob
from src.dataset import StructuresDataset, col_batch
from src.data_encoding import en_structure, en_features, ext_topology
from src.structure import encode_bfactor, concatenate_chains, split_by_chain
from src.structure_io import save_pdb
from src.config import config_model
from model.model import Model

In [8]:
data_path = "test_tmp"
# data_path = "your_pdb_data_path"

In [9]:
save_path = "checkpoints"
model_filepath = os.path.join(save_path, 'model.pt')

In [10]:
device = pt.device("cpu")
model = Model(config_model)
model.load_state_dict(pt.load(model_filepath, map_location=pt.device("cpu")))
model = model.eval().to(device)

In [11]:
pdb_filepaths = glob(os.path.join(data_path, "*.pdb"), recursive=True)
pdb_filepaths = [fp for fp in pdb_filepaths if "_p" not in fp]
dataset = StructuresDataset(pdb_filepaths, with_preprocessing=True)
print(len(dataset))

1


In [12]:
with pt.no_grad():
    for subunits, filepath in tqdm(dataset):
        structure = concatenate_chains(subunits)
        X, M = en_structure(structure)
        q = en_features(structure)[0]
        ids_topk, _, _, _, _ = ext_topology(X, 64)
        X, ids_topk, q, M = col_batch([[X, ids_topk, q, M]])
        z = model(X.to(device), ids_topk.to(device), q.to(device), M.float().to(device))
        p = pt.sigmoid(z)
        structure = encode_bfactor(structure, p.cpu().numpy())
        output_filepath = filepath[:-4]+'_p.pdb'
        save_pdb(split_by_chain(structure), output_filepath)

100%|██████████| 1/1 [00:07<00:00,  7.50s/it]
