In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# define basic config

import jax
from flax.core import FrozenDict

jax.config.update("jax_debug_nans", True)

raw_dataset_dir = "data/guacamol_raw"
dataset_dir = "data/guacamol"
model_dir = "./models/molecular_demo"
ckpt_dir = "./checkpoints/molecular_demo"

tokenizer_path = f"{model_dir}/tokenizer.json"
gen_cfg_path = f"{model_dir}/generation_config.json"

batch_size = 2048

vocab_size = 1024
min_token_frequency=2

save_granularity = 200
log_granularity = 50

In [11]:
import os
from lmkit.tools import trainer, compat, data

# train & load tokenizer

if not os.path.exists(tokenizer_path):
    data_iterator = data.dir_line_iterator(raw_dataset_dir, verbose=True)

    trainer.train_tokenizer(
        data_iterator,
        vocab_size=vocab_size,
        save_dir=model_dir,
        generation_config=dict(),
        min_frequency=2,
        normalize=False
    )

tokenizer = compat.load_tokenizer(
    tokenizer_path=tokenizer_path,
    mode="train",
    generation_config_file=gen_cfg_path,
)

# create & load TFRecord dataset

if not os.path.exists(dataset_dir) or not os.listdir(dataset_dir):
    data.to_tfrecord(
        data_iter=data.dir_line_iterator(raw_dataset_dir),
        encode_fn=lambda x: x.encode("utf-8"),
        out_dir=dataset_dir,
    )


dataset = data.from_tfrecords_dir(
    dataset_dir,
    tokenizer=tokenizer,
    batch_size=batch_size,
    shuffle_buffer_size=10_000,
)

Applied 'train' post-processor (BOS+EOS).


In [13]:
# define model configuration

model_config = FrozenDict(
    {
        "num_layers": 12,
        "num_heads": 12,
        "num_kv_heads": 12,
        "hidden_size": 768,
        "intermediate_size": 3072,
        "act_fn": jax.nn.silu,
        "vocab_size": tokenizer.vocab_size,
        "max_position_embeddings": 2048,
        "rope_base": 100_000,
        "norm_eps": 1e-6,
        "io_tying": True,
    }
)

# train model

trainer.train_model(
    config=model_config,
    data_iterator=dataset.as_numpy_iterator(),
    num_steps=1_500_000//batch_size,
    learning_rate=1e-3,
    log_every=100,
    save_every=250,
    wandb_project="molecular_interpretability",
    wandb_run_name="guacamol_first_run"
)

2025-04-09 17:38:42,642 - INFO - Starting training...
2025-04-09 17:38:42,643 - INFO - Config: FrozenDict({
    num_layers: 12,
    num_heads: 12,
    num_kv_heads: 12,
    hidden_size: 768,
    intermediate_size: 3072,
    act_fn: <PjitFunction of <function silu at 0xe80af6c44160>>,
    vocab_size: 1024,
    max_position_embeddings: 2048,
    rope_base: 100000,
    norm_eps: 1e-06,
    io_tying: True,
})
2025-04-09 17:38:42,645 - INFO - Initializing new model parameters.
2025-04-09 17:38:42,910 - INFO - Created transformer with 114838272, hidden size: 768, layers: 12
2025-04-09 17:38:43,189 - INFO - Optimizer: Adam (lr=0.001)
2025-04-09 17:38:43,190 - INFO - Total parameters: 114,838,272
2025-04-09 17:38:43,205 - ERROR - Failed to initialize wandb: config must be a dict or have a __dict__ attribute.. Disabling wandb.


Training:   0%|          | 0/732 [00:00<?, ?it/s]

2025-04-09 17:40:54,837 - INFO - Checkpoint saved at step 48 to checkpoints/checkpoint_48.pkl
2025-04-09 17:41:42,202 - INFO - Checkpoint saved at step 96 to checkpoints/checkpoint_96.pkl
2025-04-09 17:41:58,426 - INFO - [Step 100/732] Loss: 1.6562, Acc: 0.540, Grad Norm: 0.7227, Valid Tokens: 66439, Time/Step: 0.748s, Tokens/Sec: 88815.01
2025-04-09 17:41:59,900 - INFO - [Step 100] Global Param Norm: 390.0000
2025-04-09 17:41:59,901 - INFO - [Step 100] Norm(embed_table): 11.3502
2025-04-09 17:42:45,316 - INFO - Checkpoint saved at step 144 to checkpoints/checkpoint_144.pkl
2025-04-09 17:43:33,024 - INFO - Checkpoint saved at step 192 to checkpoints/checkpoint_192.pkl
2025-04-09 17:43:38,912 - INFO - [Step 200/732] Loss: 1.3828, Acc: 0.602, Grad Norm: 0.6172, Valid Tokens: 66291, Time/Step: 0.733s, Tokens/Sec: 90402.02
2025-04-09 17:43:39,350 - INFO - [Step 200] Global Param Norm: 392.0000
2025-04-09 17:43:39,350 - INFO - [Step 200] Norm(embed_table): 12.4153
2025-04-09 17:44:09,329 - 

