##### 版權 2024 Google LLC.


In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />在 ai.google.dev 上檢視</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/doggy8088/generative-ai-docs/blob/main/site/zh/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />在 Google Colab 中執行</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/doggy8088/generative-ai-docs/blob/main/site/zh/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />在 GitHub 上檢視原始碼</a>
  </td>
</table>


# PyTorch 中的 Gemma

以下是快速示範如何在 PyTorch 中運行 Gemma 推論。
如需更多詳細資訊，請在 [這裡](https://github.com/google/gemma_pytorch)查看官方 PyTorch 實作的 Github 回應。

**請注意** ：
* 免費的 Colab CPU Python 執行階段和 T4 GPU Python 執行階段，足以運行 Gemma 2B 模型和 7B int8 量化模型。
* 有關其他 GPU 或 TPU 的進階使用案例，請參閱官方回應中的 [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md)。


## Kaggle 的存取

如果要登入 Kaggle，可以將 `kaggle.json` 認證檔儲存在
`~/.kaggle/kaggle.json`，或是執行下列程式碼於 Colab 環境中。更多詳細資料，請參閱
[`kagglehub` 套件文件](https://github.com/Kaggle/kagglehub#authenticate)。


In [None]:
import kagglehub

kagglehub.login()

## 安裝依賴項


In [None]:
!pip install -q -U torch immutabledict sentencepiece

## 下載模型權重


In [None]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [None]:
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

## 下載模型實作


In [None]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

In [None]:
import sys

sys.path.append('gemma_pytorch')

In [None]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## 設定模型


In [None]:
import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

## 執行推論

以下為聊天模式產生和多重要求產生的範例。

經指令微調後的 Gemma 模型使用特定 格式器 訓練，此格式器能在訓練和推論期間為指令微調範例加上額外資訊進行註解。這些註解 (1) 標示對話中的角色，且 (2) 描述對話中的回合。以下我們將展示一個範例程式碼片段，它使用多回合對話中的使用者和模型聊天範本來格式化模型提示。相關符號如下：

- `user`：使用者回合
- `model`：模型回合
- `<start_of_turn>`：對話回合開始
- `<end_of_turn>`：對話回合結束

在 [這裡](https://ai.google.dev/gemma/docs/formatting) 閱讀有關 Gemma 的指令微調格式設定和系統指令。


In [None]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model



"* **Visit the Golden Gate Bridge and Alcatraz Island in San Francisco.**\n* **Head to Yosemite National Park and marvel at nature's beauty.**\n* **Explore the bustling metropolis of Los Angeles.**\n* **Relax on the pristine beaches of Santa Monica or Malibu.**\n* **Go whale watching in Monterey Bay.**\n* **Discover the charming coastal towns of Monterey Bay and Carmel-by-the-Sea.**\n* **Visit Disneyland and Disney California Adventure in Anaheim.**\n*"

In [None]:
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)

['\n\nThe fingers dance on the keys,\nA symphony of thoughts and dreams.\nThe mind, a canvas yet uncouth,\nScribbling its secrets in the night.\n\nThe ink, a whispered voice from deep,\nA language ancient, never to sleep.\nEach stroke an echo of']

## 了解更多

現在你已經學會如何在 Pytorch 中使用 Gemma，你可以探索 Gemma 在 [ai.google.dev/gemma](https://ai.google.dev/gemma) 中可以執行的許多其他操作。
請參閱以下其他相關資源：

- [Gemma 模型說明](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma 格式化和系統說明](https://ai.google.dev/gemma/docs/formatting)
