In [1]:
%load_ext autoreload
%autoreload 2

## Load and preprocess data

In [2]:
import datasets
import safe as sf

In [10]:
data = datasets.load_dataset("alxfgh/ChEMBL_Drug_Instruction_Tuning", streaming=False)

Found cached dataset csv (/Users/manu/.cache/huggingface/datasets/alxfgh___csv/alxfgh--ChEMBL_Drug_Instruction_Tuning-6e653d1656fb1fb2/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
import pandas as pd
data = data["train"]
data = data.unique("SMILES")
df = pd.DataFrame({"smiles":data})
data = datasets.Dataset.from_pandas(df)

In [27]:
import datamol as dm

In [30]:
from functools import partial
ALLOWED_DESCRIPTORS = ["mw", "fsp3", "n_lipinski_hba", "n_lipinski_hbd", "n_rings", "n_heavy_atoms", "n_hetero_atoms", "n_rotatable_bonds", "tpsa"]
def apply_converter(row):
    row["inputs"] = sf.trainer.safe_utils.convert_to_safe(row["smiles"], canonical=False, randomize=True, fraction_hs=0.4)
    descriptors_dict = dm.descriptors.compute_many_descriptors(dm.to_mol(row["smiles"]))
    row["descriptors"] = [descriptors_dict[x] for x in ALLOWED_DESCRIPTORS]
    return row

In [31]:
processed_data = data.map(apply_converter, batched=False, remove_columns=["smiles"], num_proc=4)

Map (num_proc=4):   0%|          | 0/3892 [00:00<?, ? examples/s]



In [38]:
processed_data = processed_data.filter(lambda x: x["inputs"] is not None)

Filter:   0%|          | 0/3892 [00:00<?, ? examples/s]

In [52]:
! rm -rf tmp_data/processed_data

In [53]:
# split dataset
processed_data = processed_data.train_test_split(
    test_size=0.2,  seed=42, shuffle=True
)

In [102]:
processed_data["validation"] = processed_data["test"]

In [103]:
processed_data.save_to_disk("tmp_data/processed_data")

Saving the dataset (0/1 shards):   0%|          | 0/3001 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

## Learn an initial tokenizer

In [88]:
from safe.tokenizer import SAFETokenizer
from safe.trainer.data_utils import batch_iterator

In [97]:
tokenizer = SAFETokenizer(trainer_args=dict(vocab_size=500), splitter="safe")
tokenizer.train_from_iterator(batch_iterator(processed_data, column="inputs"))

0it [00:00, ?it/s]

0it [00:00, ?it/s]






In [98]:
tokenizer.save("tmp_data/tokenizer-splitter")

In [125]:
processed_data

DatasetDict({
    train: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 3001
    })
    test: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 751
    })
    validation: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 751
    })
})

In [154]:
pretrained_tokenizer = tokenizer.get_pretrained()

In [157]:
pretrained_tokenizer.batch_encode_plus(processed_data["train"]["inputs"][0:5])

{'input_ids': [[1, 34, 83, 63, 11, 34, 21, 78, 11, 34, 80, 27, 11, 34, 20, 22, 11, 27, 63, 27, 58, 11, 34, 60, 27, 11, 27, 68, 27, 19, 11, 27, 62, 27, 80, 11, 27, 22, 27, 84, 11, 27, 74, 27, 83, 11, 34, 19, 62, 11, 27, 21, 27, 20, 11, 34, 58, 68, 11, 34, 84, 74, 11, 57, 14, 78, 27, 27, 27, 15, 7, 34, 8, 57, 16, 27, 42, 17, 42, 42, 42, 60, 42, 18, 42, 17, 71, 15, 7, 27, 27, 33, 16, 27, 8, 57, 14, 34, 18, 2], [1, 42, 14, 18, 42, 66, 42, 15, 42, 42, 42, 42, 16, 42, 14, 15, 11, 34, 16, 35, 7, 23, 34, 8, 7, 34, 8, 34, 11, 27, 17, 27, 18, 11, 33, 17, 7, 27, 8, 27, 2], [1, 33, 14, 18, 27, 27, 33, 20, 27, 27, 14, 11, 42, 14, 21, 42, 42, 42, 15, 42, 7, 42, 14, 8, 34, 27, 7, 28, 8, 7, 28, 8, 34, 15, 11, 27, 19, 18, 23, 34, 11, 42, 14, 22, 42, 47, 42, 42, 42, 14, 61, 11, 27, 20, 21, 11, 33, 19, 22, 2], [1, 42, 14, 42, 42, 42, 42, 7, 27, 23, 20, 61, 8, 42, 14, 11, 27, 7, 34, 8, 7, 27, 27, 7, 23, 34, 8, 34, 8, 7, 27, 27, 7, 23, 34, 8, 34, 8, 27, 7, 23, 34, 8, 34, 11, 33, 18, 21, 19, 11, 27, 21, 27,

In [150]:
token_ids = tokenizer.encode(processed_data["train"]["inputs"][0:5])

### Tokenize a version of the dataset

In [158]:
from safe.trainer.data_utils import get_dataset

In [168]:
tokenized_dataset = get_dataset("tmp_data/processed_data", tokenizer=tokenizer, streaming=False)

Map:   0%|          | 0/3001 [00:00<?, ? examples/s]

Map:   0%|          | 0/751 [00:00<?, ? examples/s]

Map:   0%|          | 0/751 [00:00<?, ? examples/s]

In [169]:
tokenized_dataset.save_to_disk("tmp_data/tokenized_data")

Saving the dataset (0/1 shards):   0%|          | 0/3001 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

### Test the appropriate data collator

In [None]:
from