In [1]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, LlamaTokenizer

from model.modeling_demolta import DeMOLTaFeaturizer, DeMOLTaCollateFn, DeMOLTaConfig, MOLLA, MOLLACollateFn


In [55]:
demolta_config = DeMOLTaConfig(
    num_layers=12,
    hidden_dim=384,
    ff_dim=1536,
    num_heads=6,
    layer_dropout=0.15
)

In [56]:
text_model_name = 'facebook/galactica-125m'

In [57]:
tokenizer = AutoTokenizer.from_pretrained(text_model_name)

In [58]:
if not tokenizer.pad_token:
    if tokenizer.eos_token:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.pad_token_id=2

Using pad_token, but it is not set yet.
Using eos_token, but it is not set yet.


In [59]:
smiles = ['CCO', 'CC1=CC=CC=C1']
queries = ['describe the molecule', 'describe the molecule']
answers = ['ethanol', 'benzene']

In [60]:
featurizer = DeMOLTaFeaturizer()
mol_feats = []
dataset = []
for smi, query, answer in zip(smiles, queries, answers):
    dataset.append({
        'mol_feats': featurizer(smiles=smi),
        'query': query,
        'answer': answer
    })
    

In [61]:
dl = DataLoader(dataset, batch_size=2, collate_fn=MOLLACollateFn(tokenizer))

In [62]:
for batch in dl:
    break

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [63]:
mola = MOLLA(demolta_config, text_model_name)

In [64]:
outputs = mola(
    input_ids=batch['input_ids'],
    input_attention_mask=batch['attention_mask'],
    atom_feats=batch['mols']['atom_feats'],
    bond_feats=batch['mols']['bond_feats'],
    attention_matrix_mask=batch['mols']['attention_mask'],
    labels=batch['labels']
)

In [68]:
tokenizer.decode(outputs[1].argmax(2)[0])

'a) name)Titleamine'

In [129]:
import os
import pandas as pd
import selfies as sf

from glob import glob
from tqdm.auto import tqdm
from datasets import load_dataset

In [130]:
mol_inst = load_dataset("zjunlp/Mol-Instructions", 'Molecule-oriented Instructions')

Downloading builder script:   0%|          | 0.00/7.32k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/55.1M [00:00<?, ?B/s]

Generating description_guided_molecule_design split: 0 examples [00:00, ? examples/s]

Generating forward_reaction_prediction split: 0 examples [00:00, ? examples/s]

Generating molecular_description_generation split: 0 examples [00:00, ? examples/s]

In [None]:
mol_inst.keys()

dict_keys(['description_guided_molecule_design', 'forward_reaction_prediction', 'molecular_description_generation', 'property_prediction', 'reagent_prediction', 'retrosynthesis'])

In [None]:
pretrain_df = pd.DataFrame(columns=['smiles', 'query', 'answer'])
smiles = []
query = []
answer = []

In [None]:
for data in tqdm(mol_inst['molecular_description_generation']):
    smiles.append(sf.decoder(data['input']))
    query.append(data['instruction'])
    answer.append(data['output'])

In [None]:
for data in tqdm(mol_inst['property_prediction']):
    smiles.append(sf.decoder(data['input']))
    query.append(data['instruction'])
    answer.append(data['output'])

  0%|          | 0/401229 [00:00<?, ?it/s]

In [None]:
for data in tqdm(mol_inst['retrosynthesis']):
    smiles.append(sf.decoder(data['input']))
    query.append(data['instruction'])
    output_sfs = data['output'].split('.')
    answer.append('.'.join(list(map(sf.decoder, output_sfs))))

  0%|          | 0/143535 [00:00<?, ?it/s]

In [None]:
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="haitengzhao/molecule_property_instruction", repo_type="dataset")

Fetching 86 files:   0%|          | 0/86 [00:00<?, ?it/s]

Downloading (…)25b9608a26fe.parquet:   0%|          | 0.00/4.30M [00:00<?, ?B/s]

Downloading (…)e98748a9e08c.parquet:   0%|          | 0.00/26.8M [00:00<?, ?B/s]

Downloading (…)5153a36b8742.parquet:   0%|          | 0.00/26.8M [00:00<?, ?B/s]

Downloading (…)12965f685bf8.parquet:   0%|          | 0.00/27.0M [00:00<?, ?B/s]

Downloading (…)4e02cc715d5e.parquet:   0%|          | 0.00/27.0M [00:00<?, ?B/s]

Downloading (…)11b2bb06c049.parquet:   0%|          | 0.00/27.8M [00:00<?, ?B/s]

Downloading (…)3351f8a60b78.parquet:   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Downloading (…)28deba87c4ba.parquet:   0%|          | 0.00/26.3M [00:00<?, ?B/s]

