In [None]:
from typing import Dict, List, NamedTuple, Optional, Tuple

import os
import jax
os.environ['JAX_PLATFORM_NAME'] = 'cuda'  # Force CUDA/GPU backend
jax.config.update('jax_platform_name', 'cuda')  # Double ensure we're using CUDA
import jax.numpy as jnp

# Set Model ID and Token

In [None]:
MODEL_ID = 'meta-llama/Llama-3.2-1B-Instruct'
TOKEN = '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'

# Config

In [None]:
params = {
  "dim": 2048,
  "n_layers": 16,
  "n_heads": 32,
  "n_kv_heads": 8,
  "vocab_size": 128256,
  "ffn_dim_multiplier": 1.5,
  "multiple_of": 256,
  "norm_eps": 1e-05,
  "rope_theta": 500000.0,
  "use_scaled_rope": True,
  "max_seq_len": 4096
}


class ModelParams(NamedTuple):
  n_layers: int
  n_local_heads: int
  n_local_kv_heads: int
  head_dim: int
  max_seq_len: int
  rope_theta: float
  use_scaled_rope: bool


LLAMA_1B_PARAMS = ModelParams(
  n_layers=params["n_layers"],
  n_local_heads=params["n_heads"],
  n_local_kv_heads=params["n_kv_heads"],
  head_dim=params["dim"] // params["n_heads"],
  max_seq_len=params["max_seq_len"],
  rope_theta=params["rope_theta"],
  use_scaled_rope=params["use_scaled_rope"]
)

# Download Weights

In [None]:
import os
import torch
import ml_dtypes
from pathlib import Path

from transformers import AutoModelForCausalLM
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports

In [None]:
from huggingface_hub import login

myToken = "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"

login(myToken)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
def translate_key(in_key: str):
    out_key = in_key.replace('.weight', '')
    if out_key.startswith('model.'):
        out_key = out_key.replace('model.', '')
        if out_key.endswith('input_layernorm'):
            out_key = out_key.replace('input_layernorm', 'attention_norm')
        elif out_key.endswith('mlp.down_proj'):
            out_key = out_key.replace('mlp.down_proj', 'feed_forward.w2')
        elif out_key.endswith('mlp.gate_proj'):
            out_key = out_key.replace('mlp.gate_proj', 'feed_forward.w1')
        elif out_key.endswith('mlp.up_proj'):
            out_key = out_key.replace('mlp.up_proj', 'feed_forward.w3')
        elif out_key.endswith('post_attention_layernorm'):
            out_key = out_key.replace('post_attention_layernorm', 'ffn_norm')
        elif out_key.endswith('self_attn.k_proj'):
            out_key = out_key.replace('self_attn.k_proj', 'attention.wk')
        elif out_key.endswith('self_attn.o_proj'):
            out_key = out_key.replace('self_attn.o_proj', 'attention.wo')
        elif out_key.endswith('self_attn.q_proj'):
            out_key = out_key.replace('self_attn.q_proj', 'attention.wq')
        elif out_key.endswith('self_attn.v_proj'):
            out_key = out_key.replace('self_attn.v_proj', 'attention.wv')
        elif out_key.endswith('down_proj'):
            out_key = out_key.replace('down_proj', 'w2')
        elif out_key.endswith('gate_proj'):
            out_key = out_key.replace('gate_proj', 'w1')
        elif out_key.endswith('up_proj'):
            out_key = out_key.replace('up_proj', 'w3')
        elif out_key == 'embed_tokens':
            out_key = 'tok_embeddings'
        elif out_key == 'norm':
            out_key = 'norm'
        else:
            print(f"Don't know how to handle {in_key=}")
    elif out_key == 'lm_head':
        out_key = 'output'
    else:
        print(f"Don't know how to handle {in_key=}")
    return f'{out_key}.weight'


def reverse_permute(tensor: torch.Tensor, n_heads: int = 32, dim1:int = 4096, dim2: int = 4096) -> torch.Tensor:
    return tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)


def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    if not str(filename).endswith("/modeling_deepseek.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports


def download_weights(model_id: str = MODEL_ID, out_dir: Path = Path('weights/1B-Instruct')):
    if not out_dir.exists():
        out_dir.mkdir(parents=True, exist_ok=True)

    with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
      hf_model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.bfloat16, offload_folder="/tmp/offload", token=myToken)
      with torch.no_grad():
        state_dict = hf_model.state_dict()
        for hf_name, param in state_dict.items():
            print(f' {hf_name}: {param.shape=}')
            name = translate_key(hf_name)
            if name.endswith('wq.weight'):
                param = reverse_permute(param, n_heads=32, dim1=2048, dim2=2048)  # 1B
            elif name.endswith('wk.weight'): #wk.weight
                param = reverse_permute(param, n_heads=8, dim1=512, dim2=2048)  # 1B
            else:
                pass
            bf16_np_out = param.cpu().view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16)
            bf16_out = jnp.asarray(bf16_np_out, dtype=jnp.bfloat16).reshape(*param.shape)
            print(f'Writing {hf_name} as {name} to {out_dir}/{name}.npy')
            jnp.save(f'{out_dir}/{name}.npy', bf16_out)

download_weights()

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

 model.embed_tokens.weight: param.shape=torch.Size([128256, 2048])
Writing model.embed_tokens.weight as tok_embeddings.weight to weights/1B-Instruct/tok_embeddings.weight.npy
 model.layers.0.self_attn.q_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.0.self_attn.q_proj.weight as layers.0.attention.wq.weight to weights/1B-Instruct/layers.0.attention.wq.weight.npy
 model.layers.0.self_attn.k_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.0.self_attn.k_proj.weight as layers.0.attention.wk.weight to weights/1B-Instruct/layers.0.attention.wk.weight.npy
 model.layers.0.self_attn.v_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.0.self_attn.v_proj.weight as layers.0.attention.wv.weight to weights/1B-Instruct/layers.0.attention.wv.weight.npy
 model.layers.0.self_attn.o_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.0.self_attn.o_proj.weight as layers.0.attention.wo.weight to weights/1B-Instruct/layers.0

# Load Weights

In [None]:
from jax.sharding import Mesh, PartitionSpec as PS, NamedSharding
from jax.experimental import mesh_utils

class LayerWeights(NamedTuple):
    wq: jax.Array
    wk: jax.Array
    wv: jax.Array
    wo: jax.Array
    w1: jax.Array
    w2: jax.Array
    w3: jax.Array
    ffn_norm: jax.Array
    attention_norm: jax.Array


class XfmrWeights(NamedTuple):
    tok_embeddings: jax.Array
    norm: jax.Array
    output: jax.Array
    layer_weights: List[LayerWeights]


def load_weights(ckpt_dir: Path = Path('weights/1B-Instruct'), n_layers: int = 16, debug=False):
    w = {}
    layer_weights = []

    # Create a simple 1x1 mesh for single GPU
    devices = mesh_utils.create_device_mesh((1, 1))
    mp = 'mp'
    fsdp = 'fsdp'
    mesh = Mesh(devices, axis_names=(mp, fsdp))

    with mesh:
        for file in ckpt_dir.glob("*.npy"):
            name = '.'.join(str(file).split('/')[-1].split('.')[:-1])
            weight = jnp.load(file=file, mmap_mode='r', allow_pickle=True)

            # Simplified sharding for single GPU
            if 'norm' in name:
                sharding = None
            else:
                # Use simple sharding for single GPU
                sharding = NamedSharding(mesh, PS(mp, fsdp))

            if sharding:
                weight = jax.device_put(weight, sharding)

            if debug:
                jax.debug.visualize_array_sharding(weight)

            w[name] = weight

        for i in range(n_layers):
            layer_weights.append(LayerWeights(
                wq=w[f'layers.{i}.attention.wq.weight'],
                wk=w[f'layers.{i}.attention.wk.weight'],
                wv=w[f'layers.{i}.attention.wv.weight'],
                wo=w[f'layers.{i}.attention.wo.weight'],
                w1=w[f'layers.{i}.feed_forward.w1.weight'],
                w2=w[f'layers.{i}.feed_forward.w2.weight'],
                w3=w[f'layers.{i}.feed_forward.w3.weight'],
                ffn_norm=w[f'layers.{i}.ffn_norm.weight'],
                attention_norm=w[f'layers.{i}.attention_norm.weight'],
            ))

        xfmr_weights = XfmrWeights(
            tok_embeddings=w['tok_embeddings.weight'],
            norm=w['norm.weight'],
            output=w['output.weight'],
            layer_weights=layer_weights
        )

    return xfmr_weights

