# Example: Using Pretrained Gemma

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.

## Installation

In [1]:
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Then run the cell below.

In [2]:
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.

In [3]:
from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

clear_output()

## Python imports

In [4]:
from flax import nnx
import sentencepiece as spm

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.

In [5]:
import sys
import tempfile

with tempfile.TemporaryDirectory() as tmp:
  # Here we create a temporary directory and clone the flax repo
  # Then we append the examples/gemma folder to the path to load the gemma modules
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import params as params_lib
  import sampler as sampler_lib
  import transformer as transformer_lib
  sys.path.pop();

Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.[K
remote: Counting objects: 100% (605/605), done.[K
remote: Compressing objects: 100% (250/250), done.[K
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)[K
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.


## Start Generating with Your Model

Load and prepare your LLM's checkpoint for use with Flax.

In [6]:
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)

Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [7]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

True

Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

In [8]:
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)

Finally, build a sampler on top of your model and your tokenizer.

In [9]:
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

In [10]:
input_batch = [
  "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # number of steps performed when generating
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')

Prompt:

# Python program for implementation of Bubble Sort

def bubbleSort(arr):
Output:

    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp


# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)


# Time complexity of Bubble sort O(n^2)
# where n is the length of the array


# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation


# This program uses the bubble sort algorithm to sort the given array in ascending order.

```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.

def bubbleSort(arr):
    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(

#

You should get an implementation of bubble sort.