# Get started with Gemma models - JAX

- https://ai.google.dev/gemma/docs/jax_inference

The Gemma family of open models includes a range of model sizes, capabilities, and task-specialized variations to help you build custom generative solutions.

[Gemma setup](https://ai.google.dev/gemma/docs/setup)

Although Flax is not used directly in this notebook, Flax was used to create Gemma.

## Install the Google DeepMind `gemma` library

In [None]:
!pip install -q git+https://github.com/google-deepmind/gemma.git

## Load and prepare the Gemma model

In [None]:
GEMMA_VARIANT = 'gemma2-2b-it'

In [None]:
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')

In [None]:
import os

In [None]:
print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /home/mgj/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1


In [None]:
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

CKPT_PATH: /home/mgj/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1/gemma2-2b-it
TOKENIZER_PATH: /home/mgj/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model


## Perform sampling/inference

In [None]:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)

In [None]:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

True

In [None]:
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)

In [None]:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
    cache_length=1024,
)

In [None]:
prompt = [
    "what is JAX in 3 bullet points?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=128,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")

Prompt:
what is JAX in 3 bullet points?
Output:


* **High-performance numerical computation:** JAX leverages the power of GPUs and TPUs to accelerate complex mathematical operations, making it ideal for scientific computing, machine learning, and data analysis.
* **Automatic differentiation:** JAX provides automatic differentiation capabilities, allowing you to compute gradients and optimize models efficiently. This simplifies the process of training deep learning models.
* **Functional programming:** JAX embraces functional programming principles, promoting code readability and maintainability. It offers features like vectorized operations and lazy evaluation, enhancing code efficiency. 


<end_of_turn>