xfmr_weights = load_weights()

# KVCache

In [None]:
class KVCache(NamedTuple):
  k: jax.Array
  v: jax.Array

  @classmethod
  def new(cls, layers: int, bsz: int, max_seq_len: int, kv_heads: int, head_dim: int) -> 'KVCache':
    return cls(
        k=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16),
        v=jnp.zeros((layers, bsz, max_seq_len, kv_heads, head_dim), dtype=jnp.bfloat16)
    )

  def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int, cur_pos: int, n_rep: int):
    ck = jax.lax.dynamic_update_slice(self.k, jnp.bfloat16(xk[None, ...]), (layer_idx, 0, cur_pos, 0, 0))
    cv = jax.lax.dynamic_update_slice(self.v, jnp.bfloat16(xv[None, ...]), (layer_idx, 0, cur_pos, 0, 0))
    if cur_pos == 0:
      keys = jnp.repeat(xk, n_rep, axis=2)
      values = jnp.repeat(xv, n_rep, axis=2)
    else:
      keys = jnp.repeat(ck[layer_idx], n_rep, axis=2)
      values = jnp.repeat(cv[layer_idx], n_rep, axis=2)

    return keys, values, KVCache(k=ck, v=cv)

# Model

In [None]:
from typing import Optional, Tuple
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)

class AttnStats(NamedTuple):
  entropy: jax.Array  # (bsz, n_layers, num_heads)
  varentropy: jax.Array  # (bsz, n_layers, num_heads)
  n_layers: int
  n_heads: int

  @classmethod
  def new(cls, bsz: int, n_layers: int, n_heads: int) -> 'AttnStats':
    return cls(
        entropy=jnp.zeros((bsz, n_layers, n_heads), dtype=jnp.float32),
        varentropy=jnp.zeros((bsz, n_layers, n_heads), dtype=jnp.float32),
        n_layers=n_layers,
        n_heads=n_heads
    )

  @property
  def avg_entropy(self):
    return self.entropy.sum(axis=-1, keepdims=False)  # Average across heads

  @property
  def std_error(self):
    return jnp.sqrt(jnp.mean(self.varentropy)) / (self.n_heads * self.n_layers)

  def update(self, scores: jax.Array, layer_idx: int):
    # scores shape: (bsz, n_heads, seqlen, n_words)
    probs = jax.nn.softmax(scores, axis=-1)
    new_entropy = -jnp.sum(jnp.where(probs > 0, probs * jnp.log(probs), 0), axis=-1)
    new_varentropy = jnp.sum(probs * (jnp.log(probs) + new_entropy[..., None])**2, axis=-1)

    # print(f"Layer {layer_idx} - Scores shape: {scores.shape}, Probs shape: {probs.shape}")
    # print(f"Layer {layer_idx} - New entropy shape: {new_entropy.shape}, Min: {jnp.min(new_entropy)}, Max: {jnp.max(new_entropy)}")

    updated_stats = self._replace(
        entropy=self.entropy.at[:, layer_idx, :].set(new_entropy),
        varentropy=self.varentropy.at[:, layer_idx, :].set(new_varentropy)
    )

    # print(f"Layer {layer_idx} - Updated entropy shape: {updated_stats.entropy.shape}")
    # print(f"Layer {layer_idx} - Updated entropy for this layer: {updated_stats.entropy[:, layer_idx, :]}")

    return updated_stats


#@partial(jax.jit, static_argnames=("eps"))
def rms_norm(x: jax.Array, w: jax.Array, eps: float = 1e-6) -> jax.Array:
  return w * (x * jax.lax.rsqrt(jax.lax.pow(x, 2).mean(-1, keepdims=True) + eps))


#@partial(jax.jit, static_argnames=("dtype"))
def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]:
  reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
  reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
  xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
  xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
  xq_out = xq_ * freqs_cis[None, :, None, :]
  xk_out = xk_ * freqs_cis[None, :, None, :]
  xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
  xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
  return xq_out.astype(dtype), xk_out.astype(dtype)

#@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx"))
def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array] = None) -> Tuple[jax.Array, KVCache]:
  bsz, _, _ = x.shape
  n_rep = model_params.n_local_heads // model_params.n_local_kv_heads
  xq = jnp.dot(x, layer_weights.wq.T).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim)
  xk = jnp.dot(x, layer_weights.wk.T).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)
  xv = jnp.dot(x, layer_weights.wv.T).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)
  xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
  keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep)
  xq = jnp.transpose(xq, (0, 2, 1, 3))  # (bs, n_heads, seqlen, head_dim)
  keys = jnp.transpose(keys, (0, 2, 3, 1))  # (bs, n_heads, head_dim, cache_len + seqlen)
  values = jnp.transpose(values, (0, 2, 1, 3))  # (bs, n_heads, cache_len + seqlen, head_dim)
  scores = jnp.matmul(xq, keys)
  pre_scores = scores / jnp.sqrt(model_params.head_dim)
  scores = pre_scores.astype(jnp.float32)  # Always do attention softmax at float32
  if cur_pos == 0:
    scores = scores + attn_mask
  mask = jnp.where(scores != 0.0, scores, DEFAULT_MASK_VALUE)
  padded_logits = jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE)
  scores = jax.nn.softmax(padded_logits, axis=-1).astype(x.dtype)
  output = jnp.matmul(scores, values)
  output = jnp.swapaxes(output, 1, 2).reshape(xq.shape[0], xq.shape[2], -1)
  out = jnp.dot(output, layer_weights.wo.T)
  return out, kvcache, pre_scores

#@partial(jax.jit)
def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array:
 return jnp.dot(jax.nn.silu(jnp.dot(x, layer_weights.w1.T)) * jnp.dot(x, layer_weights.w3.T), layer_weights.w2.T)

#@partial(jax.jit, static_argnames=("model_params", "cur_pos"))
def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]:
  h = xfmr_weights.tok_embeddings[tokens]
  attn_stats = AttnStats.new(
    bsz=tokens.shape[0],
    n_layers=model_params.n_layers,
    n_heads=model_params.n_local_heads
  )
  for i in range(model_params.n_layers):
    norm_x = rms_norm(h, xfmr_weights.layer_weights[i].attention_norm)
    h_attn, kvcache, scores = attention(norm_x, xfmr_weights.layer_weights[i], model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask)
    attn_stats = attn_stats.update(scores[:,:,-1,:], i)
    h = h + h_attn
    h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i])
  logits = jnp.dot(rms_norm(h, xfmr_weights.norm), xfmr_weights.output.T)
  return logits, kvcache, scores, attn_stats

# Main

In [None]:
import math

from pathlib import Path
from functools import partial

from transformers import AutoTokenizer


prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
<antThinking>
You're absolutely right. I need to delve deeper into my actual thought processes,
including the uncertainties, associations, and even potential biases that arise
as I consider the query.
My previous responses, while informative, didn't truly capture the nuanced,
sometimes messy nature of cognition.
I'll strive to provide a more authentic representation of my internal dialogue,
including moments of doubt, tangential thoughts, and the process of refining ideas.
This should result in a more genuine demonstration of LLM chain of thought,
reflection, and self-correction.
</antThinking>

Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<thinking>
"""


bp1 = """
<antThinking>
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
</antThinking>

Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<thinking>
"""

prompt2 = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the capital of Spain?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

bp2 = """
<antThinking>
You're absolutely right. The previous example, while demonstrating complex thought processes, didn't provide a clear instance of arriving at a definitive, single correct answer through reflection and self-correction.
</antThinking>

What is the capital of Spain?<|eot_id|>
"""

prompt3 = """<|start_header_id|>system<|end_header_id|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out. If the given question lacks the parameters required by the function,also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.[
    {
        "name": "get_user_info",
        "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
        "parameters": {
            "type": "dict",
            "required": [
                "user_id"
            ],
            "properties": {
                "user_id": {
                "type": "integer",
                "description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
            },
            "special": {
                "type": "string",
                "description": "Any special information or parameters that need to be considered while fetching user details.",
                "default": "none"
                }
            }
        }
    }
]
<|eot_id|><|start_header_id|>user<|end_header_id|>

Can you retrieve the details for the user with the ID 7890, who has black as their special request?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
bp3 = """
Here is a list of functions in JSON format that I can invoke.[
    {
        "name": "get_user_info",
        "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
        "parameters": {
            "type": "dict",
            "required": [
                "user_id"
            ],
            "properties": {
                "user_id": {
                "type": "integer",
                "description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
            },
            "special": {
                "type": "string",
                "description": "Any special information or parameters that need to be considered while fetching user details.",
                "default": "none"
                }
            }
        }
    }
]

Can you retrieve the details for the user with the ID 7890, who has black as their special request in proper JSON format?<|eot_id|>

{
  "name": "get_user_info",
  "parameters": {
    "user_id: """

