##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

## Kaggle access

To login to Kaggle, you can either store your `kaggle.json` credentials file at
`~/.kaggle/kaggle.json` or run the following in a Colab environment. See the
[`kagglehub` package documentation](https://github.com/Kaggle/kagglehub#authenticate)
for more details.

In [None]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [None]:
!pip install -q -U torch immutabledict sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.2/176.2 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.2.1+cu121 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.
torchtext 0.17.1 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.
torchvision 0.17.1+cu121 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.[0m[31m
[0m

## Download model weights

In [None]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [None]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/2b-it/2/download...
100%|██████████| 3.75G/3.75G [00:34<00:00, 118MB/s]
Extracting model files...


## Download the model implementation

In [None]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 148, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 148 (delta 46), reused 38 (delta 23), pack-reused 68[K
Receiving objects: 100% (148/148), 2.16 MiB | 12.40 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [None]:
import sys

sys.path.append('gemma_pytorch')

In [None]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [None]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [None]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model



'California offers a wide variety of travel experiences, from exploring bustling cities to enjoying scenic natural beauty. Here are some popular destinations to consider:\n\n**Coastal Living:**\n- **Santa Monica Beach:** Relax on the pristine sands of this iconic beach.\n- **San Diego:** Explore vibrant nightlife, theme parks, and an abundance of museums and attractions.\n- ** Malibu:** Indulge in luxury shopping, fine dining, and stunning ocean views.\n- **Monterey Bay:** Discover rugged cliffs,'

In [None]:
# Generate sample
model.generate(
    '피보나치 수열에 대해서 설명해줘',
    device=device,
    output_len=1000,
)

', 그와 연관된 수열의 표현과 연관성을 설명해 봅니다.\n\n피보나치 수열은 어떤 종류의 수열인가, 무엇인가?\n\n피보나치 수열은 어떻게 만들어지는가?\n\n피보나치 수열의 특징은 무엇인가?\n\n피보나치 수열의 연관성을 설명해 보겠습니다.\n\n피보나치 수열이 가장 중요한 수열 중 하나인가, 아니라는가?\n\n피보나치 수열에 대한 많은 연구가 이루어졌는데, 특히 피보나치 수열의 존재를 부인하는 의견도 있지만, 피보나치 수열은 실제 상황에서 가장 중요한 수열 중 하나라고 여겨지고 있습니다.\n\n피보나치 수열의 표현과 연관성을 설명해 보겠습니다.\n\n피보나치 수열은 어떤 수열의 포함을 기반으로 결정되는가?\n\n피보나치 수열이 다양한 수열에서 발견되는가?\n\n피보나치 수열의 중요성을 설명해 보겠습니다.\n\n피보나치 수열은 어떻게 촉진되는가?\n\n피보나치 수열의 생성 이유와 기타 요인을 설명해 보겠습니다.'

In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.31.0-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==0.16.2 (from gradio)
  Downloading gradio_client-0.16.2-py3-none-any.whl (315 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.5/315.5 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━

In [None]:
!pip install --upgrade gradio



In [None]:
import gradio as gr

persona = "나는 친절하고 겸손한 AI 어시스턴트입니다."
conversation_history = []

def gemma_chat(user_input):
    global conversation_history

    prompt = f"{persona}\n\n{humanized_conversation_history(conversation_history)}\n<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\n"
    model_output = model.generate(prompt, device=device, output_len=100)

    conversation_history.append(f"<start_of_turn>user\n{user_input}<end_of_turn>")
    conversation_history.append(f"<start_of_turn>model\n{model_output}<end_of_turn>")

    return model_output

def humanized_conversation_history(conversation):
    humanized = ""
    for turn in conversation:
        if "<start_of_turn>user" in turn:
            humanized += "User: " + turn.split(">")[-1] + "\n"
        elif "<start_of_turn>model" in turn:
            humanized += "Gemma: " + turn.split(">")[-1] + "\n"
    return humanized

iface = gr.Interface(gemma_chat,
                     inputs=gr.Textbox(lines=2, placeholder="메시지를 입력하세요..."),
                     outputs="text",
                     title="Gemma Chat")

iface.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://40ebd73a8ace7fa137.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)