##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Kaggle CPU Python runtime and GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

### Gemma setup

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:

- Sign in or register at [kaggle.com](https://www.kaggle.com)
- Open the [Gemma 2 model card](https://www.kaggle.com/models/google/gemma-2) and select _"Request Access"_
- Complete the consent form and accept the terms and conditions


## Install dependencies

In [2]:
!pip install -q -U torch immutabledict sentencepiece

## Download model weights

In [3]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '9b', '9b-it', '27b', '27b-it']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'

In [4]:
weights_dir = "/kaggle/input/gemma-2/pytorch/gemma-2-2b-it/1"

In [5]:
import os
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

## Download the model implementation

In [6]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116[K
Receiving objects: 100% (239/239), 2.18 MiB | 25.32 MiB/s, done.
Resolving deltas: 100% (135/135), done.


In [7]:
import sys

sys.path.append('gemma_pytorch')

In [8]:
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

## Setup the model

In [9]:
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn><eos>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [10]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

results = model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
print(results)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model

California is bursting with incredible experiences, so to recommend something truly good for you, I need a bit more info!  

Tell me:

* **What type of trip are you looking for?**  Relaxing beach vacation? Adventure-filled hiking trip? Big city excitement? Historical and cultural exploration? Foodie tour? 
* **What time of year are you planning to visit?** This drastically impacts weather and activities. 
* **Who are you traveling with?**  Family? Friends? Solo?  
* **What's your budget?**  There are amazing options for all budgets, but it


In [11]:
# Generate sample
results = model.generate(
    'What is colour of Tangerine ?',
    device=device,
    output_len=128,
)
print(results)



Tangerines are the small, delicious, sweet orange. They are known for their vibrant orange colour.

So, the colour of a tangerine is **orange**. 
<end_of_turn>


## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)