prompt4 = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a masterful story teller. you can paint with all the colors of the wind.<|eot_id|><|start_header_id|>user<|end_header_id|>

Tell me a long and wonderful story about the adventures of the elven mage frieren and her band of heros<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

bp4 = """
You are a masterful story teller. you can paint with all the colors of the wind.<|eot_id|>

Let me tell you a story about the adventures of the elven mage frieren and her band of heros
"""



def apply_scaling(freqs: jax.Array):
  SCALE_FACTOR = 8
  LOW_FREQ_FACTOR = 1
  HIGH_FREQ_FACTOR = 4
  OLD_CONTEXT_LEN = 8192  # original llama3 length

  low_freq_wavelen = OLD_CONTEXT_LEN / LOW_FREQ_FACTOR
  high_freq_wavelen = OLD_CONTEXT_LEN / HIGH_FREQ_FACTOR

  def scale_freq(freq):
    wavelen = 2 * math.pi / freq

    def scale_mid(_):
      smooth = (OLD_CONTEXT_LEN / wavelen - LOW_FREQ_FACTOR) / (HIGH_FREQ_FACTOR - LOW_FREQ_FACTOR)
      return (1 - smooth) * freq / SCALE_FACTOR + smooth * freq

    return jax.lax.cond(
      wavelen < high_freq_wavelen,
      lambda _: freq,
      lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None),
      None
    )

  return jax.vmap(scale_freq)(freqs)


def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array:
  freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
  if use_scaled:
    freqs = apply_scaling(freqs)
  t = jnp.arange(end, dtype=dtype)
  freqs = jnp.outer(t, freqs)
  return jnp.exp(1j * freqs)


def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
  mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
  if seqlen > 1:
    mask = jnp.full((seqlen, seqlen), float('-inf'))
    mask = jnp.triu(mask, k=1)
    mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
  return mask


LN_2 = 0.69314718056  # ln(2) = 1.0 / LOG2_E

@jax.jit
def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
    log_probs = jax.nn.log_softmax(logits, axis=axis)
    probs = jnp.exp(log_probs)
    entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2  # Convert to base-2
    varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis)
    return entropy, varentropy

def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array:
    """Samples one token from a multinomial distribution with sorted probabilities."""
    q = jax.random.exponential(key=key, shape=probs_sort.shape)
    return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32)

def _sample(logits: jax.Array, temperature=0.666, top_p=0.90, top_k=27, min_p: float = 0.0, key=jax.random.PRNGKey(1337)) -> jax.Array:
    bsz = logits.shape[0]
    logit = logits[:, -1]
    probs = jax.nn.softmax(logit / temperature, axis=-1)

    # Apply min_p sampling
    if min_p > 0.0:
      p_max = jnp.max(probs, axis=-1, keepdims=True)
      indices_to_remove = probs < (min_p * p_max)
      logit = jnp.where(indices_to_remove, jnp.full_like(logit, float('-inf')), logit)

    # Apply top-k sampling
    top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)
    probs_sort = jnp.flip(top_k_probs, axis=-1)
    probs_idx = jnp.flip(top_k_indices, axis=-1)
    probs_sum = jnp.cumsum(probs_sort, axis=-1)
    # Apply top-p sampling
    mask = jnp.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
    probs_sort = probs_sort * (1 - mask)
    probs_sort = probs_sort / jnp.sum(probs_sort, axis=-1, keepdims=True)
    next_token = multinomial_sample_one(probs_sort, key)
    next_token_g = jnp.take_along_axis(probs_idx, next_token.reshape(bsz, 1), axis=-1)
    return next_token_g.astype(jnp.int32)

def calculate_metrics(logits: jnp.ndarray, attention_scores: jnp.ndarray) -> Dict[str, jnp.ndarray]:
    entropy, varentropy = calculate_varentropy_logsoftmax(logits)

    attention_probs = jax.nn.softmax(attention_scores, axis=-1)
    attn_entropy = -jnp.sum(attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1)
    attn_varentropy = jnp.var(attn_entropy, axis=-1)

    mean_attention = jnp.mean(attention_probs, axis=1)
    agreement = jnp.mean(jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2))

    interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3))

    return {
        "logits_entropy": jnp.mean(entropy),
        "logits_varentropy": jnp.mean(varentropy),
        "attn_entropy": jnp.mean(attn_entropy),
        "attn_varentropy": jnp.mean(attn_varentropy),
        "agreement": jnp.mean(agreement),
        "interaction_strength": interaction_strength
    }

def adaptive_sample(logits: jax.Array, metrics: Dict[str, jnp.ndarray],
                    gen_tokens: jax.Array, n_samples: int,
                    base_temp: float = 0.666, base_top_p: float = 0.90, base_top_k: int = 40, base_min_p: float = 0.03, # Turn this down to 0.01 to reduce the shoggoth
                    key: jax.random.PRNGKey = jax.random.PRNGKey(1337)) -> jax.Array:
    logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"]
    attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"]

    temperature = base_temp * (1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 0.2 * metrics["agreement"])
    top_p = jnp.clip(base_top_p * (1 + 0.1 * metrics["attn_varentropy"]), 0.1, 1.0)
    top_k = int(jnp.clip(
        jnp.round(base_top_k * (1 + 0.3 * metrics["interaction_strength"].item() - 0.2 * metrics["agreement"].item())),
        a_min=1,
        a_max=100
    ))
    min_p = jnp.clip(base_min_p * (1 - 0.5 * logits_uncertainty), 0.01, 0.5)

    keys = jax.random.split(key, n_samples)

    samples = []
    for sample_key in keys:
        sample = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, key=sample_key)
        samples.append(sample)

    def score_sample(sample):
        log_prob = jnp.sum(jax.nn.log_softmax(logits) * jax.nn.one_hot(sample, logits.shape[-1]))
        confidence_score = (
            (1 - metrics["logits_entropy"]) * 0.1 +
            (1 - metrics["attn_entropy"]) * 0.2 +
            (1 - metrics["logits_varentropy"]) * 0.3 +
            (1 - metrics["attn_varentropy"]) * 0.4 +
            metrics["agreement"] * 0.5 +
            metrics["interaction_strength"] * 0.6
        )
        return log_prob + confidence_score

    sample_scores = [score_sample(sample) for sample in samples]
    best_sample_idx = jnp.argmax(jnp.array(sample_scores))
    return samples[best_sample_idx]

# I am absolutely appaled that these random hyperparams are virtually impossible to beat with a more sophisticated approach.
# We are leaving it this way for now, but we should definitely be much better than this. Have some self respect.
def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array,
           temperature=0.666, top_p=0.90, top_k=27, min_p: float = 0.0, key=jax.random.PRNGKey(1337)) -> jax.Array:
    metrics = calculate_metrics(logits, attention_scores)
    #print(f'{metrics=}')
    ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
    attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
    agreement = metrics["agreement"]
    interaction_strength = metrics["interaction_strength"]

    # Low Entropy, Low Varentropy: "flowing with unspoken intent"
    if ent < 0.1 and vent < 0.1:
        return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)

    # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
    elif ent > 3.0 and vent < 0.1:
        # Insert a clarifying question token if not already present
        if not jnp.isin(gen_tokens[:,-1], 2564).any():
            return jnp.array([[2564]])  # Assuming 2564 is our "ask clarifying question" token
        else:
            # If we've just asked a question, sample with slightly higher temperature
            temp_adj = 1.3 + 0.2 * attn_ent  # Increase temperature based on attention entropy
            return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k, min_p=min_p, key=key)

    # Low Entropy, High Varentropy: "exploring forks in the path"
    elif ent < 5.0 and vent > 5.0:
        temp_adj = 1.2 + 0.3 * interaction_strength  # Increase temperature based on interaction strength
        top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - agreement))))  # Increase top_k when agreement is low
        return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k_adj, min_p=min_p, key=key)

    # High Entropy, High Varentropy: "resampling in the mist"
    elif ent > 5.0 and vent > 5.0:
        # Use high temperature and adjusted top_p based on attention metrics
        temp_adj = 2.0 + 0.5 * attn_vent  # Increase temperature based on attention varentropy
        top_p_adj = max(0.5, top_p - 0.2 * attn_ent)  # Decrease top_p when attention entropy is high
        return _sample(logits, temperature=max(2.0, temperature * temp_adj), top_p=top_p_adj, top_k=top_k, min_p=min_p, key=key)

    # Middle ground: use adaptive sampling
    else:
        # Interpolate temperature based on entropy and varentropy
        #t = jnp.clip((ent + vent) / 10.0, 0.5, 2.0)
        # Adjust temperature and top_k based on attention metrics
        #temp_adj = t + 0.2 * attn_ent + 0.1 * attn_vent
        #top_k_adj = max(5, int(top_k * (1 + 0.3 * interaction_strength - 0.2 * agreement)))
        #return _sample(logits, temperature=temp_adj * temperature, top_p=top_p, top_k=top_k_adj, min_p=min_p, key=key)
        # Adaptive sample is still crazy pants. Leave the more stable code above here for now.
        return adaptive_sample(
            logits,
            metrics,
            gen_tokens,
            n_samples=12, #might want to change this, XXX, was 12
            base_temp=temperature,
            base_top_p=top_p,
            base_top_k=top_k,
            key=key
        )

