# Get started with Gemma models - Keras (TensorFlow)

- https://ai.google.dev/gemma/docs/keras_inference

The Gemma family of open models includes a range of model sizes, capabilities, and task-specialized variations to help you build custom generative solutions.

[Gemma setup](https://ai.google.dev/gemma/docs/setup)

In [None]:
!pip install -q -U keras-nlp

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

In [None]:
import keras
import keras_nlp

2025-01-30 20:27:30.790585: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738236450.801981   82719 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738236450.805093   82719 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-30 20:27:30.817911: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Create a model

In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

I0000 00:00:1738236453.628914   82719 gpu_process_state.cc:201] Using CUDA malloc Async allocator for GPU: 0
I0000 00:00:1738236453.629105   82719 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9711 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:0b:00.0, compute capability: 8.6
2025-01-30 20:27:36.057548: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2359296000 exceeds 10% of free system memory.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [None]:
gemma_lm.summary()

As you can see from the summary, the model has 2.6 billion trainable parameters.

For purposes of naming the model ("2B"), the embedding layer is not counted against the number of parameters.

## Generate text

The model has a `generate` method that generates text based on a prompt. The optional `max_length` argument specifies the maximum length of the generated sequence.

In [None]:
%%time
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)

I0000 00:00:1738236469.387060   82719 service.cc:148] XLA service 0x238d0a60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738236469.387100   82719 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2025-01-30 20:27:49.860342: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1738236472.202092   82719 cuda_dnn.cc:529] Loaded cuDNN version 90501







2025-01-30 20:28:01.865950: W external/local_xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -135.78MiB (-142372821 bytes) by rematerialization; only reduced to 968.60MiB (1015654653 bytes), down from 968.60MiB (1015655053 bytes) originally
I0000 00:00:1738236485.191400   82719 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


CPU times: user 50.9 s, sys: 1.18 s, total: 52.1 s
Wall time: 25.2 s


'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

In [None]:
%%time
gemma_lm.generate("The universe is", max_length=64)

CPU times: user 3.23 s, sys: 2.13 ms, total: 3.23 s
Wall time: 3.33 s


'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

If you're running on JAX or TensorFlow backends, you'll notice that the second `generate` call returns nearly instantly. This is because each call to `generate` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are much faster.

You can also provide batched prompts using a list as input:

In [None]:
gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)










2025-01-30 20:28:31.077089: W external/local_xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -135.78MiB (-142373429 bytes) by rematerialization; only reduced to 86.60MiB (90807759 bytes), down from 87.07MiB (91299279 bytes) originally


['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

## Try a different sampler

By default, `"greedy"` sampling will be used.

In [None]:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)

2025-01-30 20:29:39.760096: W external/local_xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -135.78MiB (-142372844 bytes) by rematerialization; only reduced to 968.60MiB (1015654685 bytes), down from 968.60MiB (1015655085 bytes) originally


'The universe is full of wonders and one of them is the fact that you can get a tattoo of a dog. Dogs are the most beloved pets of people, and the fact that they can get tattooed with a dog is just one of the many things that show the deep love that people have for dogs.\n\nA'

While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.