<a href="https://colab.research.google.com/github/marianoaloi/GeminiDemo/blob/main/site/en/gemma/docs/get_started.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/get_started"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  </td>
    <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/get_started.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/get_started.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Get started with Gemma using KerasNLP

This tutorial shows you how to get started with Gemma using [KerasNLP](https://keras.io/keras_nlp/). Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models. KerasNLP is a collection of natural language processing (NLP) models implemented in [Keras](https://keras.io/) and runnable on JAX, PyTorch, and TensorFlow.

In this tutorial, you'll use Gemma to generate text responses to several prompts. If you're new to Keras, you might want to read [Getting started with Keras](https://keras.io/getting_started/) before you begin, but you don't have to. You'll learn more about Keras as you work through this tutorial.

## Setup

### Gemma setup

To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on kaggle.com.
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.


### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [2]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [3]:
import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tpu = tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]


### Install dependencies

Install Keras and KerasNLP.

In [4]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -I -U keras==3
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Collecting keras==3
  Using cached keras-3.0.0-py3-none-any.whl (997 kB)
Collecting absl-py (from keras==3)
  Using cached absl_py-2.1.0-py3-none-any.whl (133 kB)
Collecting numpy (from keras==3)
  Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Collecting rich (from keras==3)
  Using cached rich-13.7.0-py3-none-any.whl (240 kB)
Collecting namex (from keras==3)
  Using cached namex-0.0.7-py3-none-any.whl (5.8 kB)
Collecting h5py (from keras==3)
  Using cached h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
Collecting dm-tree (from keras==3)
  Using cached dm_tree-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
Collecting markdown-it-py>=2.2.0 (from rich->keras==3)
  Using cached markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
Collecting pygments<3.0.0,>=2.13.0 (from rich->keras==3)
  Using cached pygments-2.17.2-py3-none-any.whl (1.2 MB)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


### Import packages

Import Keras and KerasNLP.

In [7]:
import keras
import keras_nlp

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3) lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial.

In [8]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

## Create a model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [9]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/1/download/config.json...
100%|██████████| 555/555 [00:00<00:00, 1.01MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/1/download/model.weights.h5...
100%|██████████| 4.67G/4.67G [01:08<00:00, 73.0MB/s]


KeyError: '0'

`from_preset` instantiates the model from a preset architecture and weights. In the code above, the string `"gemma_2b_en"` specifies the preset architecture: a Gemma model with 2 billion parameters. (A Gemma model with 7 billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud.)

Use `summary` to get more info about the model:

In [None]:
gemma_lm.summary()

As you can see from the summary, the model has 2.5 billion trainable parameters.

## Generate text

Now it's time to generate some text! The model has a `generate` method that generates text based on a prompt. The optional `max_length` argument specifies the maximum length of the generated sequence.

Try it out with the prompt `"What is the meaning of life?"`.

In [None]:
gemma_lm.generate("What is the meaning of life?", max_length=64)

Try calling `generate` again with a different prompt.

In [None]:
gemma_lm.generate("How does the brain work?", max_length=64)

If you're running on JAX or TensorFlow backends, you'll notice that the second `generate` call returns nearly instantly. This is because each call to `generate` for a given batch size and `max_length` is compiled with XLA. The first run is expensive, but subsequent runs are much faster.

You can also provide batched prompts using a list as input:

In [None]:
gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)

### Optional: Try a different sampler

You can control the generation strategy for `GemmaCausalLM` by setting the `sampler` argument on `compile()`. By default, [`"greedy"`](https://keras.io/api/keras_nlp/samplers/greedy_sampler/#greedysampler-class) sampling will be used.

As an experiment, try setting a [`"top_k"`](https://keras.io/api/keras_nlp/samplers/top_k_sampler/) strategy:

In [None]:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)

While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.

You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see [Samplers](https://keras.io/api/keras_nlp/samplers/).

## What's next

In this tutorial, you learned how to generate text using KerasNLP and Gemma. Here are a few suggestions for what to learn next:

* Learn how to [finetune a Gemma model](https://ai.google.dev/gemma/docs/lora_tuning).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).