In [2]:
%load_ext autoreload
%autoreload 2

## Model Inference

In [None]:
import os
from jax import random
from flax.core import FrozenDict

from lmkit.model import sampler, config as config_lib
from lmkit.tools import compat

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
params = FrozenDict(params)
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


prompts = [
    "Question: What is a Josephson junction?\nAnswer:",
    "Question: What is the highest point of the Pamirs?\nAnswer:",
]

sampler.generate(
    inputs=prompts,
    max_new_tokens=1000,
    tokenizer=tokenizer,
    params=params,
    config=config,
    random_key=random.key(0),
    return_text=True,
    verbose=True,
)


Cuda processing allowed: True


Loading safetensors: 100%|██████████| 4/4 [02:35<00:00, 38.78s/it]


## Custom Training

In [None]:
from lmkit.tools import data

datasource_file = "data/shakespeare.txt"
with open(datasource_file, "r") as f:
    text = f.read()
data_iter = text.split("\n")

batch_size=2048
dataset_dir = "data/dataset"

data.to_arrayrecord(
    data_iter = data_iter,
    out_dir=dataset_dir,
    encode_fn=lambda x: x.encode("utf-8"),
)

<grain._src.python.dataset.transformations.slice.SliceMapDataset at 0xf4fe7b5fc940>

In [None]:
import os
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import NFD, Lowercase, StripAccents, Sequence
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from typing import Optional, List, Iterator
import logging

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def train_bpe(
    iterator: Iterator[str],
    vocab_size: int,
    save_path: str,
    min_frequency: int = 2,
    special_tokens: Optional[List[str]] = None,
    add_prefix_space: bool = False,
) -> Tokenizer:
    unk_token = "<unk>"
    if special_tokens is None:
        special_tokens = [unk_token, "<pad>", "<bos>", "<eos>"]
    elif unk_token not in special_tokens:
        # Add unk_token if user provided a list without it, as BPE model needs it
        special_tokens = [unk_token] + special_tokens
        logging.warning(f"'{unk_token}' not found in special_tokens, adding it.")

    tokenizer = Tokenizer(BPE(unk_token=unk_token))

    # 2. Setup Normalizer (optional but recommended)
    # Normalizes text before tokenization (e.g., unicode, lowercase, accents)
    tokenizer.normalizer = Sequence(
        [
            NFD(),  # Unicode normalization decomposes characters
            Lowercase(),
            StripAccents(),
        ]
    )

    # 3. Setup PreTokenizer (splits text into initial words/tokens)
    # ByteLevel handles all bytes, good for diverse languages/data
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=add_prefix_space)

    # 4. Setup Decoder (to convert IDs back to text)
    # Must match the pre_tokenizer
    tokenizer.decoder = ByteLevelDecoder()

    # 5. Setup Trainer
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=min_frequency,
        special_tokens=special_tokens,
        # initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), # Use the ByteLevel alphabet
        # Other options like `show_progress=True` can be added
    )

    # 6. Train the tokenizer
    logging.info(f"Starting BPE tokenizer training...")
    logging.info(f"Vocab size: {vocab_size}, Min frequency: {min_frequency}")
    logging.info(f"Special tokens: {special_tokens}")
    logging.info(f"Saving to: {save_path}")

    tokenizer.train_from_iterator(iterator, trainer=trainer)

    logging.info("Training complete.")

    # 7. Save the tokenizer
    # Create directory if it doesn't exist
    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
        logging.info(f"Created directory: {save_dir}")

    tokenizer.save(save_path)
    logging.info(f"Tokenizer saved successfully to {save_path}")

    return tokenizer


# --- Example Usage ---
if __name__ == "__main__":
    tokenizer_dataset = data.grain_dataset_from(
        arrayrecord_dir=dataset_dir,
        batch_size=batch_size,
        map_fn=lambda x: x.decode("utf-8"),
    )
    data_iterator = iter(tokenizer_dataset)

    # 2. Define parameters
    VOCAB_SIZE = 2048  # Small vocab size for demo purposes
    MIN_FREQ = 2
    SAVE_PATH = "trained_bpe_tokenizer.json"

    # 3. Train the tokenizer
    trained_tokenizer = train_bpe_tokenizer(
        iterator=data_iterator,
        vocab_size=VOCAB_SIZE,
        save_path=SAVE_PATH,
        min_frequency=MIN_FREQ,
    )


2025-04-07 11:27:05,397 - INFO - Starting BPE tokenizer training...
2025-04-07 11:27:05,397 - INFO - Vocab size: 2048, Min frequency: 2
2025-04-07 11:27:05,397 - INFO - Special tokens: ['<unk>', '<pad>', '<bos>', '<eos>']
2025-04-07 11:27:05,398 - INFO - Saving to: trained_bpe_tokenizer.json






2025-04-07 11:27:08,009 - INFO - Training complete.
2025-04-07 11:27:08,017 - INFO - Tokenizer saved successfully to trained_bpe_tokenizer.json




--- Testing the trained tokenizer ---
Tokenizer loaded successfully from trained_bpe_tokenizer.json

Original text: 'This is a test sentence.'
Encoded IDs: [490, 112, 44, 57, 139, 1214, 204, 10]
Tokens: ['this', 'Ġis', 'Ġa', 'Ġt', 'est', 'Ġsent', 'ence', '.']
Decoded text: 'this is a test sentence.'

Vocabulary size: 2048


In [None]:
import jax.numpy as jnp
from flax.core import FrozenDict

from lmkit.model import trainer
from lmkit.tools import data

config = FrozenDict({
    "num_layers": 12,
    "num_heads": 12,
    "num_kv_heads": 12,
    "hidden_size": 768,
    "intermediate_size": 3072,
    "act_fn": "silu",
    "vocab_size": ...,
    "max_position_embeddings": 2048,
    "rope_base": 100_000,
    "io_tying": True,

})

batch_size = 4
num_steps = 500  # Small number for demo
log_granularity = 50
save_granularity = 200
ckpt_dir = "checkpoints"
dataset_dir = "data/dataset"

def batch_fn(x):
    encoded = tokenizer.encode_batch_fast(x)
    ids = [item.ids for item in encoded]
    batch_tokens = jnp.asarray(ids).astype(jnp.int32)
    positions = jnp.where(
        batch_tokens != tokenizer.pad_token_id, jnp.arange(batch_tokens.shape[1]), 0
    )
    return FrozenDict({
        "input_ids": batch_tokens[:, :-1],
        "positions": positions[:, :-1], # what if slicing gets into the picture
        "target_ids": batch_tokens[:, 1:],
    })


loaded_dataset = data.grain_dataset_from(
    arrayrecord_dir=dataset_dir,
    batch_size=batch_size,
    map_fn=lambda x: x.decode("utf-8"),
    batch_map_fn=batch_fn,
)

final_params, final_opt_state = trainer.train(
    config=config,
    data_iterator=data_iter,
    num_steps=num_steps,
    learning_rate=5e-4,
    log_every=log_granularity,
    save_every=save_granularity,
    checkpoint_dir=ckpt_dir,
)

print("\nTraining completed. Final parameters and optimizer state returned.")

NameError: name 'data_iter' is not defined