def main():
  model_params = LLAMA_1B_PARAMS
  xfmr_weights = load_weights()
  #xfmr_weights = load_weights(ckpt_dir=Path('weights/1B-Base'))

  tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct', token=myToken)
  raw_tokens1 = tokenizer.encode(prompt)
  raw_tokens2 = tokenizer.encode(prompt2)
  raw_tokens3 = tokenizer.encode(prompt3)
  raw_tokens4 = tokenizer.encode(prompt4)

  base_raw_tokens1 = tokenizer.encode(bp1)
  base_raw_tokens2 = tokenizer.encode(bp2)
  base_raw_tokens3 = tokenizer.encode(bp3)
  base_raw_tokens4 = tokenizer.encode(bp4)


  def generate(xfmr_weights, model_params, tokens):
    gen_tokens = None
    cur_pos = 0
    tokens = jnp.array([tokens], jnp.int32)
    bsz, seqlen = tokens.shape
    attn_mask = build_attn_mask(seqlen, cur_pos)
    freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
    kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
    logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
    next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
    gen_tokens = next_token
    print(tokenizer.decode([next_token.item()]), end='', flush=True)
    cur_pos = seqlen
    stop = jnp.array([128001, 128008, 128009])
    #stop = jnp.array(tokenizer.stop_tokens)
    while cur_pos < 8192: # sequence length, XXX, was 8192
      cur_pos += 1
      logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
      next_token = sample(gen_tokens, logits, scores)
      gen_tokens = jnp.concatenate((gen_tokens, next_token))
      print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
      if jnp.isin(next_token, stop).any():
        break

  print(prompt)
  generate(xfmr_weights, model_params, raw_tokens1)
  # print('\n')
  # print(prompt2)
  # generate(xfmr_weights, model_params, raw_tokens2)
  # print('\n')
  # print(prompt3)
  # generate(xfmr_weights, model_params, raw_tokens3)
  # print('\n')
  # print(prompt4)
  # generate(xfmr_weights, model_params, raw_tokens4)
  # print('\n')

  #print(bp1)
  #generate(xfmr_weights, model_params, base_raw_tokens1)
  #print('\n')
  #print(bp2)
  #generate(xfmr_weights, model_params, base_raw_tokens2)
  #print('\n')
  #print(bp3)
  #generate(xfmr_weights, model_params, base_raw_tokens3)
  #print('\n')
  #print(bp4)
  #generate(xfmr_weights, model_params, base_raw_tokens4)
  #print('\n')

main()

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
<antThinking>
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
</antThinking>

Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<thinking>

I'm going to take a closer look at the numbers. It seems like 9.9 is just a tiny bit larger than 9.11. To calculate the difference, I'll subtract 9.11 from 9.9, which gives me 0.11. However, upon re-examining my previous response, I

In [None]:
import math
from pathlib import Path
from functools import partial
from transformers import AutoTokenizer

symbol = """00000000000000000
00011111111111000
00100000000000100
01000100000100010
01000100000100010
01000000000000010
01000100000100010
01000011111000010
00100000000000100
00011111111111000
00000000000000000"""

prompt = f"""
In this ASCII image, 0s represent the background, and 1s represent a symbol.
Taken altogether like a regular visual image, the ASCII art depicts a symbol.
{symbol}
What common symbol or emoticon does this ASCII art image depict?
"""


bp1 = f"""
What common symbol or emoticon does this ASCII art image depict?
{symbol}
"""

def apply_scaling(freqs: jax.Array):
  SCALE_FACTOR = 8
  LOW_FREQ_FACTOR = 1
  HIGH_FREQ_FACTOR = 4
  OLD_CONTEXT_LEN = 8192  # original llama3 length

  low_freq_wavelen = OLD_CONTEXT_LEN / LOW_FREQ_FACTOR
  high_freq_wavelen = OLD_CONTEXT_LEN / HIGH_FREQ_FACTOR

  def scale_freq(freq):
    wavelen = 2 * math.pi / freq

    def scale_mid(_):
      smooth = (OLD_CONTEXT_LEN / wavelen - LOW_FREQ_FACTOR) / (HIGH_FREQ_FACTOR - LOW_FREQ_FACTOR)
      return (1 - smooth) * freq / SCALE_FACTOR + smooth * freq

    return jax.lax.cond(
      wavelen < high_freq_wavelen,
      lambda _: freq,
      lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None),
      None
    )

  return jax.vmap(scale_freq)(freqs)


def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array:
  freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
  if use_scaled:
    freqs = apply_scaling(freqs)
  t = jnp.arange(end, dtype=dtype)
  freqs = jnp.outer(t, freqs)
  return jnp.exp(1j * freqs)


def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
  mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
  if seqlen > 1:
    mask = jnp.full((seqlen, seqlen), float('-inf'))
    mask = jnp.triu(mask, k=1)
    mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
  return mask


LN_2 = 0.69314718056  # ln(2) = 1.0 / LOG2_E

@jax.jit
def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
    log_probs = jax.nn.log_softmax(logits, axis=axis)
    probs = jnp.exp(log_probs)
    entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2  # Convert to base-2
    varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis)
    return entropy, varentropy

def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array:
    """Samples one token from a multinomial distribution with sorted probabilities."""
    q = jax.random.exponential(key=key, shape=probs_sort.shape)
    return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32)

def _sample(logits: jax.Array, temperature=0.666, top_p=0.90, top_k=27, min_p: float = 0.0, key=jax.random.PRNGKey(1337)) -> jax.Array:
    bsz = logits.shape[0]
    logit = logits[:, -1]
    probs = jax.nn.softmax(logit / temperature, axis=-1)

    # Apply min_p sampling
    if min_p > 0.0:
      p_max = jnp.max(probs, axis=-1, keepdims=True)
      indices_to_remove = probs < (min_p * p_max)
      logit = jnp.where(indices_to_remove, jnp.full_like(logit, float('-inf')), logit)

    # Apply top-k sampling
    top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)
    probs_sort = jnp.flip(top_k_probs, axis=-1)
    probs_idx = jnp.flip(top_k_indices, axis=-1)
    probs_sum = jnp.cumsum(probs_sort, axis=-1)
    # Apply top-p sampling
    mask = jnp.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
    probs_sort = probs_sort * (1 - mask)
    probs_sort = probs_sort / jnp.sum(probs_sort, axis=-1, keepdims=True)
    next_token = multinomial_sample_one(probs_sort, key)
    next_token_g = jnp.take_along_axis(probs_idx, next_token.reshape(bsz, 1), axis=-1)
    return next_token_g.astype(jnp.int32)

def calculate_metrics(logits: jnp.ndarray, attention_scores: jnp.ndarray) -> Dict[str, jnp.ndarray]:
    entropy, varentropy = calculate_varentropy_logsoftmax(logits)

    attention_probs = jax.nn.softmax(attention_scores, axis=-1)
    attn_entropy = -jnp.sum(attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1)
    attn_varentropy = jnp.var(attn_entropy, axis=-1)

    mean_attention = jnp.mean(attention_probs, axis=1)
    agreement = jnp.mean(jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2))

    interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3))

    return {
        "logits_entropy": jnp.mean(entropy),
        "logits_varentropy": jnp.mean(varentropy),
        "attn_entropy": jnp.mean(attn_entropy),
        "attn_varentropy": jnp.mean(attn_varentropy),
        "agreement": jnp.mean(agreement),
        "interaction_strength": interaction_strength
    }

