# Get started with Gemma models - Keras (JAX)

- https://ai.google.dev/gemma/docs/keras_inference

The Gemma family of open models includes a range of model sizes, capabilities, and task-specialized variations to help you build custom generative solutions.

[Gemma setup](https://ai.google.dev/gemma/docs/setup)

In [None]:
!pip install -q -U keras-nlp

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["TF_GPU_ALLOCATOR"] =  "cuda_malloc_async"

In [None]:
import keras
import keras_nlp

2025-01-30 20:39:11.396449: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738237151.407460   87773 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738237151.410579   87773 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Create a model

In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

2025-01-30 20:39:19.878722: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
I0000 00:00:1738237159.878768   87773 gpu_process_state.cc:201] Using CUDA malloc Async allocator for GPU: 0
I0000 00:00:1738237159.878906   87773 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:0b:00.0, compute capability: 8.6
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [None]:
gemma_lm.summary()

As you can see from the summary, the model has 2.6 billion trainable parameters.

For purposes of naming the model ("2B"), the embedding layer is not counted against the number of parameters.

## Generate text

The model has a `generate` method that generates text based on a prompt. The optional `max_length` argument specifies the maximum length of the generated sequence.

In [None]:
%%time
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)

2025-01-30 20:39:43.691007: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below -135.78MiB (-142372821 bytes) by rematerialization; only reduced to 484.32MiB (507845550 bytes), down from 497.69MiB (521866146 bytes) originally


CPU times: user 48.3 s, sys: 1.15 s, total: 49.5 s
Wall time: 29.8 s


'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

In [None]:
%%time
gemma_lm.generate("The universe is", max_length=64)

CPU times: user 2.16 s, sys: 18.4 ms, total: 2.18 s
Wall time: 2.25 s


'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

If you're running on JAX or TensorFlow backends, you'll notice that the second `generate` call returns nearly instantly. This is because each call to `generate` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are much faster.

You can also provide batched prompts using a list as input:

In [None]:
gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)

2025-01-30 20:40:15.000124: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below -135.78MiB (-142373429 bytes) by rematerialization; only reduced to 101.50MiB (106429536 bytes), down from 101.50MiB (106431248 bytes) originally


['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

## Try a different sampler

By default, `"greedy"` sampling will be used.

In [None]:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)

2025-01-30 20:40:28.491052: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below -135.78MiB (-142372844 bytes) by rematerialization; only reduced to 484.32MiB (507845554 bytes), down from 497.69MiB (521866166 bytes) originally


'The universe is made of energy, but we cannot see or touch the energy. The energy is the source of all things, and it is constantly being transformed. The universe is a complex system, with a vast amount of energy and matter. The energy in the universe is constantly being transformed, and this transformation is what'

While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.