<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Inference with CodeGemma using JAX and Flax

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/codegemma/codegemma_flax_inference"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/codegemma/codegemma_flax_inference.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/codegemma/codegemma_flax_inference.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

We present CodeGemma, a collection of open code models based on Google DeepMind’s Gemma models (Gemma Team et al., 2024).
CodeGemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Continuing from Gemma pretrained models, CodeGemma models are further trained on more than 500 to 1000 billion tokens of primarily code, using
the same architectures as the Gemma model family. As a result, CodeGemma models achieve state of-the-art code performance in both completion
and generation tasks, while maintaining strong
understanding and reasoning skills at scale.

CodeGemma has 3 variants:

* A 7B code pretrained model
* A 7B instruction-tuned code model
* A 2B model, trained specifically for code infilling and open-ended generation.

This guide walks you through using the CodeGemma model with Flax for a code completion task.

**Note:** This notebook runs on TPU v2 in Google Colab because T4 GPU has insufficient memory.

## Setup

### 1. Set up Kaggle access for CodeGemma

To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:

* Get access to CodeGemma on [kaggle.com](https://www.kaggle.com/models/google/codegemma/).
* Select a Colab runtime with sufficient resources (**T4 GPU has insufficient memory, use TPU v2 instead**) to run the CodeGemma model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### 2. Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the "Grant access?" messages, agree to provide secret access.

In [None]:
import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### 3. Install the `gemma` library

Free Colab hardware acceleration is currently *insufficient* to run this notebook. If you are using [Colab Pay As You Go or Colab Pro](https://colab.research.google.com/signup), click on **Edit** > **Notebook settings** > Select **A100 GPU** > **Save** to enable hardware acceleration.

Next, you need to install the Google DeepMind `gemma` library from [`github.com/google-deepmind/gemma`](https://github.com/google-deepmind/gemma). If you get an error about "pip's dependency resolver", you can usually ignore it.

**Note:** By installing `gemma`, you will also install [`flax`](https://flax.readthedocs.io), core [`jax`](https://jax.readthedocs.io), [`orbax`](https://orbax.readthedocs.io/), and [`sentencepiece`](https://github.com/google/sentencepiece).

In [None]:
!pip install -q git+https://github.com/google-deepmind/gemma.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.7/133.7 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for gemma (pyproject.toml) ... [?25l[?25hdone


### 4. Import libraries

This notebook uses [Gemma](https://github.com/google-deepmind/gemma) (which uses [Flax](https://flax.readthedocs.io) to build its neural network layers), and [SentencePiece](https://github.com/google/sentencepiece) (for tokenization).

In [None]:
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

## Load the CodeGemma model

Load the CodeGemma model with [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12), which takes three arguments:

- `handle`: The model handle from Kaggle
- `path`: (Optional string) The local path
- `force_download`: (Optional boolean) Forces to re-download the model

**Note:** Be mindful that the `2b-pt` model is around 3.66Gb in size.

In [None]:
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}

In [None]:
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')

Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...


In [None]:
print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3


**Note:** The path from the output above is where the model weights and tokenizer are saved locally, you will need them for later.

Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:

- The `spm.model` tokenizer file will be in `/LOCAL/PATH/TO/codegemma/flax/2b-pt/3`
- The model checkpoint will be in `/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt`

In [None]:
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model


## Perform sampling/inference

Load and format the CodeGemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:

In [None]:
params = params_lib.load_and_format_params(CKPT_PATH)

Load the CodeGemma tokenizer, constructed using [`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423):

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

True

To automatically load the correct configuration from the CodeGemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the CodeGemma `Transformer` cache. Afterwards, instantiate the CodeGemma model as `model_2b` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).

**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current CodeGemma release.

In [None]:
transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

Create a `sampler` with [`gemma.sampler.Sampler`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/sampler.py#L88). It uses the CodeGemma model checkpoint and the tokenizer.

In [None]:
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

Create some variables to represent the fill-in-the-middle (fim) tokens and create some helper functions to format the prompt and generated output.

For example, let's look at the following code:
```
def function(string):
assert function('asdf') == 'fdsa'
```
We would like to fill in the `function` so that the assertion holds `True`. In this case, the prefix would be:
```
"def function(string):\n"
```
And the suffix would be:
```
"assert function('asdf') == 'fdsa'"
```
We then format this into a prompt as PREFIX-SUFFIX-MIDDLE (the middle section that needs to be filled is always at the end of the prompt):
```
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
```

In [None]:
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

Create a prompt and perform inference. Specify the prefix `before` text and the suffix `after` text and generate the formatted prompt using the helper function `format_completion prompt`.

You can tweak `total_generation_steps` (the number of steps performed when generating a response — this example uses `100` to preserve host memory).

**Note:** If you run out of memory, click on **Runtime** > **Disconnect and delete runtime**, and then **Runtime** > **Run all**.

In [None]:
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])


ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'


In [None]:
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])


ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)


In [None]:
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])


ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)



## Learn more

- You can learn more about the Google DeepMind [`gemma`  library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),
[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and
[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).
- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).
- For `sentencepiece` tokenizer/detokenizer documentation, check out [Google's `sentencepiece` GitHub repo](https://github.com/google/sentencepiece).
- For `kagglehub` documentation, check out `README.md` on [Kaggle's `kagglehub` GitHub repo](https://github.com/Kaggle/kagglehub).
- Learn how to [use Gemma models with Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
- If you are using Google Cloud TPUs (v3-8 and newer), make sure to also update to the latest `jax[tpu]` package (`!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html`), restart the runtime, and check that `jax` and `jaxlib` versions match (`!pip list | grep jax`). This can prevent the `RuntimeError` that can arise because of the `jaxlib` and `jax` version mismatch. For more JAX installation instructions, refer to the [JAX docs](https://jax.readthedocs.io/en/latest/tutorials/installation.html#install-google-tpu).