In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ..

/home/oem/8-project/crypogen/pocketgen


In [3]:
import pickle
import torch
from datasets import get_dataset
import utils.misc as misc
import utils.transforms as trans
from torch_geometric.transforms import Compose
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from models.molopt_score_model import ScorePosNet3D
from scripts.likelihood_est_diffusion import data_likelihood_estimation

In [5]:
sampling_config = 'configs/sampling.yml'
sampling_config = misc.load_config(sampling_config)

In [6]:
# Load checkpoint
device = 'cuda:0'
ckpt = torch.load(sampling_config.model.checkpoint, map_location=device)

# Transforms
protein_featurizer = trans.FeaturizeProteinAtom()
ligand_atom_mode = ckpt['config'].data.transform.ligand_atom_mode
ligand_featurizer = trans.FeaturizeLigandAtom(ligand_atom_mode)
transform = Compose([
    protein_featurizer,
    ligand_featurizer,
    trans.FeaturizeLigandBond(),
])

In [7]:
# Load model
model = ScorePosNet3D(
    ckpt['config'].model,
    protein_atom_feature_dim=protein_featurizer.feature_dim,
    ligand_atom_feature_dim=ligand_featurizer.feature_dim
).to(device)
model.load_state_dict(ckpt['model'], strict=True)
print(f'Successfully load the model! {sampling_config.model.checkpoint}')

Successfully load the model! ./logs_diffusion/training_2024_06_12__15_17_02/checkpoints/14100.pt


In [8]:
from utils.data import PDBProtein, parse_sdf_file
from datasets.pl_data import ProteinLigandData, torchify_dict
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

In [9]:
def convert_data(pdb_path, ligand_path, transform, radius=10, pocket=False):
    # ligand_dict = parse_sdf_file_mol(ligand_path, heavy_only=False)
    ligand_dict = parse_sdf_file(ligand_path)
    if not pocket:
        protein = PDBProtein(pdb_path)
        pdb_block_pocket = protein.residues_to_pdb_block(
            protein.query_residues_ligand(ligand_dict, radius)
        )
        pocket_dict = PDBProtein(pdb_block_pocket).to_dict_atom()
    else:
        pocket_dict = PDBProtein(pdb_path).to_dict_atom()

    data = ProteinLigandData.from_protein_ligand_dicts(
        protein_dict=torchify_dict(pocket_dict),
        ligand_dict=torchify_dict(ligand_dict),
    )
    data.protein_filename = pdb_path
    data.ligand_filename = ligand_path
    # data.y = torch.tensor(float(pka))
    # data.kind = torch.tensor(2)  # kd
    # data.id = idx
    assert data.protein_pos.size(0) > 0
    if transform is not None:
        data = transform(data)
    return data

In [10]:
class InferenceDataset(Dataset):
    def __init__(self, data_list):
        super().__init__()
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        return data

In [11]:
test_data = convert_data('examples/3ug2_protein.pdb', 
                         'examples/3ug2_ligand.sdf', transform)
# test_data.kind = KMAP[args.kind]
test_set = InferenceDataset([test_data])
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, follow_batch=['protein_element', 'ligand_element'])

In [12]:
batch = next(iter(test_loader)).to(device)

In [13]:
preds = model.fetch_embedding(
            protein_pos=batch.protein_pos,
            protein_v=batch.protein_atom_feature.float(),
            batch_protein=batch.protein_element_batch,

            ligand_pos=batch.ligand_pos,
            ligand_v=batch.ligand_atom_feature_full,
            batch_ligand=batch.ligand_element_batch,
        )

In [14]:
# load linear model
with open('pretrained_models/pk_reg_para.pkl', 'rb') as f:
    lmodel = pickle.load(f) 

FileNotFoundError: [Errno 2] No such file or directory: 'pretrained_models/pk_reg_para.pkl'

In [18]:
final_ligand_h = np.array([preds['final_ligand_h'].cpu().numpy().mean(0)])

In [19]:
pka = lmodel.predict(final_ligand_h)

In [20]:
affinity = np.power(10, -pka)

In [21]:
affinity

array([1.0119374e-09], dtype=float32)