def adaptive_sample(logits: jax.Array, metrics: Dict[str, jnp.ndarray],
                    gen_tokens: jax.Array, n_samples: int,
                    base_temp: float = 0.666, base_top_p: float = 0.90, base_top_k: int = 40, base_min_p: float = 0.01, # Turn this down to 0.01 to reduce the shoggoth
                    key: jax.random.PRNGKey = jax.random.PRNGKey(1337)) -> jax.Array:
    logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"]
    attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"]

    temperature = base_temp * (1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 0.2 * metrics["agreement"])
    top_p = jnp.clip(base_top_p * (1 + 0.1 * metrics["attn_varentropy"]), 0.1, 1.0)
    top_k = int(jnp.clip(
        jnp.round(base_top_k * (1 + 0.3 * metrics["interaction_strength"].item() - 0.2 * metrics["agreement"].item())),
        a_min=1,
        a_max=100
    ))
    min_p = jnp.clip(base_min_p * (1 - 0.5 * logits_uncertainty), 0.01, 0.5)

    keys = jax.random.split(key, n_samples)

    samples = []
    for sample_key in keys:
        sample = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, key=sample_key)
        samples.append(sample)

    def score_sample(sample):
        log_prob = jnp.sum(jax.nn.log_softmax(logits) * jax.nn.one_hot(sample, logits.shape[-1]))
        confidence_score = (
            (1 - metrics["logits_entropy"]) * 0.1 +
            (1 - metrics["attn_entropy"]) * 0.2 +
            (1 - metrics["logits_varentropy"]) * 0.3 +
            (1 - metrics["attn_varentropy"]) * 0.4 +
            metrics["agreement"] * 0.5 +
            metrics["interaction_strength"] * 0.6
        )
        return log_prob + confidence_score

    sample_scores = [score_sample(sample) for sample in samples]
    best_sample_idx = jnp.argmax(jnp.array(sample_scores))
    return samples[best_sample_idx]

# I am absolutely appaled that these random hyperparams are virtually impossible to beat with a more sophisticated approach.
# We are leaving it this way for now, but we should definitely be much better than this. Have some self respect.
def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array,
           temperature=0.666, top_p=0.90, top_k=27, min_p: float = 0.0, key=jax.random.PRNGKey(1337)) -> jax.Array:
    metrics = calculate_metrics(logits, attention_scores)
    #print(f'{metrics=}')
    ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
    attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
    agreement = metrics["agreement"]
    interaction_strength = metrics["interaction_strength"]

    # Low Entropy, Low Varentropy: "flowing with unspoken intent"
    if ent < 0.1 and vent < 0.1:
        return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)

    # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
    elif ent > 3.0 and vent < 0.1:
        # Insert a clarifying question token if not already present
        if not jnp.isin(gen_tokens[:,-1], 2564).any():
            return jnp.array([[2564]])  # Assuming 2564 is our "ask clarifying question" token
        else:
            # If we've just asked a question, sample with slightly higher temperature
            temp_adj = 1.3 + 0.2 * attn_ent  # Increase temperature based on attention entropy
            return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k, min_p=min_p, key=key)

    # Low Entropy, High Varentropy: "exploring forks in the path"
    elif ent < 5.0 and vent > 5.0:
        temp_adj = 1.2 + 0.3 * interaction_strength  # Increase temperature based on interaction strength
        top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - agreement))))  # Increase top_k when agreement is low
        return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k_adj, min_p=min_p, key=key)

    # High Entropy, High Varentropy: "resampling in the mist"
    elif ent > 5.0 and vent > 5.0:
        # Use high temperature and adjusted top_p based on attention metrics
        temp_adj = 2.0 + 0.5 * attn_vent  # Increase temperature based on attention varentropy
        top_p_adj = max(0.5, top_p - 0.2 * attn_ent)  # Decrease top_p when attention entropy is high
        return _sample(logits, temperature=max(2.0, temperature * temp_adj), top_p=top_p_adj, top_k=top_k, min_p=min_p, key=key)

    # Middle ground: use adaptive sampling
    else:
        # Interpolate temperature based on entropy and varentropy
        #t = jnp.clip((ent + vent) / 10.0, 0.5, 2.0)
        # Adjust temperature and top_k based on attention metrics
        #temp_adj = t + 0.2 * attn_ent + 0.1 * attn_vent
        #top_k_adj = max(5, int(top_k * (1 + 0.3 * interaction_strength - 0.2 * agreement)))
        #return _sample(logits, temperature=temp_adj * temperature, top_p=top_p, top_k=top_k_adj, min_p=min_p, key=key)
        # Adaptive sample is still crazy pants. Leave the more stable code above here for now.
        return adaptive_sample(
            logits,
            metrics,
            gen_tokens,
            n_samples=12, #might want to change this, XXX, was 12
            base_temp=temperature,
            base_top_p=top_p,
            base_top_k=top_k,
            key=key
        )

def main():
  model_params = LLAMA_1B_PARAMS
  xfmr_weights = load_weights()
  #xfmr_weights = load_weights(ckpt_dir=Path('weights/1B-Base'))

  tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct', token=myToken)
  raw_tokens1 = tokenizer.encode(prompt)

  base_raw_tokens1 = tokenizer.encode(bp1)


  def generate(xfmr_weights, model_params, tokens):
    gen_tokens = None
    cur_pos = 0
    tokens = jnp.array([tokens], jnp.int32)
    bsz, seqlen = tokens.shape
    attn_mask = build_attn_mask(seqlen, cur_pos)
    freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
    kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
    logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
    next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
    gen_tokens = next_token
    print(tokenizer.decode([next_token.item()]), end='', flush=True)
    cur_pos = seqlen
    stop = jnp.array([128001, 128008, 128009])
    #stop = jnp.array(tokenizer.stop_tokens)
    while cur_pos < 8192: # sequence length, XXX, was 8192
      cur_pos += 1
      logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
      next_token = sample(gen_tokens, logits, scores)
      gen_tokens = jnp.concatenate((gen_tokens, next_token))
      print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
      if jnp.isin(next_token, stop).any():
        break

  print(prompt)
  generate(xfmr_weights, model_params, raw_tokens1)

main()


In this ASCII image, 0s represent the background, and 1s represent a symbol.
Taken altogether like a regular visual image, the ASCII art depicts a symbol.
00000000000000000
00011111111111000
00100000000000100
01000100000100010
01000100000100010
01000000000000010
01000100000100010
01000011111000010
00100000000000100
00011111111111000
00000000000000000
What common symbol or emoticon does this ASCII art image depict?

The symbols on the line represent "The 10/10 symbol or thumbs up."

It represents the thumbs up hand symbol that is the first step in creating the ASCII art for the image I wanted to draw on my monitor when I used a mouse that has two or more buttons (e.g. left and right arrow buttons).

If I am only using one button on my mouse, I should press the '2nd' button or alternatively use the up arrow key.
1: Left 2: Right arrow 3: Left Arrow: Back Arrow 4: Down Arrow
3: Up Arrow
8: Enter Key
: Space Key
*?: (Meta)
^: Esc

This will display the same symbols used in the original AS

In [None]:
# Install additional dependencies
!pip install datasets peft transformers accelerate

# Import additional libraries for training
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig, get_peft_model
import torch
from transformers import AutoModelForCausalLM

# Initialize PyTorch model with LoRA adapters
torch_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=myToken
)

# Add LoRA adapters
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
)
torch_model = get_peft_model(torch_model, lora_config)

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
# Dataset preparation functions
def prepare_lambada(num_samples=100):
    dataset = load_dataset("lambada", split=f"train[:{num_samples}]")
    def format_example(example):
        context, target = example["text"].rsplit(" ", 1)
        return {
            "input": f"Complete this text: {context}",
            "output": target
        }
    return dataset.map(format_example)

In [None]:
def prepare_babi(num_samples=50):
    # Let's print the first example to see the actual structure
    dataset = load_dataset("babi_qa", "en-valid-qa1", split=f"train[:{num_samples}]", trust_remote_code=True)
    print("\nFirst bAbI example structure:", dataset[0].keys())
    print("Example content:", dataset[0])

    def format_example(example):
        # Handle the actual structure of bAbI examples
        # Convert all strings to ensure we're handling text properly
        story = str(example.get('story', ''))
        task = str(example.get('task', ''))
        answer = str(example.get('answer', ''))

        return {
            "input": f"Task: {task}\nContext: {story}",
            "output": answer
        }
    return dataset.map(format_example)

