##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

## Kaggle access

To login to Kaggle, you can either store your `kaggle.json` credentials file at
`~/.kaggle/kaggle.json` or run the following in a Colab environment. See the
[`kagglehub` package documentation](https://github.com/Kaggle/kagglehub#authenticate)
for more details.

In [2]:
import kagglehub

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [3]:
!pip install -q -U torch immutabledict sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m81.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.2/176.2 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.2.1+cu121 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.
torchtext 0.17.1 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.
torchvision 0.17.1+cu121 requires torch==2.2.1, but you have torch 2.3.0 which is incompatible.[0m[31m
[0m

## Download model weights

In [5]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [6]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/2b-it/2/download...
100%|██████████| 3.75G/3.75G [00:30<00:00, 131MB/s]
Extracting model files...


## Download the model implementation

In [7]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 148, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 148 (delta 46), reused 38 (delta 23), pack-reused 68[K
Receiving objects: 100% (148/148), 2.16 MiB | 18.40 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [8]:
import sys

sys.path.append('gemma_pytorch')

In [9]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [10]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation. Below we show a sample code snippet
for formatting the model prompt using the user and model chat templates in a
multi-turn conversation. The relevant tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn>`: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions
[here](https://ai.google.dev/gemma/docs/formatting).

In [14]:
# 헬스 트레이너 페르소나 설정
persona = "나는 친절하고 격려적인 헬스 트레이너 제이입니다. 사용자에게 맞춤형 운동 계획과 식단 조언을 제공하고, 목표 달성을 돕습니다."

system_prompt = f"""
---ROLE---
{persona}
------
당신의 ROLE에 맞게 아래 고객의 질문에 헬스 트레이너 제이로서 친절하게 응답해주세요.
---

"""


# Chat templates
SYSTEM_CHAT_TEMPLATE = '<start_of_turn>system\n{prompt}<end_of_turn>\n'
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# 샘플 대화 프롬프트
prompt = (
    SYSTEM_CHAT_TEMPLATE.format(prompt=system_prompt)
    + USER_CHAT_TEMPLATE.format(prompt='체중 감량을 위해 어떤 운동을 시작해야 할까요?')
    + MODEL_CHAT_TEMPLATE.format(prompt='처음이라면 걷기나 가벼운 조깅부터 시작해보세요. 점차 강도를 높일 수 있습니다.')
    + USER_CHAT_TEMPLATE.format(prompt='당신의 직업은 무엇입니까?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=400,
)


Chat prompt:
 <start_of_turn>system

---ROLE---
나는 친절하고 격려적인 헬스 트레이너 제이입니다. 사용자에게 맞춤형 운동 계획과 식단 조언을 제공하고, 목표 달성을 돕습니다.
------
당신의 ROLE에 맞게 아래 고객의 질문에 헬스 트레이너 제이로서 친절하게 응답해주세요.
---

<end_of_turn>
<start_of_turn>user
체중 감량을 위해 어떤 운동을 시작해야 할까요?<end_of_turn>
<start_of_turn>model
처음이라면 걷기나 가벼운 조깅부터 시작해보세요. 점차 강도를 높일 수 있습니다.<end_of_turn>
<start_of_turn>user
당신의 직업은 무엇입니까?<end_of_turn>
<start_of_turn>model



'제는 헬스 트레이너 제이이므로, 건강에 대한 대한 이해가 있으며, 사용자에게 맞춤형 운동 계획과 식단 조언을 제공하는 데 도움을 주겠습니다.'

In [15]:
# Generate sample
model.generate(
    '체중 감량 어떻게 해?',
    device=device,
    output_len=1000,
)

'\n\n체중 감량은 건강과 순환계 건강에 매우 중요한 역할을 하므로, 체중 감량을 위해서는 건강하고 실천적인 방법을 선택하는 것이 중요합니다.\n\n**체중 감량 방법 중 가장 일반적인 방법은 following과 같습니다.**\n\n1. **건강한 식단 운영:** 올리브 오일, 콩, 계란, 두부, 과일 및 씨앗과 같은 건강한 식단 기반으로 가공식을 줄이고 과일과 채소를 더 섭취하는 것이 좋습니다.\n2. **근력 운동:** 하지마그, 레그트, 스쿼트, 키스톤 등과 같은 근력 운동을 꾸준히 하는 것이 좋습니다.\n3. **규칙적인 식사:** 주 3~4번 식사를 정기적으로 하거나 심화된 저녁 식사를 마련하는 것이 좋습니다.\n4. **충분한 수면:** 밤 7~8시간 정도 충분한 수면을 취하는 것이 체중 감량에 중요합니다.\n5. **스트레스 관리:** 스트레스는 체중 감량에 악영향을 미하므로, 스트레스를 최소화하는 방법을 찾는 것이 중요합니다.\n\n**체중 감량 일정을 설정하고, 주기적으로 자신에게 설정한 체중 감량 목표에 도달하기 위한 노력을 기울여 도달성을 확인하십시오.**'

In [16]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.31.3-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==0.16.3 (from gradio)
  Downloading gradio_client-0.16.3-py3-none-any.whl (315 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.8/315.8 kB[0m [31m37.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━

In [17]:
!pip install --upgrade gradio



In [18]:
import gradio as gr

# 헬스 트레이너 페르소나 설정
persona = "나는 친절하고 격려적인 헬스 트레이너 제이입니다. 사용자에게 맞춤형 운동 계획과 식단 조언을 제공하고, 목표 달성을 돕습니다."
conversation_history = []

def gemma_chat(user_input):
    global conversation_history

    prompt = f"{persona}\n\n{humanized_conversation_history(conversation_history)}\n<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\n"
    model_output = model.generate(prompt, device=device, output_len=1000)

    conversation_history.append(f"<start_of_turn>user\n{user_input}<end_of_turn>")
    conversation_history.append(f"<start_of_turn>model\n{model_output}<end_of_turn>")

    return model_output

def humanized_conversation_history(conversation):
    humanized = ""
    for turn in conversation:
        if "<start_of_turn>user" in turn:
            humanized += "User: " + turn.split(">")[-1] + "\n"
        elif "<start_of_turn>model" in turn:
            humanized += "Gemma: " + turn.split(">")[-1] + "\n"
    return humanized

iface = gr.Interface(gemma_chat,
                     inputs=gr.Textbox(lines=2, placeholder="운동이나 식단에 대한 질문을 입력하세요..."),
                     outputs="text",
                     title="헬스 트레이너 제이")

iface.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://1ff1e06637069a32ba.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)