##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This is a quick demo of Gemma running on KerasNLP. To run this you will need:
- To be added to a private github repo for Gemma.
- To be added to a private Kaggle model for weights.

Note that you will need a large GPU (e.g. A100) to run this as well.

General Keras reading:
- [Getting started with Keras](https://keras.io/getting_started/)
- [Getting started with KerasNLP](https://keras.io/guides/keras_nlp/getting_started/)
- [Generation and fine-tuning guide for GPT2](https://keras.io/guides/keras_nlp/getting_started/)

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Keras_Gemma_2_Quickstart.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Access

In [1]:
import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

## Installation

In [2]:
# Install all deps
!pip install keras
!pip install keras-nlp

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m589.8/589.8 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m95.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m76.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m107.5 MB/s[0m eta [36m

## Quickstart

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

import keras_nlp
import keras

# Run at half precision.
keras.config.set_floatx("bfloat16")

In [9]:
# Connect using the default `gemma2_9b_keras` or through huggingface weights `hf://google/gemma-2-9b-keras`
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")
gemma_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/metadata.json...
100%|██████████| 143/143 [00:00<00:00, 179kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/task.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/config.json...
100%|██████████| 780/780 [00:00<00:00, 895kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/model.weights.h5...
100%|██████████| 17.2G/17.2G [18:34<00:00, 16.6MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/tokenizer.json...
100%|██████████| 315/315 [00:00<00:00, 431kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/assets/tokenizer/voc

In [11]:
gemma_lm.generate("What is the meaning of life?", max_length=32)

'What is the meaning of life?\n\n[Answer 1]\n\nThe meaning of life is to live it.\n\n[Answer 2]\n\nThe'