In [2]:
import os
import numpy as np
import lmdb
import pickle
from functools import lru_cache
from scipy.spatial import distance_matrix
from torch.utils.data import Dataset

class LMDBDataset_cid:
    def __init__(self, db_path):
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(self.db_path)
        env = self.connect_db(self.db_path)
        with env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env

    def __len__(self):
        return len(self._keys)

    @lru_cache(maxsize=16)
    def __getitem__(self, cid):
        if not hasattr(self, "env"):
            self.connect_db(self.db_path, save_to_self=True)
        datapoint_pickled = self.env.begin().get(cid.encode())
        data = pickle.loads(datapoint_pickled)
        return data

class D3Dataset_cid(Dataset):
    def __init__(self, path, max_atoms=256):
        self.lmdb_dataset = LMDBDataset_cid(path)

        self.max_atoms = max_atoms
        ## the following is the default setting of uni-mol's pretrained weights
        self.remove_hydrogen = True
        self.remove_polar_hydrogen = False
        self.normalize_coords = True
        self.add_special_token = True
        self.__max_atoms = 512

    def __len__(self):
        return len(self.lmdb_dataset)

    def __getitem__(self, index):
        # Get the cid from self.lmdb_dataset._keys[index]
        cid = self.lmdb_dataset._keys[index].decode('utf-8')  # Decode bytes to string if necessary

        # Retrieve the data using cid
        data = self.lmdb_dataset[cid]
        smiles = data['smiles']
        description = data['description']
        enriched_description = data['enriched_description']
        ## deal with 3d coordinates
        atoms_orig = np.array(data['atoms'])
        atoms = atoms_orig.copy()
        coordinates = data['coordinates']

        return atoms, coordinates, smiles, description, enriched_description, cid

In [3]:
path = 'C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/3d-pubchem.lmdb'
lmdb_dataset = LMDBDataset_cid(path)
cid = lmdb_dataset._keys[2].decode('utf-8')
data = lmdb_dataset[cid]
txt = data['enriched_description']
sml = data['smiles']

## Text-Encoders: Sci-Bert & MolT5

In [4]:
from transformers import BertTokenizer, BertConfig, BertLMHeadModel

def init_tokenizer():
    bert_name = 'allenai/scibert_scivocab_uncased'
    tokenizer = BertTokenizer.from_pretrained(bert_name)
    tokenizer.add_special_tokens({"bos_token": "[DEC]"})
    return tokenizer

# Import Sci-bert for text encoding
tokenizer = init_tokenizer()
encoder_config = BertConfig.from_pretrained('allenai/scibert_scivocab_uncased')
model = BertLMHeadModel.from_pretrained('allenai/scibert_scivocab_uncased', config=encoder_config)
inputs = tokenizer(txt, return_tensors="pt")
outputs = model(**inputs)
cls_token = outputs[0][:,0,:]

#############################################################################################################################################

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

text_encoder = T5ForConditionalGeneration.from_pretrained('laituan245/molt5-large-caption2smiles')
text_tokenizer = T5Tokenizer.from_pretrained("laituan245/molt5-large-caption2smiles", model_max_length=512)

@torch.no_grad()
def molT5_encoder(descriptions, molt5, molt5_tokenizer, description_length, device):
    tokenized = molt5_tokenizer(descriptions, padding='max_length', truncation=True, max_length=description_length, return_tensors="pt").to(device)
    encoder_outputs = molt5.encoder(input_ids=tokenized.input_ids, attention_mask=tokenized.attention_mask, return_dict=True).last_hidden_state
    return encoder_outputs, tokenized.attention_mask

biot5_embed, pad_mask = molT5_encoder(txt, text_encoder, text_tokenizer, 256, device='cpu')




If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## SMILES encoder

In [21]:
from utils import molT5_encoder, AE_SMILES_encoder, regexTokenizer
from train_autoencoder import ldmol_autoencoder
import argparse
import torch

# Load pretrained encoder parameter 
parser = argparse.ArgumentParser()
parser.add_argument("--vae", type=str, default="C:/Users/namjo/Downloads/Pretrain-20250121T212421Z-001/Pretrain/checkpoint_autoencoder.ckpt")  # Choice doesn't affect training
args = parser.parse_args([])