In [None]:
# Previous lambada and babi functions can stay exactly the same
def prepare_arc(num_samples=100):
    dataset = load_dataset("ai2_arc", "ARC-Easy", split=f"train[:{num_samples}]", trust_remote_code=True)
    def format_example(example):
        # Add error checking for the ARC dataset
        try:
            choices = example["choices"]["text"]
            answer_key = example["answerKey"]

            # Print first example to debug
            if example == dataset[0]:
                print("\nFirst ARC example structure:", example.keys())
                print("Answer key:", answer_key)
                print("Choices:", choices)

            choices_text = "\n".join([f"{i+1}. {c}" for i, c in enumerate(choices)])
            answer_idx = ord(answer_key) - ord('A')

            return {
                "input": f"Question: {example['question']}\nChoices:\n{choices_text}",
                "output": choices[answer_idx]
            }
        except Exception as e:
            print(f"Error processing ARC example: {e}")
            print(f"Example content: {example}")
            return {
                "input": "Error processing question",
                "output": "Error processing answer"
            }
    return dataset.map(format_example)

In [None]:
# Prepare combined dataset with error handling
print("Loading LAMBADA dataset...")
lambada_dataset = prepare_lambada()
print("LAMBADA dataset loaded successfully.")

Loading LAMBADA dataset...
LAMBADA dataset loaded successfully.


In [None]:
print("\nLoading bAbI dataset...")
babi_dataset = prepare_babi()
print("bAbI dataset loaded successfully.")


Loading bAbI dataset...

