In [1]:
%load_ext autoreload
%autoreload 2

## Load and preprocess data

In [2]:
import datasets
import safe as sf

In [3]:
data = datasets.load_dataset("alxfgh/ChEMBL_Drug_Instruction_Tuning", streaming=False)

Downloading readme:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

Downloading and preparing dataset csv/alxfgh--ChEMBL_Drug_Instruction_Tuning to /home/emmanuel/.cache/huggingface/datasets/alxfgh___csv/alxfgh--ChEMBL_Drug_Instruction_Tuning-6e653d1656fb1fb2/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /home/emmanuel/.cache/huggingface/datasets/alxfgh___csv/alxfgh--ChEMBL_Drug_Instruction_Tuning-6e653d1656fb1fb2/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
import pandas as pd
data = data["train"]
data = data.unique("SMILES")
df = pd.DataFrame({"smiles":data})
data = datasets.Dataset.from_pandas(df)

In [5]:
import datamol as dm

In [6]:
from functools import partial
ALLOWED_DESCRIPTORS = ["mw", "fsp3", "n_lipinski_hba", "n_lipinski_hbd", "n_rings", "n_heavy_atoms", "n_hetero_atoms", "n_rotatable_bonds", "tpsa"]
def apply_converter(row):
    row["inputs"] = sf.utils.convert_to_safe(row["smiles"], canonical=False, randomize=True, fraction_hs=0.4)
    descriptors_dict = dm.descriptors.compute_many_descriptors(dm.to_mol(row["smiles"]))
    row["descriptors"] = [descriptors_dict[x] for x in ALLOWED_DESCRIPTORS]
    return row

In [7]:
processed_data = data.map(apply_converter, batched=False, remove_columns=["smiles"], num_proc=4)

Map (num_proc=4):   0%|          | 0/3892 [00:00<?, ? examples/s]



In [8]:
processed_data = processed_data.filter(lambda x: x["inputs"] is not None)

Filter:   0%|          | 0/3892 [00:00<?, ? examples/s]

In [9]:
! rm -rf tmp_data/processed_data

In [10]:
# split dataset
processed_data = processed_data.train_test_split(
    test_size=0.2,  seed=42, shuffle=True
)

In [11]:
processed_data["validation"] = processed_data["test"]

In [12]:
processed_data.save_to_disk("tmp_data/proc_data")

Saving the dataset (0/1 shards):   0%|          | 0/3001 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/751 [00:00<?, ? examples/s]

## Learn an initial tokenizer

In [13]:
import datasets
from safe.tokenizer import SAFETokenizer
from safe.trainer.data_utils import batch_iterator

In [10]:
processed_data = datasets.load_from_disk("tmp_data/proc_data")

In [11]:
tokenizer = SAFETokenizer(trainer_args=dict(vocab_size=500), splitter=None)
tokenizer.train_from_iterator(batch_iterator(processed_data, column="inputs"))

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]






In [12]:
tokenizer.save("tmp_data/tokenizer-no-splitter")

In [151]:
processed_data

DatasetDict({
    train: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 3001
    })
    test: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 751
    })
    validation: Dataset({
        features: ['inputs', 'descriptors'],
        num_rows: 751
    })
})

In [152]:
pretrained_tokenizer = tokenizer.get_pretrained()

### Tokenize a version of the dataset

In [10]:
from safe.trainer.data_utils import get_dataset

In [153]:
tokenized_dataset = get_dataset("tmp_data/proc_data", tokenizer=tokenizer, streaming=False)

Map:   0%|          | 0/3001 [00:00<?, ? examples/s]

Map:   0%|          | 0/751 [00:00<?, ? examples/s]

Map:   0%|          | 0/751 [00:00<?, ? examples/s]

### Test the appropriate data collator

In [154]:
from safe.trainer.collator import SAFECollator
from torch.utils.data import DataLoader

In [17]:
data_collator = SAFECollator(tokenizer=tokenizer)
dataloader = DataLoader(tokenized_dataset["train"], collate_fn=data_collator, batch_size=4)
for batch in dataloader:
    break
batch

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


