In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import sys
from pathlib import Path

sys.path.append("..")  # noqa: E402

import torch
from torch.utils.data import DataLoader

from src.dataset.profsa import (
    ProFSADataset,
    ProFSADataModule,
    MolDataset,
    NextMolDataset,
    process_mol,
)
from src.dataset.components.lmdb import UniMolLMDBDataset

In [14]:
lmdb_dir = Path("/data/prot_frag/train_ligand_pocket")
lmdb_path = lmdb_dir / "valid.lmdb"
lmdb_dataset = UniMolLMDBDataset(lmdb_path)
dataset = ProFSADataset(data_dir=lmdb_dir, data_file="valid.lmdb")
mol_dataset = MolDataset(lmdb_path=lmdb_path)
mol2_dataset = NextMolDataset(lmdb_path="/data/screening/smilesdb/smilesdb.lmdb")

In [6]:
res1 = dataset[0]
for key, val in res1.items():
    if type(val) == torch.Tensor:
        print(key, val.shape, val.dtype)

['C' 'C' 'O' 'N' 'C' 'C' 'O' 'C' 'C' 'C' 'N' 'C' 'C' 'O' 'C' 'C' 'C' 'C'
 'N' 'C' 'C' 'O' 'N' 'C' 'C' 'O' 'C' 'C' 'C' 'C' 'N' 'C' 'C' 'O' 'N']
net_input.mol_src_tokens torch.Size([37]) torch.int64
net_input.mol_src_distance torch.Size([37, 37]) torch.float32
net_input.mol_src_edge_type torch.Size([37, 37]) torch.int64
net_input.pocket_src_tokens torch.Size([215]) torch.int64
net_input.pocket_src_distance torch.Size([215, 215]) torch.float32
net_input.pocket_src_edge_type torch.Size([215, 215]) torch.int64
net_input.pocket_src_coord torch.Size([215, 3]) torch.float32


In [5]:
atoms = lmdb_dataset[0]["lig_atoms_real"]
coordinates = lmdb_dataset[0]["lig_coord_real"]
print(",".join(atoms))

C,C,O,N,C,C,O,C,C,C,N,C,C,O,C,C,C,C,N,C,C,O,N,C,C,O,C,C,C,C,N,C,C,O,N


In [6]:
res2 = process_mol(atoms, coordinates)
for key, val in res2.items():
    if type(val) == torch.Tensor:
        print(key, val.shape, val.dtype)

atoms torch.Size([37]) torch.int64
distance torch.Size([37, 37]) torch.float32
edge_type torch.Size([37, 37]) torch.int64


In [7]:
print(torch.any(res1["net_input.mol_src_tokens"] != res2["atoms"]))
print(torch.any(res1["net_input.mol_src_distance"] != res2["distance"]))
print(torch.any(res1["net_input.mol_src_edge_type"] != res2["edge_type"]))

tensor(False)
tensor(False)
tensor(False)


In [8]:
res3 = mol_dataset[0]
for key, val in res3.items():
    if type(val) == torch.Tensor:
        print(key, val.shape, val.dtype)

mol_src_tokens torch.Size([37]) torch.int64
mol_src_distance torch.Size([37, 37]) torch.float32
mol_src_edge_type torch.Size([37, 37]) torch.int64


In [9]:
print(torch.any(res1["net_input.mol_src_tokens"] != res3["mol_src_tokens"]))
print(torch.any(res1["net_input.mol_src_distance"] != res3["mol_src_distance"]))
print(torch.any(res1["net_input.mol_src_edge_type"] != res3["mol_src_edge_type"]))

tensor(False)
tensor(False)
tensor(False)


In [18]:
batch_size = 4
data_loader = DataLoader(
    dataset,
    collate_fn=dataset.dataset.collater,
    batch_size=batch_size,
    shuffle=False,
)
mol_data_loader = DataLoader(
    mol_dataset,
    collate_fn=mol_dataset.collater,
    batch_size=batch_size,
    shuffle=False,
)

batch = next(iter(data_loader))
mol_batch = next(iter(mol_data_loader))

for key, val in batch["net_input"].items():
    if type(val) == torch.Tensor:
        print(key, val.shape, val.dtype)
print()
for key, val in mol_batch.items():
    if type(val) == torch.Tensor:
        print(key, val.shape, val.dtype)

print()
print(torch.any(batch["net_input"]["mol_src_tokens"] != mol_batch["mol_src_tokens"]))
print(torch.any(batch["net_input"]["mol_src_distance"] != mol_batch["mol_src_distance"]))
print(torch.any(batch["net_input"]["mol_src_edge_type"] != mol_batch["mol_src_edge_type"]))

mol_src_tokens torch.Size([4, 48]) torch.int64
mol_src_distance torch.Size([4, 48, 48]) torch.float32
mol_src_edge_type torch.Size([4, 48, 48]) torch.int64
pocket_src_tokens torch.Size([4, 264]) torch.int64
pocket_src_distance torch.Size([4, 264, 264]) torch.float32
pocket_src_edge_type torch.Size([4, 264, 264]) torch.int64
pocket_src_coord torch.Size([4, 258, 3]) torch.float32
mol_len torch.Size([4]) torch.int64
pocket_len torch.Size([4]) torch.int64

mol_src_tokens torch.Size([4, 48]) torch.int64
mol_src_distance torch.Size([4, 48, 48]) torch.float32
mol_src_edge_type torch.Size([4, 48, 48]) torch.int64

tensor(False)
tensor(False)
tensor(False)


In [11]:
data = mol2_dataset[0]
data.keys()

dict_keys(['mol_src_tokens', 'mol_src_distance', 'mol_src_edge_type'])

In [None]:
mol2_data_loader = DataLoader(
    mol2_dataset,
    collate_fn=mol2_dataset.collater,
    batch_size=batch_size,
    shuffle=False,
)