In [1]:
%load_ext autoreload
%autoreload 2

## Model Inference

In [2]:
import os
from jax import random
from flax.core import FrozenDict

from lmkit.model import sampler, config as config_lib
from lmkit.tools import compat

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
params = FrozenDict(params)
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.to_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


prompts = [
    "Question: What is a Josephson junction?\nAnswer:",
    "Question: What is the highest point of the Pamirs?\nAnswer:",
]

sampler.generate(
    inputs=prompts,
    max_new_tokens=1000,
    tokenizer=tokenizer,
    params=params,
    config=config,
    random_key=random.key(0),
    return_text=True,
    verbose=True,
)


Cuda processing allowed: True


Loading safetensors: 100%|██████████| 4/4 [02:47<00:00, 41.96s/it]


  0%|          | 0/1000 [00:00<?, ?it/s]

['Question: What is a Josephson junction?\nAnswer: A Josephson junction is a device that consists of two superconducting materials separated by a thin layer of insulating material. When the two superconducting materials are cooled to a temperature below their critical temperature, a current can flow through the junction even though there is a potential difference between the two materials. This phenomenon is known as the Josephson effect, and it has been used in a variety of applications, including superconducting quantum interference devices (SQUIDs) and superconducting quantum computing devices.\nExplanation: A Josephson junction is a device that consists of two superconducting materials separated by a thin layer of insulating material. When the two superconducting materials are cooled to a temperature below their critical temperature, a current can flow through the junction even though there is a potential difference between the two materials. This phenomenon is known as the Josephs

## Custom Training

In [2]:
model_dir = "models/demo"
tokenizer_path = f"{model_dir}/tokenizer.json"
generation_config_path = f"{model_dir}/generation_config.json"

In [None]:
import os
from lmkit.tools import data, trainer, compat

# 1. Extract dataset

datasource_file = "data/shakespeare.txt"
with open(datasource_file, "r") as f:
    text = f.read()
data_iter = text.split("\n")

batch_size=2048
dataset_dir = "data/dataset"

if not os.path.exists(dataset_dir) or not os.listdir(dataset_dir):
    data.to_arrayrecord(
        data_iter = data_iter,
        out_dir=dataset_dir,
        encode_fn=lambda x: x.encode("utf-8"),
    )

# 2. Train tokenizer
vocab_size = 2048
min_frequency = 2

tokenizer_dataset = data.grain_dataset_from(
    arrayrecord_dir=dataset_dir,
    batch_size=batch_size,
    map_fn=lambda x: x.decode("utf-8"),
)
data_iterator = iter(tokenizer_dataset)

tokenizer = trainer.train_tokenizer(
    iterator=data_iterator,
    vocab_size=vocab_size,
    save_dir=model_dir,
    generation_config={},
    min_frequency=min_frequency,
)

tokenizer = compat.load_tokenizer(
    tokenizer_path=tokenizer_path,
    mode="train",
    generation_config_file=generation_config_path,
)





Applied 'train' post-processor (BOS+EOS).


In [None]:
import jax
import jax.numpy as jnp
from flax.core import FrozenDict
import logging
from lmkit.tools import data

jax.config.update("jax_debug_nans", True)

config = FrozenDict({
    "num_layers": 12,
    "num_heads": 12,
    "num_kv_heads": 12,
    "hidden_size": 768,
    "intermediate_size": 3072,
    "act_fn": jax.nn.silu,
    "vocab_size": tokenizer.vocab_size,
    "max_position_embeddings": 2048,
    "rope_base": 100_000,
    "norm_eps": 1e-6,
    "io_tying": True,

})

batch_size = 2048
num_steps = 500
log_granularity = 50
save_granularity = 200
ckpt_dir = "checkpoints"
dataset_dir = "data/dataset"

def batch_map_fn(batch_text, min_final_len: int = 2):
    if tokenizer.pad_token_id is None:
        raise ValueError("Tokenizer must have pad_token_id set for padding.")

    encoded = tokenizer.encode_batch_fast(batch_text)
    if not encoded:
        return None

    logging.info("Decoded: {tokenizer.decode_batch(encoded, skip_special_tokens=False)}")

    max_len = max(len(item.ids) for item in encoded) if encoded else 0
    if max_len == 0:
        logging.debug("Skipping batch: max sequence length after tokenization is 0.")
        return None

    ids = [
        item.ids + [tokenizer.pad_token_id] * (max_len - len(item.ids))
        for item in encoded
    ]
    initial_batch_tokens = jnp.array(ids, dtype=jnp.int32)

    current_len = initial_batch_tokens.shape[1]
    pad_amount = 1 - (current_len % 2)
    paddings = ((0, 0), (0, pad_amount))

    padded_tokens_for_slicing = jnp.pad(
        initial_batch_tokens,
        paddings,
        mode="constant",
        constant_values=tokenizer.pad_token_id,
    )
    odd_len = padded_tokens_for_slicing.shape[1]

    final_len = odd_len - 1
    if final_len < min_final_len:
        logging.debug(
            f"Skipping batch: final sequence length ({final_len}) < min_final_len ({min_final_len})."
        )
        return None

    positions_for_slicing = jnp.where(
        padded_tokens_for_slicing != tokenizer.pad_token_id, jnp.arange(odd_len), -1
    )

    input_ids = padded_tokens_for_slicing[:, :-1]
    input_positions = positions_for_slicing[:, :-1]
    target_ids = padded_tokens_for_slicing[:, 1:]

    return FrozenDict(
        {
            "input_ids": input_ids.astype(jnp.int32),
            "positions": input_positions.astype(jnp.int32),
            "target_ids": target_ids.astype(jnp.int32),
        }
    )


dataset = data.grain_dataset_from(
    arrayrecord_dir=dataset_dir,
    batch_size=batch_size,
    map_fn=lambda x: x.decode("utf-8"),
    batch_map_fn=batch_map_fn,
)
dataset = dataset.repeat(50)

final_params, final_opt_state = trainer.train_model(
    config=config,
    data_iterator=iter(dataset.to_iter_dataset()),
    num_steps=num_steps,
    learning_rate=1e-2,
    log_every=log_granularity,
    save_every=save_granularity,
    checkpoint_dir=ckpt_dir,
)

Training:   0%|          | 0/500 [00:00<?, ?it/s]


Training completed. Final parameters and optimizer state returned.


In [46]:
from lmkit.model import sampler
from jax import random

sampling_tokenizer = compat.load_tokenizer(
    tokenizer_path, mode="inference", generation_config_file=generation_config_path
)

prompts = [
    "First"
]

generated = sampler.generate(
    inputs=prompts,
    max_new_tokens=100,
    tokenizer=sampling_tokenizer,
    params=final_params,
    config=config,
    random_key=random.key(0),
    return_text=False,
    verbose=True,
)

print(sampling_tokenizer.decode(generated[0], skip_special_tokens=False))

Applied 'inference' post-processor (BOS only).


  0%|          | 0/100 [00:00<?, ?it/s]

<bos>first citizen:<eos>is,<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>st.<eos> not not you have i'll not,<eos> be the father, and<eos><eos> to,<eos><eos> a,<eos>.<eos> his good,<eos>,<eos><eos>, and be the king<eos>,<eos> to not, the good a own ded,<eos>;<eos>.<eos>,<eos>,<eos><eos>,<eos> to the.<eos>,<eos>,<eos>.<eos>s,<eos>, and a star