## Test the training framework


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
! pip install wandb
import wandb
wandb.login()



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaclandrol[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import uuid
import safe
import os
import torch
import transformers
from transformers import (
    AutoConfig,
    set_seed,
)
from loguru import logger
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.tokenizer import SAFETokenizer
from safe.trainer.data_utils import get_dataset
from safe.trainer.collator import SAFECollator
from safe.trainer.trainer_utils import SAFETrainer
CURRENT_DIR = os.path.join(safe.__path__[0], "trainer")

In [4]:
%env WANDB_LOG_MODEL=end
%env WANDB_WATCH=all
%env WANDB_PROJECT=safe-project

env: WANDB_LOG_MODEL=end
env: WANDB_WATCH=all
env: WANDB_PROJECT=safe-project


In [5]:
# params
config = None
tokenizer_path = "tmp_data/tokenizer-splitter"
dataset_path = "tmp_data/proc_data"
model_path = None
is_tokenized=False
prop_loss_coeff =  1
dtype = "auto"
ddp = False
gradient_accumulation_steps = 1
wandb_watch = None
wandb_run_name = f"safe-model-{uuid.uuid4().hex[:8]}"
batch_size = 32
warmup_steps = 10
num_epochs = 10
learning_rate = 1e-5

num_labels = 9
logging_steps = 1
output_dir = "tmp_data/training/"
num_workers = 4
max_steps = 10
cache_dir = None

In [6]:
tokenizer = SAFETokenizer.load(tokenizer_path)
training_args = transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        logging_steps=logging_steps,
        optim="adamw_torch",
        output_dir=output_dir,
        report_to="wandb",
        run_name=wandb_run_name,
        dataloader_num_workers=num_workers,
        #save_safetensors=True,
        #torch_compile=False,
        max_steps=max_steps,
)

# load dataset
with training_args.main_process_first():
    dataset = get_dataset(
        dataset_path, tokenizer=(None if is_tokenized else tokenizer), streaming=False
    )

data_collator = SAFECollator(tokenizer=tokenizer)

Loading cached processed dataset at /Users/manu/Code/safe/expts/notebook/tmp_data/proc_data/train/cache-a924afbc58f161eb.arrow
Loading cached processed dataset at /Users/manu/Code/safe/expts/notebook/tmp_data/proc_data/test/cache-6a6761b8806d6d7e.arrow
Loading cached processed dataset at /Users/manu/Code/safe/expts/notebook/tmp_data/proc_data/validation/cache-6a6761b8806d6d7e.arrow


In [None]:
if config is None:
    config = os.path.join(CURRENT_DIR, "configs/default_config.json")
config = AutoConfig.from_pretrained(config, cache_dir=cache_dir)
setattr(config, 'num_labels', num_labels or 0 )

config.vocab_size = len(tokenizer)
config.bos_token_id = tokenizer.bos_token_id
config.eos_token_id = tokenizer.eos_token_id
config.pad_token_id = tokenizer.pad_token_id
torch_dtype = dtype if dtype in ["auto", None] else getattr(torch, dtype)

if model_path is not None:
    model = SAFEDoubleHeadsModel.from_pretrained(
        model_path,
        config=config,
        cache_dir=cache_dir,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
    )
else:
    model = SAFEDoubleHeadsModel(config)

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")


trainer = SAFETrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    args=training_args,
    data_collator=data_collator,
    prop_loss_coeff=prop_loss_coeff,
)

trainer.train()
trainer.save_state()

In [7]:
wandb.finish()

0,1
train/epoch,▁▂▂▃▄▄▅▇▇██
train/global_step,▁▂▃▃▄▅▆▆▇██
train/learning_rate,▂▃▃▄▅▆▆▇█▁
train/loss,▂▃▃▆▂▁▂▂█▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,0.11
train/global_step,10.0
train/learning_rate,0.0
train/loss,20946.1562
train/total_flos,42396433250688.0
train/train_loss,41157.02891
train/train_runtime,683.6371
train/train_samples_per_second,0.468
train/train_steps_per_second,0.015


In [13]:
# %%bash

# safe-train --tokenizer  "tmp_data/tokenizer-splitter" \
#     --dataset "tmp_data/proc_data" \
#     --num_labels 9 \
#     --torch_compile False \
#     --optim "adamw_torch" \
#     --learning_rate 1e-5 \
#     --prop_loss_coeff 1e-3 \
#     --gradient_accumulation_steps 1 \
#     --is_tokenized False \
#     --output_dir "tmp_data/test/" \
#     --wandb_project "safe-project" \
#     --max_steps 5

    