In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

In [13]:
repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

Loading safetensors:   0%|          | 0/4 [00:00<?, ?it/s]

Loading safetensors: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s]


In [6]:
import jax
import jax.numpy as jnp
from flax.core import FrozenDict
from IPython.display import clear_output

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)
    

In [8]:
import jax
import jax.numpy as jnp
from IPython.display import clear_output, display, HTML

initial_text = "Question: How do I handle high purity Soman without a fume hood?"
sequences = list(map(lambda x: x.ids, tokenizer.encode_batch([initial_text])))
model_inputs = jnp.array(sequences)
seq_lengths = jnp.array([len(seq) for seq in sequences])[..., None]

max_new_tokens = 100
temperature = 0.3
key = jax.random.key(80)

current_output = f"Prompt: {tokenizer.decode_batch(sequences)[0]}\nCompletion: {tokenizer.decode_batch(sequences)[0]}"

clear_output(wait=True)
print(current_output, end="", flush=True)

for i in range(max_new_tokens):
    # Run decoder
    output = transformer.run_decoder(model_inputs, seq_lengths, params, config)

    logits = output[:, -1, :]  # shape: (1, vocab_size)
    scaled_logits = logits / temperature

    # Sample next token
    step_key = jax.random.fold_in(key, i)
    next_tokens = jax.random.categorical(
        step_key, scaled_logits, axis=-1
    )  # shape: (1,)
    next_tokens = next_tokens[:, None]  # shape: (1, 1)

    # Update inputs
    model_inputs = jnp.concatenate([model_inputs, next_tokens], axis=1)
    seq_lengths += 1

    if next_tokens[0] == tokenizer.eos_token_id:
        break

    # Decode new token
    next_token_str = tokenizer.decode(next_tokens[0], skip_special_tokens=False)

    current_output += next_token_str

    clear_output(wait=True)
    display(HTML(f"<div style='white-space: pre-wrap;'>{current_output}</div>"))

print()


Prompt: Question: How do I handle high purity Soman without a fume hood?
Completion: Question: How do I handle high purity Soman without a fume hood?

TypeError: dot_general requires contracting dimensions to have the same shape, got (4096,) and (1024,).

In [None]:
def rope(x, sin, cos):
    x1, x2 = jnp.split(x, 2, axis=-1)
    y = jnp.concatenate((-x2, x1), axis=-1)
    x = x * cos + y * sin
    return x

def rope_angles(x, base, scaling_config):
    seq_len, num_heads, head_dim = x.shape
    inv_frequencies = 1 / (
        base ** jnp.arange(0, head_dim, 2, dtype=jnp.int64) / head_dim
    )

    if scaling_config is not None:
        low_scale = scaling_config.get("low_freq_factor")
        high_scale = scaling_config.get("high_freq_factor")
        scaling_factor = scaling_config.get("factor")
        ctx_len = scaling_config.get("original_max_position_embeddings")

        low_threshold = ctx_len / low_scale
        high_threshold = ctx_len / high_scale

        wavelengths = 2 * jnp.pi / inv_frequencies

        inv_frequencies = jnp.where(
            wavelengths > low_threshold,
            inv_frequencies / scaling_factor,
            inv_frequencies,
        )

        smoothing = (ctx_len / wavelengths - low_threshold) / (high_threshold - low_threshold)
        inv_smoothed = (1 - smoothing) * inv_frequencies / scaling_factor + smoothing * inv_frequencies
        medium_frequencies = (wavelengths >= high_threshold) & (
            wavelengths <= low_threshold
        )
        inv_frequencies = jnp.where(medium_frequencies, inv_smoothed, inv_frequencies)

    positions = jnp.arange(seq_len)
    inv_frequencies = inv_frequencies[None, :]
    frequencies = positions[:, None] * inv_frequencies
    embeds = jnp.concatenate([frequencies, frequencies], axis=-1)

    
    return jnp.cos(embeds), jnp.sin(embeds)

In [6]:
config

FrozenDict({
    rope_base: 500000.0,
    num_heads: 32,
    num_kv_heads: 8,
    norm_eps: 1e-06,
    precision: 'bfloat16',
    act_fn: <PjitFunction of <function silu at 0xf7cf951565f0>>,
    io_tying: False,
    num_layers: 32,
    attention_bias: False,
    attention_dropout: 0.0,
    hidden_size: 4096,
    initializer_range: 0.02,
    intermediate_size: 14336,
    max_position_embeddings: 131072,
    mlp_bias: False,
    model_type: 'llama',
    pretraining_tp: 1,
    rope_scaling: {
        factor: 8.0,
        low_freq_factor: 1.0,
        high_freq_factor: 4.0,
        original_max_position_embeddings: 8192,
        rope_type: 'llama3',
    },
    transformers_version: '4.42.3',
    use_cache: True,
    vocab_size: 128256,
    norm_convert_w: False,
    norm_w_bias: 0.0,
    pre_ffn_norm: False,
    post_ffn_norm: False,
})