device = 'cpu'

# Load pretrained encoder 
ae_config = {
        'bert_config_decoder': './config_decoder.json',
        'bert_config_encoder': './config_encoder.json',
        'embed_dim': 256,
    }

tokenizer = regexTokenizer(vocab_path='./vocab_bpe_300_sc.txt', max_len=127) #newtkn
ae_model = ldmol_autoencoder(config=ae_config, no_train=True, tokenizer=tokenizer, use_linear=True)

if args.vae:
    print('LOADING PRETRAINED MODEL..', args.vae)
    checkpoint = torch.load(args.vae, map_location='cpu')
    try:
        state_dict = checkpoint['model']
    except:
        state_dict = checkpoint['state_dict']
    msg = ae_model.load_state_dict(state_dict, strict=False)
    print('autoencoder', msg)
for param in ae_model.parameters():
    param.requires_grad = False
del ae_model.text_encoder
ae_model = ae_model.to(device)
ae_model.eval()

print(f'AE #parameters: {sum(p.numel() for p in ae_model.parameters())}, #trainable: {sum(p.numel() for p in ae_model.parameters() if p.requires_grad)}')

# Encode the Input SMILES representation 
@torch.no_grad()
def AE_SMILES_encoder(sm, ae_model):
    if sm[0][:5] == "[CLS]":    sm = [s[5:] for s in sm]
    text_input = ae_model.tokenizer(sm).to(ae_model.device)
    text_input_ids = text_input
    text_attention_mask = torch.where(text_input_ids == 0, 0, 1).to(text_input.device)
    if hasattr(ae_model.text_encoder2, 'bert'):
        output = ae_model.text_encoder2.bert(text_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text').last_hidden_state
    else:
        output = ae_model.text_encoder2(text_input_ids, attention_mask=text_attention_mask, return_dict=True).last_hidden_state

    if hasattr(ae_model, 'encode_prefix'):
        output = ae_model.encode_prefix(output)
        if ae_model.output_dim*2 == output.size(-1):
            mean, logvar = torch.chunk(output, 2, dim=-1)
            logvar = torch.clamp(logvar, -30.0, 20.0)
            std = torch.exp(0.5 * logvar)
            output = mean + std * torch.randn_like(mean)
    return output

sml_rep = AE_SMILES_encoder(sml, ae_model) # [1, 127, 64] # 127 is length of Tokenizer.


LOADING PRETRAINED MODEL.. C:/Users/namjo/Downloads/Pretrain-20250121T212421Z-001/Pretrain/checkpoint_autoencoder.ckpt


  checkpoint = torch.load(args.vae, map_location='cpu')


autoencoder _IncompatibleKeys(missing_keys=[], unexpected_keys=['text_encoder2.embeddings.position_ids'])
AE #parameters: 127993920, #trainable: 0


In [None]:
import torch
import argparse
import torch.nn as nn

from unicore.data import Dictionary
from model.unimol_simple import SimpleUniMolModel

parser = argparse.ArgumentParser()
SimpleUniMolModel.add_args(parser) 
args = parser.parse_args([])

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""
    def forward(self, x: torch.Tensor, mask=None):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

def init_unimol_mol_encoder(args):
    dictionary = Dictionary.load('unimol_dict_mol.txt')
    dictionary.add_symbol("[MASK]", is_special=True)
    unimol_model = SimpleUniMolModel(args, dictionary)

    ckpt = torch.load('mol_pre_no_h_220816.pt', map_location=torch.device('cpu'))['model']
    missing_keys, unexpected_keys = unimol_model.load_state_dict(ckpt, strict=False)

    ln_graph = LayerNorm(unimol_model.num_features)
    return unimol_model, ln_graph, dictionary

# conf_encoder takes [src_token, src_distance, src_edge_type] as input
conf_encoder, ln_conf, dictionary_mol = init_unimol_mol_encoder(args)

In [13]:
import random
from scipy.spatial import distance_matrix

dictionary = Dictionary.load('C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/unimol_dict_mol.txt')
bos = dictionary.bos()
eos = dictionary.eos()
num_types = len(dictionary)

path = 'C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/3d-pubchem.lmdb'

lmdb_dataset = LMDBDataset_cid(path)
cid = lmdb_dataset._keys[5].decode('utf-8')
data = lmdb_dataset[cid]
smils = data['smiles']
atoms_orig = np.array(data['atoms'])
atoms = atoms_orig.copy()
coordinate_set = data['coordinates']
txt = data['enriched_description']
coordinates = random.sample(coordinate_set, 1)[0].astype(np.float32)

mask_hydrogen = atoms != "H" # [True, True, True, True, .... , False]
atoms = atoms[mask_hydrogen]
coordinates = coordinates[mask_hydrogen]
coordinates = coordinates - coordinates.mean(axis=0)
atom_vec = torch.from_numpy(dictionary.vec_index(atoms)).long()

atom_vec = torch.cat([torch.LongTensor([bos]), atom_vec, torch.LongTensor([eos])])
coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)