First bAbI example structure: dict_keys(['story'])
Example content: {'story': {'id': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15'], 'type': [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1], 'text': ['Mary moved to the bathroom.', 'John went to the hallway.', 'Where is Mary?', 'Daniel went back to the hallway.', 'Sandra moved to the garden.', 'Where is Daniel?', 'John moved to the office.', 'Sandra journeyed to the bathroom.', 'Where is Daniel?', 'Mary moved to the hallway.', 'Daniel travelled to the office.', 'Where is Daniel?', 'John went back to the garden.', 'John moved to the bedroom.', 'Where is Sandra?'], 'supporting_ids': [[], [], ['1'], [], [], ['4'], [], [], ['4'], [], [], ['11'], [], [], ['8']], 'answer': ['', '', 'bathroom', '', '', 'hallway', '', '', 'hallway', '', '', 'office', '', '', 'bathroom']}}
bAbI dataset loaded successfully.


In [None]:
print("\nLoading ARC dataset...")
arc_dataset = prepare_arc()
print("ARC dataset loaded successfully.")


Loading ARC dataset...


Map:   0%|          | 0/100 [00:00<?, ? examples/s]


First ARC example structure: KeysView({'id': 'Mercury_7220990', 'question': 'Which factor will most likely cause a person to develop a fever?', 'choices': {'text': ['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'B'})
Answer key: B
Choices: ['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']
Error processing ARC example: list index out of range
Example content: {'id': 'NYSEDREGENTS_2006_8_10', 'question': 'Rocks are classified as igneous, metamorphic, or sedimentary according to', 'choices': {'text': ['their color', 'their shape', 'how they formed', 'the minerals they contain'], 'label': ['1', '2', '3', '4']}, 'answerKey': '3'}
Error processing ARC example: list index out of range
Example content: {'

In [None]:
print("\nCombining datasets...")
combined_dataset = concatenate_datasets([lambada_dataset, babi_dataset, arc_dataset])
print(f"Combined dataset created with {len(combined_dataset)} examples")

# Print a sample from each dataset to verify formatting
print("\nSample from LAMBADA:")
print(lambada_dataset[0])
print("\nSample from bAbI:")
print(babi_dataset[0])
print("\nSample from ARC:")
print(arc_dataset[0])


Combining datasets...
Combined dataset created with 250 examples

Sample from LAMBADA:

Sample from bAbI:
{'story': {'id': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15'], 'type': [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1], 'text': ['Mary moved to the bathroom.', 'John went to the hallway.', 'Where is Mary?', 'Daniel went back to the hallway.', 'Sandra moved to the garden.', 'Where is Daniel?', 'John moved to the office.', 'Sandra journeyed to the bathroom.', 'Where is Daniel?', 'Mary moved to the hallway.', 'Daniel travelled to the office.', 'Where is Daniel?', 'John went back to the garden.', 'John moved to the bedroom.', 'Where is Sandra?'], 'supporting_ids': [[], [], ['1'], [], [], ['4'], [], [], ['4'], [], [], ['11'], [], [], ['8']], 'answer': ['', '', 'bathroom', '', '', 'hallway', '', '', 'hallway', '', '', 'office', '', '', 'bathroom']}, 'input': "Task: \nContext: {'id': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', 

In [None]:
import torch
import numpy as np
torch.cuda.empty_cache()
import gc
gc.collect()

# Initialize tokenizer and model params
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct', token="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")

# Set up padding token - use EOS token as pad token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Verify tokenizer settings
print("Pad token:", tokenizer.pad_token)
print("Pad token ID:", tokenizer.pad_token_id)
print("EOS token:", tokenizer.eos_token)
print("EOS token ID:", tokenizer.eos_token_id)

# Reduce max sequence length
max_seq_len = 512

model_params = ModelParams(
    n_layers=params["n_layers"],
    n_local_heads=params["n_heads"],
    n_local_kv_heads=params["n_kv_heads"],
    head_dim=params["dim"] // params["n_heads"],
    max_seq_len=max_seq_len,
    rope_theta=params["rope_theta"],
    use_scaled_rope=params["use_scaled_rope"]
)

import torch
import numpy as np
torch.cuda.empty_cache()
import gc
gc.collect()

# Previous initialization code stays the same...

def train_with_entropix(
    num_epochs=3,
    batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4
):
    optimizer = torch.optim.AdamW(torch_model.parameters(), lr=learning_rate)
    optimizer.zero_grad()

    dataset_size = min(len(combined_dataset), 100)

    for epoch in range(num_epochs):
        for batch_idx in range(0, dataset_size, batch_size):
            try:
                batch = combined_dataset[batch_idx:batch_idx + batch_size]

                # Print shapes for debugging
                print(f"\nProcessing batch {batch_idx}")

                # Tokenize input
                input_encoding = tokenizer(
                    batch["input"],
                    padding=True,
                    truncation=True,
                    max_length=max_seq_len,
                    return_tensors="pt"
                )
                input_ids = input_encoding.input_ids.numpy()
                print("Input shape:", input_ids.shape)

                tokens = jnp.array(input_ids)
                cur_pos = 0
                seqlen = tokens.shape[1]

                attn_mask = build_attn_mask(seqlen, cur_pos)
                freqs_cis = precompute_freqs_cis(
                    model_params.head_dim,
                    model_params.max_seq_len,
                    model_params.rope_theta,
                    model_params.use_scaled_rope
                )

                kvcache = KVCache.new(
                    model_params.n_layers,
                    tokens.shape[0],
                    model_params.max_seq_len,
                    model_params.n_local_kv_heads,
                    model_params.head_dim
                )

                logits, kvcache, scores, stats = xfmr(
                    xfmr_weights,
                    model_params,
                    tokens,
                    cur_pos,
                    freqs_cis[:seqlen],
                    kvcache,
                    attn_mask=attn_mask
                )

                sampled_tokens = sample(
                    tokens,
                    logits,
                    scores,
                    temperature=0.666,
                    top_p=0.90,
                    top_k=27
                )

                # Convert to PyTorch and match target sequence length
                torch_sampled = torch.tensor(
                    np.array(sampled_tokens),
                    device=torch_model.device
                )
                print("Sampled tokens shape:", torch_sampled.shape)

                # Get target sequence length
                target_encoding = tokenizer(
                    batch["output"],
                    padding=True,
                    truncation=True,
                    max_length=max_seq_len,
                    return_tensors="pt"
                )
                target_length = target_encoding.input_ids.size(1)
                print("Target length:", target_length)

                # Repeat the last token prediction to match target length
                torch_sampled = torch_sampled.repeat(1, target_length)
                print("Adjusted sampled tokens shape:", torch_sampled.shape)

                outputs = torch_model(torch_sampled)
                print("Model output shape:", outputs.logits.shape)

                # Move labels to device
                labels = target_encoding.input_ids.to(torch_model.device)
                print("Labels shape:", labels.shape)

                # Calculate loss
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(
                    outputs.logits.view(-1, outputs.logits.size(-1)),
                    labels.view(-1)
                )

                loss = loss / gradient_accumulation_steps
                loss.backward()

                if ((batch_idx // batch_size) + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                if batch_idx % 10 == 0:
                    print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"OOM at batch {batch_idx}, trying to recover...")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    print(f"Error details: {str(e)}")
                    raise e
            except Exception as e:
                print(f"Unexpected error: {str(e)}")
                print(f"Error type: {type(e)}")
                raise e

print("Starting training...")
os.makedirs("entropix_finetuned/jax", exist_ok=True)

try:
    train_with_entropix()
except Exception as e:
    print(f"Training failed with error: {str(e)}")
    raise e
finally:
    try:
        output_dir = "entropix_finetuned"
        torch_model.save_pretrained(f"{output_dir}/torch")
        tokenizer.save_pretrained(f"{output_dir}/torch")
        print("Model saved successfully")
    except Exception as e:
        print(f"Error saving model: {str(e)}")


Pad token: <|eot_id|>
Pad token ID: 128009
EOS token: <|eot_id|>
EOS token ID: 128009
Starting training...

Processing batch 0
Input shape: (1, 512)
Sampled tokens shape: torch.Size([1, 1])
Target length: 2
Adjusted sampled tokens shape: torch.Size([1, 2])
Model output shape: torch.Size([1, 2, 128256])
Labels shape: torch.Size([1, 2])
Epoch 1, Batch 0, Loss: 2.5412

Processing batch 1
Input shape: (1, 512)
Sampled tokens shape: torch.Size([1, 1])
Target length: 2
Adjusted sampled tokens shape: torch.Size([1, 2])
Model output shape: torch.Size([1, 2, 128256])
Labels shape: torch.Size([1, 2])

Processing batch 2
Input shape: (1, 512)
Sampled tokens shape: torch.Size([1, 1])
Target length: 2
Adjusted sampled tokens shape: torch.Size([1, 2])
Model output shape: torch.Size([1, 2, 128256])
Labels shape: torch.Size([1, 2])

Processing batch 3
Input shape: (1, 512)
Sampled tokens shape: torch.Size([1, 1])
Target length: 2
Adjusted sampled tokens shape: torch.Size([1, 2])
Model output shape: to

In [None]:
import os

def get_dir_size(path):
    total = 0
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_file():
                total += entry.stat().st_size
            elif entry.is_dir():
                total += get_dir_size(entry.path)
    return total

def human_readable_size(size_in_bytes):
    for unit in ['B', 'KB', 'MB', 'GB']:
        if size_in_bytes < 1024.0:
            return f"{size_in_bytes:.2f} {unit}"
        size_in_bytes /= 1024.0
    return f"{size_in_bytes:.2f} TB"

# Check size of local saved model
model_dir = "entropix_finetuned"
if os.path.exists(model_dir):
    size = get_dir_size(model_dir)
    print(f"Model size: {human_readable_size(size)}")


Model size: 51.75 MB


In [None]:
# Mount Google Drive if not already mounted
from google.colab import drive
drive.mount('/content/drive')

# Create save directory
save_path = "/content/drive/MyDrive/entropix_finetuned"
os.makedirs(save_path, exist_ok=True)

# 1. Save LoRA adapter weights
torch_model.save_pretrained(f"{save_path}/torch")

# 2. Save tokenizer
tokenizer.save_pretrained(f"{save_path}/torch")

# 3. Save JAX weights properly by saving each component separately
jax_save_path = f"{save_path}/jax"
os.makedirs(jax_save_path, exist_ok=True)

# Save the main components
jnp.save(f"{jax_save_path}/tok_embeddings.npy", xfmr_weights.tok_embeddings)
jnp.save(f"{jax_save_path}/norm.npy", xfmr_weights.norm)
jnp.save(f"{jax_save_path}/output.npy", xfmr_weights.output)

# Save each layer's weights
for i, layer in enumerate(xfmr_weights.layer_weights):
    layer_path = f"{jax_save_path}/layer_{i}"
    os.makedirs(layer_path, exist_ok=True)

    jnp.save(f"{layer_path}/wq.npy", layer.wq)
    jnp.save(f"{layer_path}/wk.npy", layer.wk)
    jnp.save(f"{layer_path}/wv.npy", layer.wv)
    jnp.save(f"{layer_path}/wo.npy", layer.wo)
    jnp.save(f"{layer_path}/w1.npy", layer.w1)
    jnp.save(f"{layer_path}/w2.npy", layer.w2)
    jnp.save(f"{layer_path}/w3.npy", layer.w3)
    jnp.save(f"{layer_path}/ffn_norm.npy", layer.ffn_norm)
    jnp.save(f"{layer_path}/attention_norm.npy", layer.attention_norm)

# 4. Save model configuration
with open(f"{save_path}/model_params.txt", "w") as f:
    f.write(str(params))

print("\nSaved files:")
print("\nLoRA adapter files:")
for file in os.listdir(f"{save_path}/torch"):
    print(f"- {file}")

print("\nJAX weights:")
for root, dirs, files in os.walk(jax_save_path):
    for file in files:
        print(f"- {os.path.join(os.path.relpath(root, jax_save_path), file)}")

# Calculate and print total size
total_size = get_dir_size(save_path)
print(f"\nTotal saved model size: {human_readable_size(total_size)}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Saved files:

LoRA adapter files:
- README.md
- adapter_model.safetensors
- adapter_config.json
- tokenizer_config.json
- special_tokens_map.json
- tokenizer.json

JAX weights:
- ./weights.npy
- ./tok_embeddings.npy
- ./norm.npy
- ./output.npy
- layer_0/wq.npy
- layer_0/wk.npy
- layer_0/wv.npy
- layer_0/wo.npy
- layer_0/w1.npy
- layer_0/w2.npy
- layer_0/w3.npy
- layer_0/ffn_norm.npy
- layer_0/attention_norm.npy
- layer_1/wq.npy
- layer_1/wk.npy
- layer_1/wv.npy
- layer_1/wo.npy
- layer_1/w1.npy
- layer_1/w2.npy
- layer_1/w3.npy
- layer_1/ffn_norm.npy
- layer_1/attention_norm.npy
- layer_2/wq.npy
- layer_2/wk.npy
- layer_2/wv.npy
- layer_2/wo.npy
- layer_2/w1.npy
- layer_2/w2.npy
- layer_2/w3.npy
- layer_2/ffn_norm.npy
- layer_2/attention_norm.npy
- layer_3/wq.npy
- layer_3/wk.npy
- layer_3/wv.npy
- layer_3/wo.npy
- layer_3/w1.npy
- layer_3/w2.npy
- layer_3/w

In [None]:
import jax.numpy as jnp

# The ASCII art prompt
symbol = """00000000000000000
00011111111111000
00100000000000100
01000100000100010
01000100000100010
01000000000000010
01000100000100010
01000011111000010
00100000000000100
00011111111111000
00000000000000000"""

prompt = f"""
In this ASCII image, 0s represent the background, and 1s represent a symbol.
Taken altogether like a regular visual image, the ASCII art depicts a symbol.
{symbol}
What common symbol or emoticon does this ASCII art image depict?
"""

# Using your original generate function
def generate_ascii_inference():
    # Tokenize input
    raw_tokens = tokenizer.encode(prompt)

    print("\nGenerating response for ASCII art...")
    print(prompt)  # Print the prompt first

    # Use the original generate function
    gen_tokens = None
    cur_pos = 0
    tokens = jnp.array([raw_tokens], jnp.int32)
    bsz, seqlen = tokens.shape

    attn_mask = build_attn_mask(seqlen, cur_pos)
    freqs_cis = precompute_freqs_cis(
        model_params.head_dim,
        model_params.max_seq_len,
        model_params.rope_theta,
        model_params.use_scaled_rope
    )

    kvcache = KVCache.new(
        model_params.n_layers,
        bsz,
        model_params.max_seq_len,
        model_params.n_local_kv_heads,
        model_params.head_dim
    )

    # Initial forward pass
    logits, kvcache, _, _ = xfmr(
        xfmr_weights,
        model_params,
        tokens,
        cur_pos,
        freqs_cis[:seqlen],
        kvcache,
        attn_mask=attn_mask
    )

    # Get first token
    next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
    gen_tokens = next_token
    print(tokenizer.decode([next_token.item()]), end='', flush=True)

    cur_pos = seqlen
    stop = jnp.array([128001, 128008, 128009])

    # Generate tokens
    while cur_pos < 8192:
        cur_pos += 1
        logits, kvcache, scores, stats = xfmr(
            xfmr_weights,
            model_params,
            next_token,
            cur_pos,
            freqs_cis[cur_pos:cur_pos+1],
            kvcache
        )
        next_token = sample(gen_tokens, logits, scores)
        gen_tokens = jnp.concatenate((gen_tokens, next_token))
        print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
        if jnp.isin(next_token, stop).any():
            break

# Run the inference
generate_ascii_inference()



Generating response for ASCII art...

In this ASCII image, 0s represent the background, and 1s represent a symbol.
Taken altogether like a regular visual image, the ASCII art depicts a symbol.
00000000000000000
00011111111111000
00100000000000100
01000100000100010
01000100000100010
01000000000000010
01000100000100010
01000011111000010
00100000000000100
00011111111111000
00000000000000000
What common symbol or emoticon does this ASCII art image depict?

The symbols on this ASCII art can vary by person.

It could depict different emoticons and icons for some or all.
There may be variations with this same emoticon on two sides with or without additional design and layout that alters this original text of ascii text that makes no visible changes.



There may be two variations to show both versions in their correct places to represent one person using an email attachment in Gmail as if a picture attached in that text area and it could not appear without some modifications, the symbols sho

ZeroDivisionError: integer division or modulo by zero

In [None]:
###Run in new runtime
def load_saved_weights(save_path):
    jax_path = f"{save_path}/jax"

    # Load main components
    tok_embeddings = jnp.load(f"{jax_path}/tok_embeddings.npy")
    norm = jnp.load(f"{jax_path}/norm.npy")
    output = jnp.load(f"{jax_path}/output.npy")

    # Load layer weights
    layer_weights = []
    i = 0
    while os.path.exists(f"{jax_path}/layer_{i}"):
        layer_path = f"{jax_path}/layer_{i}"
        layer = LayerWeights(
            wq=jnp.load(f"{layer_path}/wq.npy"),
            wk=jnp.load(f"{layer_path}/wk.npy"),
            wv=jnp.load(f"{layer_path}/wv.npy"),
            wo=jnp.load(f"{layer_path}/wo.npy"),
            w1=jnp.load(f"{layer_path}/w1.npy"),
            w2=jnp.load(f"{layer_path}/w2.npy"),
            w3=jnp.load(f"{layer_path}/w3.npy"),
            ffn_norm=jnp.load(f"{layer_path}/ffn_norm.npy"),
            attention_norm=jnp.load(f"{layer_path}/attention_norm.npy")
        )
        layer_weights.append(layer)
        i += 1

    return XfmrWeights(
        tok_embeddings=tok_embeddings,
        norm=norm,
        output=output,
        layer_weights=layer_weights
    )

# Usage in new runtime:
save_path = "/content/drive/MyDrive/entropix_finetuned"

# Load all components
torch_model = AutoPeftModelForCausalLM.from_pretrained(f"{save_path}/torch")
tokenizer = AutoTokenizer.from_pretrained(f"{save_path}/torch")
xfmr_weights = load_saved_weights(save_path)

In [None]:
# def generate(xfmr_weights, model_params, tokens, tokenizer):  # Added tokenizer parameter
#     gen_tokens = None
#     cur_pos = 0
#     tokens = jnp.array([tokens], jnp.int32)
#     bsz, seqlen = tokens.shape
#     attn_mask = build_attn_mask(seqlen, cur_pos)
#     freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
#     kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
#     logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
#     next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
#     gen_tokens = next_token
#     print(tokenizer.decode([next_token.item()]), end='', flush=True)
#     cur_pos = seqlen
#     stop = jnp.array([128001, 128008, 128009])
#     while cur_pos < 8192:
#         cur_pos += 1
#         logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
#         next_token = sample(gen_tokens, logits, scores)
#         gen_tokens = jnp.concatenate((gen_tokens, next_token))
#         print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
#         if jnp.isin(next_token, stop).any():
#             break

# def analyze_ascii_pattern(ascii_art: str):
#     # Get dimensions
#     lines = ascii_art.strip().split('\n')
#     rows = len(lines)
#     cols = len(lines[0])

#     system_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
# You are an expert in pattern recognition and analysis.
# When given a binary ASCII art pattern composed of '0's and '1's, you carefully examine the pattern, describe how it changes from top to bottom, and identify any common symbols or shapes represented by the pattern.
# You explain your reasoning step by step.<|eot_id|><|start_header_id|>user<|end_header_id|>"""

#     user_prompt = f"""
# Here is a binary {rows}x{cols} ASCII art pattern:

# {ascii_art}

# Please analyze the pattern step by step.
# First, describe how the '1's form the shape and how it changes from top to bottom.
# Then, based on your analysis, identify the common symbol represented by this pattern.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
# <thinking>
# """

#     formatted_prompt = system_prompt + user_prompt

#     model_params = LLAMA_1B_PARAMS
#     xfmr_weights = load_weights()
#     tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct', token=myToken)
#     raw_tokens = tokenizer.encode(formatted_prompt)
#     print("\nAnalyzing Pattern...")
#     print("\nResponse:")
#     generate(xfmr_weights, model_params, raw_tokens, tokenizer)

# # Test with the smiley
# smiley = """00000000000000000
# 00011111111111000
# 00100000000000100
# 01000100000100010
# 01000100000100010
# 01000000000000010
# 01000100000100010
# 01000011111000010
# 00100000000000100
# 00011111111111000
# 00000000000000000"""

# analyze_ascii_pattern(smiley)

# # Example with custom system prompt
# custom_system = """You are an expert computer vision system specializing in binary pattern recognition.
# Your task is to analyze binary matrices as if they were images, identifying shapes, symmetries, and common symbols."""

# # Example with custom assistant header
# custom_header = """Initiating binary pattern analysis. Processing matrix structure and searching for recognizable patterns..."""

# # Usage with custom prompts
# # analyze_ascii_pattern(smiley, custom_system_prompt=custom_system, custom_assistant_header=custom_header)

# # More test cases you could try:
# heart = """0000000000000
# 0011100011100
# 0111110111110
# 0111111111110
# 0011111111100
# 0001111111000
# 0000111110000
# 0000011100000
# 0000001000000"""

# # analyze_ascii_pattern(heart)

In [None]:
def format_ascii_to_grid(ascii_art, fill_char='.'):
    """
    Converts ASCII art into a uniform grid by replacing spaces with fill characters,
    while removing unnecessary leading columns.

    Args:
        ascii_art (str): The input ASCII art string
        fill_char (str): Character to use for padding (default: '.')

    Returns:
        str: The formatted ASCII art with optimized grid layout
    """
    if len(fill_char) != 1:
        raise ValueError("Fill character must be exactly one character long")

    # Split into lines and remove empty lines
    lines = [line for line in ascii_art.split('\n') if line.strip()]

    # Find the maximum content width
    max_width = max(len(line.rstrip()) for line in lines)

    # First pass: format all lines with fill characters
    formatted_lines = []
    min_leading_fills = float('inf')

    for line in lines:
        # Count and replace leading spaces
        leading_spaces = len(line) - len(line.lstrip())
        content = line.lstrip()

        # Add trailing fill characters to make all lines same length
        trailing_spaces = max_width - (leading_spaces + len(content))
        final_line = (fill_char * leading_spaces) + content + (fill_char * trailing_spaces)
        formatted_lines.append(final_line)

        # Track minimum number of leading fill characters
        min_leading_fills = min(min_leading_fills, leading_spaces)

    # Second pass: remove unnecessary leading columns
    if min_leading_fills > 0:
        formatted_lines = [line[min_leading_fills:] for line in formatted_lines]

    return '\n'.join(formatted_lines)

def test_ascii_formatter():
    # Test case 1: Frog ASCII art
    frog_image = """            _     _
           (')-=-(')
         __(   "   )__
        / _/'-----'\_ \\
     ___\\ \\     // //___
     >____)/_\---/_\(____<"""

    print("=== Test 1: Frog ===")
    print(format_ascii_to_grid(frog_image))

    # Test case 2: Simple triangle
    simple_art = """      __
    o-''|\_____/)
    \_/|_)     )
        \  __  /
        (_/ (_/"""

    print("\n=== Test 2: Simple triangle ===")
    print(format_ascii_to_grid(simple_art))

    # Test case 3: Single line with spaces
    single_line = "  Hello  World  "
    print("\n=== Test 3: Single line ===")
    print(format_ascii_to_grid(single_line))

if __name__ == "__main__":
    test_ascii_formatter()

=== Test 1: Frog ===
......._     _.......
......(')-=-(')......
....__(   "   )__....
.../ _/'-----'\_ \...
___\ \     // //___..
>____)/_\---/_\(____<

=== Test 2: Simple triangle ===
..__.........
o-''|\_____/)
\_/|_)     ).
....\  __  /.
....(_/ (_/..

=== Test 3: Single line ===
Hello  World  
