<a href="https://colab.research.google.com/github/jenriver/bonsai/blob/qwen3/bonsai/models/qwen3/qwen3_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Suggested runtime: TPU v2-8

## 1. Setup the environment.

In [None]:
!pip install -q setuptools
!pip install -q ml-dtypes
!pip install -q kagglehub
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for flax (pyproject.toml) ... [?25l[?25hdone


In [None]:
!pip uninstall -q -y jax-bonsai
!pip install -q git+https://github.com/jenriver/jax-bonsai@qwen3

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for jax-bonsai (pyproject.toml) ... [?25l[?25hdone


## 2. Download the pretrained weights.

In [None]:
from huggingface_hub import snapshot_download

model_name = "Qwen/Qwen3-0.6B"

MODEL_CP_PATH = "./qwen3-0.6b-weights"  # Specify your desired download directory

# Download all files from the repository
snapshot_download(repo_id=model_name, local_dir=MODEL_CP_PATH)

print(f"Model weights and files downloaded to: {MODEL_CP_PATH}")

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

Model weights and files downloaded to: ./qwen3-0.6b-weights


## 3. Create the model.

In [None]:
from bonsai.models.qwen3 import params
from bonsai.models.qwen3 import model
from flax import nnx

MODEL_CP_PATH = "/content/drive/MyDrive/colab-mount/qwen-3-transformers-0.6b-v1"

config = model.ModelConfig.qwen3_0_6_b()  # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config)
nnx.display(qwen3)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)

## 4. Generate from the model.

In [None]:
from bonsai.generate import sampler


def templatize(prompts):
    out = []
    for p in prompts:
        out.append(
            tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": p},
                ],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=True,
            )
        )
    return out


inputs = templatize(
    [
        "which is larger 9.9 or 9.11?",
        "Why is the sky blue?",
        "How do you say cheese in French?",
    ]
)

sampler = sampler.Sampler(
    qwen3, tokenizer, sampler.CacheConfig(cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128)
)
out = sampler(inputs, total_generation_steps=128, echo=True)

for t in out.text:
    print(t)
    print("*" * 30)

<|im_start|>user
which is larger 9.9 or 9.11?<|im_end|>
<|im_start|>assistant
<think>
Okay, so I need to figure out which number is larger between 9.9 and 9.11. Let me think. Both numbers are in decimal form, right? 9.9 and 9.11. Hmm, decimal numbers can be tricky sometimes, but I remember that when comparing decimals, you can look at the digits from left to right, starting with the first non-zero digit. 

First, let me write them down to visualize better: 9.9 and 9.11. Both start with a 9. So, the first digit after the decimal is the tenths place
******************************
<|im_start|>user
Why is the sky blue?<|im_end|>
<|im_start|>assistant
<think>
Okay, the user is asking why the sky is blue. I need to explain this in a clear and accurate way. First, I should mention the basic reason, which is the scattering of sunlight by the Earth's atmosphere. Then, I should explain the Rayleigh scattering theory. Maybe include some examples, like how different colors are scattered in differe