# Fixed Basic Inference Example for Tunix

This notebook demonstrates how to:

- Initialize a Tunix tokenizer safely (with a mock fallback for local setups)
- Create a dummy Gemma model using `dummy_model_creator`
- Configure the `Sampler` for inference
- Run inference on multiple prompts

It is designed to run locally without requiring access to remote checkpoints, 
making it a beginner-friendly and reproducible example.


In [1]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Fixed Basic Inference Example for Tunix (Notebook version).

This notebook demonstrates how to correctly initialize a Tunix model
(using dummy weights) and use the Sampler class for text generation.
"""

from flax import nnx
import jax
from tunix import Tokenizer
from tunix.generate import sampler
from tunix.models import dummy_model_creator
from tunix.models.gemma import model as gemma_model

In [None]:
print("Initializing Tokenizer...")

try:
  tokenizer = Tokenizer()
  print("Tokenizer initialized successfully.")
except Exception as e:
  print(
      f"Warning: Could not initialize real tokenizer ({e}). Using mock"
      " tokenizer."
  )

  class MockTokenizer:

    def encode(self, s):
      # Return dummy token IDs
      return [1, 2, 3]

    def decode(self, t):
      # Always return a dummy string
      return "dummy output"

    def pad_id(self):
      return 0

    def bos_id(self):
      return 1

    def eos_id(self):
      return 2

  tokenizer = MockTokenizer()

Initializing Tokenizer...


In [None]:
print("Initializing Dummy Model (Gemma 2B config)...")

config = gemma_model.ModelConfig.gemma_2b()

model = dummy_model_creator.create_dummy_model(
    gemma_model.Transformer,
    config,
    dtype=jax.numpy.float32,
)

print("Dummy model created.")

In [None]:
print("Initializing Sampler...")

cache_config = sampler.CacheConfig(
    cache_size=1024,  # Max sequence length
    num_layers=config.num_layers,
    num_kv_heads=config.num_kv_heads,
    head_dim=config.head_dim,
)

inference_sampler = sampler.Sampler(
    transformer=model,
    tokenizer=tokenizer,
    cache_config=cache_config,
)

print("Sampler initialized.")

In [None]:
prompts = [
    "If I have 3 apples and eat 1, how many remain?",
    "Write a short story about a robot.",
]

print("\n" + "=" * 50)
print("Starting Generation...")

output = inference_sampler(
    input_strings=prompts,
    max_generation_steps=50,
    temperature=0.7,
    echo=True,  # include prompt in output
)

for i, text in enumerate(output.text):
  print("\n" + "-" * 50)
  print(f"Prompt {i+1}: {prompts[i]}")
  print(f"Generated: {text}")

## Notes

- This example uses **dummy model weights** via `dummy_model_creator` so that:
  - It does not require downloading large real checkpoints.
  - It is safe to run on local CPUs/GPUs with limited memory.
- The tokenizer is wrapped in a **try/except**:
  - If a real tokenizer cannot be initialized (e.g., missing HF model or remote paths),
    a lightweight mock tokenizer is used instead.
- The `Sampler` is configured explicitly via `CacheConfig`, showing:
  - `cache_size`
  - `num_layers`
  - `num_kv_heads`
  - `head_dim`
- This notebook is intended as a **reference/example** for:
  - New contributors exploring the Tunix generation stack.
  - Testing changes to `Sampler`, `dummy_model_creator`, or Gemma configs.
