In [1]:
from model.blip2_llama_inference import Blip2Llama
from model.unimol import SimpleUniMolModel
from rdkit import Chem
from rdkit.Chem import AllChem
from unicore.data import Dictionary
import numpy as np
import torch
from scipy.spatial import distance_matrix

  from .autonotebook import tqdm as notebook_tqdm


[2024-02-25 17:09:09,062] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
tensor_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

In [3]:
import argparse
def get_args():
    parser = argparse.ArgumentParser()
    ### models
    parser.add_argument('--bert_name', type=str, default='all_checkpoints/scibert_scivocab_uncased')
    parser.add_argument('--llm_model', type=str, default='all_checkpoints/llama-2-7b-hf')
    
    ### flash attention
    parser.add_argument('--enable_flash', action='store_false', default=False)

    ### lora settings
    parser.add_argument('--lora_r', type=int, default=8)
    parser.add_argument('--lora_alpha', type=int, default=32)
    parser.add_argument('--lora_dropout', type=int, default=0.1)
    parser.add_argument('--lora_path', type=str, default='all_checkpoints/generalist/generalist.ckpt')

    ### q-former settings
    parser.add_argument('--cross_attention_freq', type=int, default=2)
    parser.add_argument('--num_query_token', type=int, default=8)

    parser = SimpleUniMolModel.add_args(parser)

    args, unknown = parser.parse_known_args()
    return args
args = get_args()
args

Namespace(bert_name='scibert', cross_attention_freq=2, enable_flash=False, llm_model='all_checkpoints/llama-2-7b-hf', lora_alpha=32, lora_dropout=0.1, lora_path='all_checkpoints/generalist/generalist.ckpt', lora_r=8, num_query_token=8, unimol_activation_dropout=0.0, unimol_activation_fn='gelu', unimol_attention_dropout=0.1, unimol_delta_pair_repr_norm_loss=-1.0, unimol_dropout=0.1, unimol_emb_dropout=0.1, unimol_encoder_attention_heads=64, unimol_encoder_embed_dim=512, unimol_encoder_ffn_embed_dim=2048, unimol_encoder_layers=15, unimol_max_atoms=256, unimol_max_seq_len=512)

In [4]:
model = Blip2Llama(args).to(tensor_type)
device = torch.device("cpu")
model.to(device)
tokenizer = model.llm_tokenizer

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.75s/it]


loaded model from all_checkpoints/generalist/generalist.ckpt


In [5]:
def smiles2graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = AllChem.AddHs(mol)
    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
    if (np.asarray(atoms) == 'H').all():
        return None
    coordinate_list = []
    res = AllChem.EmbedMolecule(mol)
    if res == 0:
        try:
            AllChem.MMFFOptimizeMolecule(mol)
        except:
            pass
        coordinates = mol.GetConformer().GetPositions()
    elif res == -1:
        mol_tmp = Chem.MolFromSmiles(smiles)
        AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000)
        mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)
        try:
            AllChem.MMFFOptimizeMolecule(mol_tmp)
        except:
            pass
        coordinates = mol_tmp.GetConformer().GetPositions()
    coordinates = coordinates.astype(np.float32)
    assert len(atoms) == len(coordinates), "coordinates shape is not align with {}".format(smiles)
    assert coordinates.shape[1] == 3
    
    atoms = np.asarray(atoms)
    ## remove the hydrogen
    mask_hydrogen = atoms != "H"
    if sum(mask_hydrogen) > 0:
        atoms = atoms[mask_hydrogen]
        coordinates = coordinates[mask_hydrogen]

    ## atom vectors
    dictionary = Dictionary.load('data_provider/unimol_dict.txt')
    dictionary.add_symbol("[MASK]", is_special=True)
    atom_vec = torch.from_numpy(dictionary.vec_index(atoms)).long()

    ## normalize coordinates:
    coordinates = coordinates - coordinates.mean(axis=0)

    ## add_special_token:
    atom_vec = torch.cat([torch.LongTensor([dictionary.bos()]), atom_vec, torch.LongTensor([dictionary.eos()])])
    coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)
    
    ## obtain edge types; which is defined as the combination of two atom types
    edge_type = atom_vec.view(-1, 1) * len(dictionary) + atom_vec.view(1, -1)
    dist = distance_matrix(coordinates, coordinates).astype(np.float32)
    coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)

    return atom_vec, dist, edge_type, smiles


