In [None]:
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix
!pip uninstall -q -y flax
!pip install --no-cache-dir git+https://github.com/google/flax.git
!pip install -q huggingface_hub
!pip install -q humanize

In [None]:
!huggingface-cli login

In [None]:
import functools
from flax import nnx
from huggingface_hub import snapshot_download
import humanize
import jax
from tunix.models.gemma import model as model_lib
from tunix.models.gemma import params_safetensors as params_lib

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  print("\n--- TPU HBM Usage ---")
  for i, d in enumerate(jax.local_devices()):
    stats = d.memory_stats()
    used = stats.get("bytes_in_use", 0)
    limit = stats.get("bytes_limit", 0)

    hbm_used = stats.get("device:0:HBM0:bytes_in_use", used)
    hbm_limit = stats.get("device:0:HBM0:bytes_limit", limit)

    # Fallback if specific HBM stats not available
    if hbm_limit == 0:
      hbm_used = used
      hbm_limit = limit

    percentage = (hbm_used / hbm_limit * 100) if hbm_limit > 0 else 0

    print(
        f"Device {i} ({d.device_kind}): Using {fmt_size(hbm_used)} /"
        f" {fmt_size(hbm_limit)} ({percentage:.2f}%)"
    )

  print("--- End HBM Usage ---")

# Download the weights and load into TPU

In [None]:
model_id = "google/gemma-2-2b-it"
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

In [None]:
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()

In [None]:
MODEL_CP_PATH = local_model_path

config = model_lib.TransformerConfig.gemma2_2b()
MESH = [(1, 1), ("fsdp", "tp")]  # update this based on your # TPU devices
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma = params_lib.create_model_from_safe_tensors(MODEL_CP_PATH, config, mesh)
  nnx.display(gemma)

In [None]:
print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()

# Run inference

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)

In [None]:
from tunix.generate import sampler


def templatize(prompts):
  out = []
  for p in prompts:
    out.append(
        tokenizer.apply_chat_template(
            [
                {"role": "user", "content": p},
            ],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
    )
  return out


inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Spanish",
])

sampler = sampler.Sampler(
    gemma,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256,
        num_layers=config.num_layers,
        num_kv_heads=config.num_kv_heads,
        head_dim=config.head_dim,
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)