edge_type = atom_vec.view(-1, 1) * num_types + atom_vec.view(1, -1)
dist = distance_matrix(coordinates, coordinates).astype(np.float32)

coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)

In [None]:
a, b = conf_encoder(atom_vec.unsqueeze(0), dist.unsqueeze(0), edge_type.unsqueeze(0))
from unicore.data import data_utils

def collate_tokens_coords(
    values,
    pad_idx,
    left_pad=False,
    pad_to_length=None,
    pad_to_multiple=1,
):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    size = size if pad_to_length is None else max(size, pad_to_length)
    if pad_to_multiple != 1 and size % pad_to_multiple != 0:
        size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
    res = values[0].new(len(values), size, 3).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
    return res

padded_coordinates = collate_tokens_coords(coordinates.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8) 
padded_edge_type = data_utils.collate_tokens_2d(edge_type.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8)
padded_dist = data_utils.collate_tokens_2d(dist.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8)
padded_coordinates.shape

In [None]:
path = 'C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/3d-pubchem.lmdb'
lmdb_dataset = LMDBDataset_cid(path)
cid = lmdb_dataset._keys[0].decode('utf-8')
cid

In [48]:
from data_provider.unimol_dataset import D3Dataset, D3Dataset_Pro
target_path = 'C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/3d-pubchem.lmdb'
dictionary = Dictionary.load('C:/Users/namjo/OneDrive/문서/GitHub/GeomCLIP/unimol_dict_mol.txt')
d3_dataset = D3Dataset(target_path, dictionary, max_atoms=256)  # Number of molecules : 301658
cid = lmdb_dataset._keys[1234].decode('utf-8')
atom_vec, coordinates, edge_type, dist, smiles = d3_dataset[cid]

def collate_tokens_coords(
    values,
    pad_idx,
    left_pad=False,
    pad_to_length=None,
    pad_to_multiple=1,
):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    size = size if pad_to_length is None else max(size, pad_to_length)
    if pad_to_multiple != 1 and size % pad_to_multiple != 0:
        size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
    res = values[0].new(len(values), size, 3).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
    return res

from unicore.data import data_utils
padded_atom_vec = data_utils.collate_tokens(atom_vec.unsqueeze(0), pad_idx=0, left_pad=False, pad_to_multiple=8) # shape = [batch_size, max_atoms]
padded_edge_type = data_utils.collate_tokens_2d(edge_type.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8) # shape = [batch_size, max_atoms, max_atoms]
padded_dist = data_utils.collate_tokens_2d(dist.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8) # shape = [batch_size, max_atoms, max_atoms]
padded_coordinates = collate_tokens_coords(coordinates.unsqueeze(0), 0, left_pad=False, pad_to_multiple=8) # shape = [batch_size, max_atoms, 3]


In [None]:
t1 = torch.tensor([[1, 1, 1],
                   [2, 2, 2]])        # shape (2, 3)

t2 = torch.tensor([[3, 3, 3]])        # shape (1, 3)