({'embed_table': Array([[0.0017395, 0.012085, -0.0205078, ..., -0.0144653, -0.0218506,
          0.00537109],
         [0.00759888, -0.00469971, -0.015625, ..., -0.00976562, -0.0149536,
          -0.0257568],
         [0.000453949, -0.010376, -0.00201416, ..., 0.00482178,
          -0.00485229, 0.0194092],
         ...,
         [0.0043335, 0.00457764, -0.0258789, ..., -0.0203857, -0.0251465,
          0.000362396],
         [0.00582886, 0.0184326, -0.0169678, ..., 0.000843048,
          -0.000862122, -0.000478745],
         [-0.0108643, 0.00485229, -0.00842285, ..., -0.0147705, -0.0206299,
          -0.0108032]], dtype=bfloat16),
  'layers': [{'attn': {'W_k': Array([[-0.00521851, -0.00921631, 0.0101318, ..., -0.0290527, 0.0644531,
             -0.0267334],
            [0.0112305, -0.0185547, -0.0106201, ..., -0.022583, 0.0463867,
             -0.0546875],
            [0.0045166, -0.0583496, 0.0422363, ..., -0.00363159, -0.0539551,
             -0.00328064],
            ...,
          

In [29]:
from lmkit.model import sampler
from jax import random
import jax.numpy as jnp

params, *others = trainer.load_checkpoint("checkpoints/checkpoint_732.pkl")

inputs = jnp.array([tokenizer.bos_token_id for _ in range(128)]).astype(jnp.int32)[:, None]

sequences = sampler.generate(
    tokenized_inputs=inputs,
    max_new_tokens=128,
    tokenizer=tokenizer,
    params=params,
    config=model_config,
    random_key=random.key(0),
    verbose=True,
    return_text=False,
)

eos_token = tokenizer.id_to_token(tokenizer.eos_token_id)
bos_token = tokenizer.id_to_token(tokenizer.bos_token_id)

decoded_sequences = [
    tokenizer.decode(sequence, skip_special_tokens=False)
    .split(bos_token)[1]
    .split(eos_token)[0]
    for sequence in sequences
]

2025-04-09 18:24:52,732 - INFO - Loaded checkpoint from step 732: checkpoints/checkpoint_732.pkl


  0%|          | 0/128 [00:00<?, ?it/s]

In [31]:
from lmkit.tools import stem

stats = stem.molstats(decoded_sequences)
stem.print_molstats(stats)

2025-04-09 18:27:10,492 - INFO - Processing batch of 128 SMILES...


2025-04-09 18:27:10,625 - INFO - Processed batch. Validity: 88.28%



--- Molecular Statistics ---
Total Molecules Processed: 128
Valid Molecules Found:     113
Validity Percentage:       88.28%
----------------------------
QED:
  Mean: 0.711
  Std:  0.178
  Min:  0.043
  Max:  0.945
Molecular Weight (MW):
  Mean: 366.192
  Std:  101.400
  Min:  120.060
  Max:  1059.228
LogP (Crippen):
  Mean: 3.560
  Std:  1.307
  Min:  -1.484
  Max:  6.539
Topological Polar Surface Area (TPSA):
  Mean: 67.427
  Std:  34.104
  Min:  17.070
  Max:  336.860
H-Bond Acceptors (HBA):
  Mean: 3.867
  Std:  1.772
  Min:  1.000
  Max:  12.000
H-Bond Donors (HBD):
  Mean: 1.150
  Std:  1.291
  Min:  0.000
  Max:  11.000
Rotatable Bonds:
  Mean: 4.044
  Std:  2.922
  Min:  1.000
  Max:  30.000
Number of Rings:
  Mean: 3.204
  Std:  0.942
  Min:  0.000
  Max:  6.000
Fraction Csp3:
  Mean: 0.261
  Std:  0.154
  Min:  0.000
  Max:  0.737
Lipinski Violations:
  Mean: 0.195
  Std:  0.513
  Min:  0.000
  Max:  3.000
----------------------------
Drug-Likeness (Lipinski Rule of 5):
  % 