##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

## Kaggle access

To login to Kaggle, you can either store your `kaggle.json` credentials file at
`~/.kaggle/kaggle.json` or run the following in a Colab environment. See the
[`kagglehub` package documentation](https://github.com/Kaggle/kagglehub#authenticate)
for more details.

In [1]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [2]:
!pip install -q -U torch==2.2.1 immutabledict sentencepiece

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.3 MB[0m [31m4.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[?25h

## Download model weights

In [3]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [4]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/2b-it/2/download...
100%|██████████| 3.75G/3.75G [02:59<00:00, 22.5MB/s]
Extracting model files...


## Download the model implementation

In [5]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 148, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 148 (delta 46), reused 38 (delta 23), pack-reused 68[K
Receiving objects: 100% (148/148), 2.16 MiB | 18.40 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [6]:
import sys

sys.path.append('gemma_pytorch')

In [7]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [8]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [9]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model



"* **Visit the Golden Gate Bridge and Alcatraz Island in San Francisco.**\n* **Explore the Redwood National and State Parks in Northern California.**\n* **Go hiking, surfing, or swimming in some of the world's most renowned beaches in California, including Santa Monica, Malibu, and San Diego.**\n* **Visit the Disneyland Resort in Anaheim, California.**\n* **Take a road trip along the Pacific Coast Highway, stopping at quaint towns and scenic overlooks along the way.**\n*"

In [10]:
# Generate sample
model.generate(
    '# instructions: You  are a helpful LLM helping user in a simple task. Answer user request and do not add anything else.\n\n# user request:\n\ncapture the semantic relationships between words "cat" and "night".',
    device=device,
    output_len=100,
)

'\n\n# my answer:\n\nSure, here\'s the semantic relationships between the words "cat" and "night":\n\n- **Part-of-speech:** Cat is a noun, and night is a noun.\n- **Syntactic role:** Cat is the object of the sentence "The cat chased the bird".\n- **Semantic role:** Cat is a kind of animal, and night is a period of time of day.\n- **Semantic relationships:** Cat and night are highly'

In [11]:
model.generate(
    '<insert text here>',
    device=device,
    output_len=350,
)

Formatted prompt:
 <start_of_turn>user
Distillation: process of extracting the essential elements or core meaning from complex information, ideas, or experiences. Refinement: Identifying and isolating the most crucial components while removing unnecessary or irrelevant details. Concentration: Focusing on the fundamental essence of the subject matter, increasing the clarity and potency of the information. Simplification: Presenting the distilled information in a concise, easily understandable form that captures the central ideas or principles. The goal of distillation is to provide a concentrated, purified representation of the original content, making it more accessible, memorable, and actionable. By reducing complexity and noise, distillation enables more efficient communication, understanding, and application of knowledge. Distillation is the art of extracting the essential, concentrating it, and presenting it in a simplified, potent form that captures the core meaning or value of th

TypeError: GemmaForCausalLM.generate() got an unexpected keyword argument 'prompt'

In [12]:
model.generate(
    'Distillation: process of extracting the essential elements or core meaning from complex information, ideas, or experiences. Refinement: Identifying and isolating the most crucial components while removing unnecessary or irrelevant details. Concentration: Focusing on the fundamental essence of the subject matter, increasing the clarity and potency of the information. Simplification: Presenting the distilled information in a concise, easily understandable form that captures the central ideas or principles. The goal of distillation is to provide a concentrated, purified representation of the original content, making it more accessible, memorable, and actionable. By reducing complexity and noise, distillation enables more efficient communication, understanding, and application of knowledge. Distillation is the art of extracting the essential, concentrating it, and presenting it in a simplified, potent form that captures the core meaning or value of the original subject matter. \n\nYour job is to distill the key semantic relationships between words.\n\nexample 1: Distill the key semantic relationships between the words "cat" and "night".\n\nanswer: Nocturnal Nature: Cats are often associated with the night due to their nocturnal tendencies. Many cat species are more active during nighttime hours. Stealth and Hunting: The darkness of night provides an advantageous environment for cats to stealthily hunt prey, with their keen senses well-adapted to low-light conditions. Mystery and Magic: In folklore and superstition, cats (especially black cats) have been linked to mysterious, magical, or supernatural elements that are often associated with the night. Lunar Symbolism: Cats have been symbolically connected to the moon, which is a prominent feature of the night sky. This association may stem from their nocturnal nature and the reflective quality of their eyes. Independent Exploration: Night represents a time when cats can freely roam and explore their surroundings with fewer disturbances, appealing to their independent nature.The relationship between cat and night is rooted in cats nocturnal behavior, their adaptations for thriving in dark environments, and the cultural symbolism and mythology that has emerged from observing these characteristics. The night provides a fitting backdrop for the mysterious, independent, and instinctual nature often attributed to cats.\n\nexample2: Distill the key semantic relationships between the words "I" and "self".\n\nanswer: I am the embodiment of my subjective identity, the unique individual with my own thoughts, feelings, and experiences. My self is the core of who I am, the essence that defines me as a distinct entity in this world. I possess an innate awareness of my own existence, a consciousness that allows me to recognize myself as separate from others. My self is the object of this awareness, the very identity I acknowledge and embrace. I am the driving force behind my actions, the wielder of my own free will. My self is the wellspring of this agency, the source from which my decisions and choices originate.Through the passage of time and the ever-changing circumstances of life, I remain a constant presence, a continuous thread that weaves the tapestry of my identity. My self is the cohesive, unified entity that persists beneath the surface, providing stability and coherence to my being. I possess the remarkable ability to turn my gaze inward, to reflect upon my own thoughts, emotions, and actions. My self is simultaneously the subject and the object of this introspection, enabling me to gain a deeper understanding of who I am and to shape my identity according to my own vision. I am the living, breathing embodiment of my subjective existence, and my self is the very core of this identity. Together, we form an inseparable unity, a singular being with the power to experience, reflect, and shape our own reality. I am the "I," and my self is the foundation upon which I build my unique presence in this world.\n\nRequest: Distill the key semantic relationships between the words "signal" and "ontology".\n\nanswer:',
    device=device,
    output_len=350,
)

' Signals are formal assertions or propositions that serve as defining principles or axioms in an ontology. They provide a framework for understanding the meaning and structure of the knowledge system and ensure that different concepts are aligned with a common set of underlying principles. Signals serve as the foundational blocks of inquiry, enabling scientists and scholars to reason about the world and test hypotheses. They help to establish the boundaries and constraints within which knowledge is constructed, thus contributing to the development of a robust and coherent ontology.'

## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)