t3 = torch.tensor([[4, 4, 4],
                   [5, 5, 5],
                   [6, 6, 6]])        # shape (3, 3)

values = [t1, t2, t3]

res_multiple = collate_tokens_coords(
    values=[t1, t2, t3],
    pad_idx=0,
    left_pad=False,
    pad_to_length=None,
    pad_to_multiple=8
)
print("Result shape (multiple of 8):", res_multiple.shape)

In [None]:
import random
import torch

# Step 1: Collect data for a random batch of CIDs
random_cids = random.sample(lmdb_dataset._keys, 56)  # Choose 56 random keys
random_cids = [cid.decode("utf-8") for cid in random_cids]  # Decode the keys

batch = [d3_dataset[cid] for cid in random_cids]  # Retrieve tuples
atom_vecs, coordinates_list, edge_types, dists, smiles_list = zip(*batch)  # Unpack tuples

# Step 2: Use collate_tokens_coords to unify coordinates
pad_idx = 0  # Padding value for empty coordinates
unified_coordinates = collate_tokens_coords(
    values=coordinates_list,  # List of tensors
    pad_idx=pad_idx,          # Padding value
    left_pad=False,           # Right padding
    pad_to_length=None,       # No minimum length specified
    pad_to_multiple=512         # No rounding to multiples
)

# Step 3: Output the results
print("Unified Coordinates Shape:", unified_coordinates.shape)


In [48]:
class D3Dataset_cid(Dataset):
    def __init__(self, path, dictionary, max_atoms=256):
        self.dictionary = dictionary
        self.num_types = len(dictionary)
        self.bos = dictionary.bos()
        self.eos = dictionary.eos()

        self.lmdb_dataset = LMDBDataset_cid(path)

        self.max_atoms = max_atoms
        ## the following is the default setting of uni-mol's pretrained weights
        self.remove_hydrogen = True
        self.remove_polar_hydrogen = False
        self.normalize_coords = True
        self.add_special_token = True
        self.__max_atoms = 512

    def __len__(self):
        return len(self.lmdb_dataset)

    def __getitem__(self, cid):
        data = self.lmdb_dataset[cid]
        smiles = data['smiles']
        description = data['description']
        enriched_description = data['enriched_description']
        ## deal with 3d coordinates
        atoms_orig = np.array(data['atoms'])
        atoms = atoms_orig.copy()
        coordinate_set = data['coordinates']
        coordinates = random.sample(coordinate_set, 1)[0].astype(np.float32)
        assert len(atoms) == len(coordinates) and len(atoms) > 0
        assert coordinates.shape[1] == 3

        ## deal with the hydrogen
        if self.remove_hydrogen:
            mask_hydrogen = atoms != "H"
            if sum(mask_hydrogen) > 0:
                atoms = atoms[mask_hydrogen]
                coordinates = coordinates[mask_hydrogen]

        if not self.remove_hydrogen and self.remove_polar_hydrogen:
            end_idx = 0
            for i, atom in enumerate(atoms[::-1]):
                if atom != "H":
                    break
                else:
                    end_idx = i + 1
            if end_idx != 0:
                atoms = atoms[:-end_idx]
                coordinates = coordinates[:-end_idx]

        ## deal with cropping
        if len(atoms) > self.max_atoms:
            index = np.random.permutation(len(atoms))[:self.max_atoms]
            atoms = atoms[index]
            coordinates = coordinates[index]

        assert 0 < len(atoms) <= self.__max_atoms

        atom_vec = torch.from_numpy(self.dictionary.vec_index(atoms)).long()

        if self.normalize_coords:
            coordinates = coordinates - coordinates.mean(axis=0)

        if self.add_special_token:
            atom_vec = torch.cat([torch.LongTensor([self.bos]), atom_vec, torch.LongTensor([self.eos])])
            coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)

        ## obtain edge types; which is defined as the combination of two atom types
        edge_type = atom_vec.view(-1, 1) * self.num_types + atom_vec.view(1, -1)
        dist = distance_matrix(coordinates, coordinates).astype(np.float32)
        coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)
        return atom_vec, coordinates, edge_type, dist, smiles, description, enriched_description