Downloading (…)3700c9b98345.parquet:   0%|          | 0.00/26.6M [00:00<?, ?B/s]

Downloading (…)a69868f95a9b.parquet:   0%|          | 0.00/28.1M [00:00<?, ?B/s]

Downloading (…)88d7611fe2eb.parquet:   0%|          | 0.00/28.0M [00:00<?, ?B/s]

Downloading (…)2784da21488e.parquet:   0%|          | 0.00/26.4M [00:00<?, ?B/s]

Downloading (…)1becf693b5d3.parquet:   0%|          | 0.00/26.3M [00:00<?, ?B/s]

Downloading (…)1632de3c820b.parquet:   0%|          | 0.00/26.7M [00:00<?, ?B/s]

Downloading (…)b9ccbcbe2515.parquet:   0%|          | 0.00/26.3M [00:00<?, ?B/s]

Downloading (…)508e383a0340.parquet:   0%|          | 0.00/27.4M [00:00<?, ?B/s]

Downloading (…)6d8b5f034e53.parquet:   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Downloading (…)d5d7b17ca368.parquet:   0%|          | 0.00/27.4M [00:00<?, ?B/s]

Downloading (…)7ecc771c0271.parquet:   0%|          | 0.00/27.0M [00:00<?, ?B/s]

Downloading (…)5fb809c5bded.parquet:   0%|          | 0.00/28.1M [00:00<?, ?B/s]

Downloading (…)ba2407d9f51b.parquet:   0%|          | 0.00/27.4M [00:00<?, ?B/s]

Downloading (…)d849833388e1.parquet:   0%|          | 0.00/27.3M [00:00<?, ?B/s]

Downloading (…)c0db615c9374.parquet:   0%|          | 0.00/26.9M [00:00<?, ?B/s]

Downloading (…)eb7b755383db.parquet:   0%|          | 0.00/26.5M [00:00<?, ?B/s]

Downloading (…)d2038f937f55.parquet:   0%|          | 0.00/27.1M [00:00<?, ?B/s]

Downloading (…)dd0631ba428c.parquet:   0%|          | 0.00/2.37M [00:00<?, ?B/s]

Downloading (…)80b7dc9002d0.parquet:   0%|          | 0.00/11.8M [00:00<?, ?B/s]

Downloading (…)45b6733cfa31.parquet:   0%|          | 0.00/12.1M [00:00<?, ?B/s]

Downloading (…)cb7e7ea12d7b.parquet:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Downloading (…)8c034c039f6c.parquet:   0%|          | 0.00/12.9M [00:00<?, ?B/s]

'C:\\Users\\dust\\.cache\\huggingface\\hub\\datasets--haitengzhao--molecule_property_instruction\\snapshots\\aad5c7578e811e1614be9430095de0c431485cf3'

In [118]:
bace_parquet = glob('C:\\Users\\dust\\.cache\\huggingface\\hub\\datasets--haitengzhao--molecule_property_instruction\\snapshots\\aad5c7578e811e1614be9430095de0c431485cf3\\data\\bace*')[0]

In [121]:
bace_df = pd.read_parquet(bace_parquet)

In [125]:
for idx, row in tqdm(bace_df.iterrows(), total=len(bace_df)):
    if row['split'] != 'train':
        continue
    for q in row['text']:
        smiles.append(row['graph'])
        query.append(q)
        answer.append(row['label'])

  0%|          | 0/1513 [00:00<?, ?it/s]

In [126]:
bbbp_parquet = glob('C:\\Users\\dust\\.cache\\huggingface\\hub\\datasets--haitengzhao--molecule_property_instruction\\snapshots\\aad5c7578e811e1614be9430095de0c431485cf3\\data\\bbbp*')[0]

In [127]:
bbbp_df = pd.read_parquet(bbbp_parquet)

In [128]:
bbbp_df

Unnamed: 0,graph,text,label,dataset_name,task_index,molecule_index,split
0,[Cl].CC(C)NCC(O)COc1cccc2ccccc12,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,0,train
1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,1,train
2,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,2,train
3,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,3,train
4,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,4,train
...,...,...,...,...,...,...,...
2034,C1=C(Cl)C(=C(C2=C1NC(=O)C(N2)=O)[N+](=O)[O-])Cl,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,2034,train
2035,[C@H]3([N]2C1=C(C(=NC=N1)N)N=C2)[C@@H]([C@@H](...,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,2035,train
2036,[O+]1=N[N](C=C1[N-]C(NC2=CC=CC=C2)=O)C(CC3=CC=...,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,2036,train
2037,C1=C(OC)C(=CC2=C1C(=[N+](C(=C2CC)C)[NH-])C3=CC...,"[In general, molecules that passively diffuse ...",Yes,bbbp,0,2037,train