In [6]:
def sdf2graph(sdf_file):
    molecules = Chem.SDMolSupplier(sdf_file)
    for molecule in molecules:
        # Get SMILES
        smiles = Chem.MolToSmiles(molecule, canonical=True)

        # Get 3D Conformer if available
        conformers = []
        num_conformers = molecule.GetNumConformers()
        for i in range(num_conformers):
            conformer = molecule.GetConformer(i)
            conformers.append(conformer.GetPositions())

        # Get atoms and coordinates
        atoms = [atom.GetSymbol() for atom in molecule.GetAtoms()]
        coordinates = conformers[0].astype(np.float32)

        assert len(atoms) == len(coordinates), "coordinates shape is not align with {}".format(smiles)
        assert coordinates.shape[1] == 3

        atoms = np.asarray(atoms)
        ## remove the hydrogen
        mask_hydrogen = atoms != "H"
        if sum(mask_hydrogen) > 0:
            atoms = atoms[mask_hydrogen]
            coordinates = coordinates[mask_hydrogen]

        ## atom vectors
        dictionary = Dictionary.load('data_provider/unimol_dict.txt')
        dictionary.add_symbol("[MASK]", is_special=True)
        atom_vec = torch.from_numpy(dictionary.vec_index(atoms)).long()

        ## normalize coordinates:
        coordinates = coordinates - coordinates.mean(axis=0)

        ## add_special_token:
        atom_vec = torch.cat([torch.LongTensor([dictionary.bos()]), atom_vec, torch.LongTensor([dictionary.eos()])])
        coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)
        
        ## obtain edge types; which is defined as the combination of two atom types
        edge_type = atom_vec.view(-1, 1) * len(dictionary) + atom_vec.view(1, -1)
        dist = distance_matrix(coordinates, coordinates).astype(np.float32)
        coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)

        return atom_vec, dist, edge_type, smiles
    

In [7]:
def get_3d_graph(smiles=None, sdf_file=None):
    if sdf_file is not None:
        d3_graph = sdf2graph(sdf_file)
    elif smiles is not None:
        d3_graph = smiles2graph(smiles)
    else:
        raise ValueError('Either smiles or sdf_file must be provided')
    return d3_graph

In [8]:
def tokenize(tokenizer, text):
    text_tokens = tokenizer(text,
                            add_special_tokens=True,
                            return_tensors='pt',
                            return_attention_mask=True,
                            return_token_type_ids=True)
    is_mol_token = text_tokens.input_ids == tokenizer.mol_token_id
    text_tokens['is_mol_token'] = is_mol_token
    assert torch.sum(is_mol_token).item() == 8

    return text_tokens


In [11]:
# atom_vec, dist, edge_type, smiles = get_3d_graph(sdf_file='/data2/lish/3D-MoLM/MolChat/data/Conformer3D_COMPOUND_CID_1.sdf')
atom_vec, dist, edge_type, smiles = get_3d_graph(smiles='CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C')
atom_vec, dist, edge_type = atom_vec.unsqueeze(0), dist.unsqueeze(0).to(tensor_type), edge_type.unsqueeze(0)
atom_vec, dist, edge_type = atom_vec.to(device), dist.to(device), edge_type.to(device)
graph = (atom_vec, dist, edge_type)
prompt = "Below is an instruction that describes a task, paired with an input molecule. Write a response that appropriately completes the request.\n" \
         "Instruction: {}\n" \
         "Input molecule: {} <mol><mol><mol><mol><mol><mol><mol><mol>.\n" \
         "Response: "
instruction = "I need to know the LogP of this molecule, could you please provide it? If uncertain, provide an estimate. Respond with the numerical value only."
input = prompt.format(instruction, smiles)
input_tokens = tokenize(tokenizer, input)
input_tokens.to(device)

{'input_ids': tensor([[    2, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29892,
          3300,  2859,   411,   385,  1881, 13206, 29883,  1297, 29889, 14350,
           263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009, 29889,
            13,  3379,  4080, 29901,   306,   817,   304,  1073,   278, 13206,
         29883,  1297,  7688,   310,   445, 13206, 29883,  1297, 29892,  1033,
           366,  3113,  3867,   372, 29973,   960, 17999, 29892,  3867,   385,
         12678, 29889,  2538,  2818,   411,   278, 16259,   995,   871, 29889,
            13,  4290, 13206, 29883,  1297, 29901, 19178, 29898, 29922, 29949,
         29897, 29949, 29961, 29907, 25380, 29950,   850,  4174, 29898, 29922,
         29949,  9601, 29949, 29899,  2314, 29907, 29961, 29940, 29974,   850,
         29907,  5033, 29907, 29897, 29907, 29871, 32001, 32001, 32001, 32001,
         32001, 32001, 32001, 32001,   869,    13,  5103, 29901, 29871]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 

In [12]:
output = model.generate(graph, input_tokens)
print(output)

['(S)-O-(acetylmalonyl)-carnitine has a molecular weight of 287.33 g/mol. It is a relatively small molecule with a total of 206 atoms. The presence of the acetyl group contributes to its lipophilicity, allowing it to easily cross cellular membranes. The positive charge on the quaternary ammonium group enhances its solubility in aqueous environments. (S)-O-(Acetylmalonyl)-carnitine is a natural product found in Spar']
