##### 版權 2024 Google LLC.


In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 使用 JAX 和 Flax 微調 Gemma


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/jax_finetune"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />檢視 ai.google.dev</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/doggy8088/generative-ai-docs/blob/main/site/zh/gemma/docs/jax_finetune.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />在 Google Colab 中執行</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/jax_finetune.ipynb"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />在 Vertex AI 中開啟</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/doggy8088/generative-ai-docs/blob/main/site/zh/gemma/docs/jax_finetune.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />在 GitHub 上檢視原始程式碼</a>
  </td>
</table>


## 概述

Gemma 是一系列以 Google DeepMind Gemini 研究和技術為基礎的輕量級、最先進的開放大型語言模型。本教學課程示範如何使用 [Google DeepMind 的 `gemma` 函式庫](https://github.com/google-deepmind/gemma)、[JAX](https://jax.readthedocs.io)(高性能數值運算函式庫)、[Flax](https://flax.readthedocs.io)(基於 JAX 的神經網路函式庫)、[Chex](https://chex.readthedocs.io/en/latest/)(用於撰寫可靠 JAX 程式碼的公用程式函式庫)、[Optax](https://optax.readthedocs.io/en/latest/)(基於 JAX 的梯度處理和最佳化函式庫)、[MTNT (雜訊文字機器翻譯) 資料集](https://arxiv.org/abs/1809.00388) 來微調 Gemma 2B Instruct 模型，以進行英法翻譯任務。儘管本筆記本未直接使用 Flax，但 Flax 已用於建立 Gemma。

`gemma` 函式庫使用 JAX、Flax、[Orbax](https://orbax.readthedocs.io/)(用於訓練公用程式 (例如檢查點) 的基於 JAX 的函式庫) 和 [SentencePiece](https://github.com/google/sentencepiece)(分詞器/反分詞器函式庫) 編寫。

**注意：** 本筆記本在 Google Colab 中使用 A100 GPU 執行。免費的 Colab 硬體加速不足以執行本筆記本，因為它需要大量的主機記憶體，例如 A100 GPU (在 Colab Pro 中提供) 或至少 Google Cloud TPU v3-8。你可以使用提供免費 TPU v3-8 加速的 [Kaggle VM 筆記本](https://www.kaggle.com/)；或 [Google Cloud TPU](https://cloud.google.com/tpu?hl=en) 提供 TPU v3 和更高版本。目前，Google Colab 提供 TPU v2，這不足以滿足本教學的要求。


## 設定


### 1. 為 Gemma 設定 Kaggle 存取權

要完成本教學課程，需要先按照 [Gemma 設定](https://ai.google.dev/gemma/docs/setup) 中的設定說明進行操作，了解如何執行下列動作：

* 在 [kaggle.com](https://www.kaggle.com/models/google/gemma/) 上取得 Gemma 存取權。
* 選取具有足夠資源來執行 Gemma 模型的 Colab 執行時期。
* 產生並設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後，請移至下一章節，設定 Colab 環境的環境變數。

### 2. 設定環境變數

設定 `KAGGLE_USERNAME` 和 `KAGGLE_KEY` 的環境變數。當出現「授予存取權？」訊息時，請同意提供機密存取權。


In [1]:
import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### 3. 安裝 `gemma` 函式庫

免費的 Colab 硬體加速目前 *不足* 以執行這個筆記本。如果你使用的是 [Colab 付費或 Colab 專業版](https://colab.research.google.com/signup)，請按一下「編輯」>「筆記本設定」> 選擇「A100 GPU」>「儲存」來啟用硬體加速。

接下來，你需要從 [`github.com/google-deepmind/gemma`](https://github.com/google-deepmind/gemma) 安裝 Google DeepMind 的 `gemma` 函式庫。如果你收到關於「pip 的依存關係解析程式」的錯誤訊息，通常可以忽略它。

**注意：** 透過安裝 `gemma`，你還會安裝 [`flax`](https://flax.readthedocs.io)、核心 [`jax`](https://jax.readthedocs.io)、[`optax`](https://optax.readthedocs.io/en/latest/)  (基於 JAX 的梯度處理和最佳化函式庫)、[`orbax`](https://orbax.readthedocs.io/) 和 [`sentencepiece`](https://github.com/google/sentencepiece)。


In [2]:
!pip install -q git+https://github.com/google-deepmind/gemma.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.7/133.7 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.4/244.4 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for gemma (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-metadata 1.14.0 requires absl-py<2.0.0,>=0.9, but you have absl-py 2.1.0 which is incompatible.[0m[31m
[0m

### 4. 匯入函式庫

本筆記本使用 [Flax](https://flax.readthedocs.io) (用於神經網路)、核心 [JAX](https://jax.readthedocs.io)，[SentencePiece](https://github.com/google/sentencepiece) (用於 Token 化)、[Chex](https://chex.readthedocs.io/en/latest/) (可靠 JAX 程式碼編寫的函式庫) 和 TensorFlow Datasets。


In [3]:
import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

## 載入 Gemma 模型

使用 [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12) 載入 Gemma 模型，其有三個參數：

- `句柄`：Kaggle 的模型句柄
- `路徑`： (選擇性字串) 本地路徑
- `強制下載`： (選擇性布林值) 強制重新下載模型

**注意** ：請注意 `gemma-2b-it` 模型大小約為 3.7Gb。


In [4]:
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}

In [5]:
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:26<00:00, 147MB/s]
Extracting model files...


In [6]:
print('GEMMA_PATH:', GEMMA_PATH)

GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2


**提示：** 上方的輸出路徑是模型權重和 tokenizer 本機儲存的位置，你稍後會用到它們。


檢視模型權重和分詞器的存放位置，然後設定路徑變數。分詞器目錄會在你下載模型的主目錄中，而模型權重則會在子目錄中。例如：

- `tokenizer.model` 檔案會在 `/LOCAL/PATH/TO/gemma/flax/2b-it/2` 中)。
- 模型檢查點會在 `/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it` 中)。


In [7]:
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)

CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model


## 載入並準備 MTNT 資料集與 Gemma 分詞器

你將使用 [MTNT (雜訊文本機器翻譯)](https://arxiv.org/abs/1809.00388) 資料集，它可以在 [TensorFlow 資料集](https://www.tensorflow.org/datasets/catalog/mtnt) 中取得。

下載 MTNT 資料集的英語至法語資料集部分，然後取樣兩個範例。資料集中的每個範例包含兩個項目：「src」：原生的英文句子；與「dst」：對應的法文翻譯。


In [8]:
ds = tfds.load("mtnt/en-fr", split="train")

ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/35692 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-train.tfrecord*...:   0%|          …

Generating test examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-test.tfrecord*...:   0%|          |…

Generating valid examples...:   0%|          | 0/811 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incomplete6YJMND/mtnt-valid.tfrecord*...:   0%|          …

Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.
Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'



載入 Gemma 分詞器，由[`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423) 建立：


In [9]:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

True

針對英翻法翻譯任務客製化 [`SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423)。由於你會微調 Gemma 模型的英文區塊，因此需要進行一些調整，例如：

* 輸入前置詞彙：在每個輸入前加入共用前置詞彙，標示翻譯任務。例如，你可以使用類似 `將這句話翻譯成法文：[輸入句子]` 的提示，加上前置詞彙。

- *翻譯開始字尾詞彙：在每個提示末端加入字尾詞彙，便能精準指示 Gemma 模型開始翻譯的時機。加入換行字元就能完成這個任務。

- *語言模型 Token：Gemma 模型預期每個序列的開頭會有 [開頭序列] Token，因此每個訓練範例結尾加入 [結尾序列] Token 就足夠了。

按照下列方式建立 `SentencePieceProcessor` 的自訂包裝器：


In [10]:
class GemmaTokenizer:

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad ID."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    The tokenization function.

    Args:
      example: Input string to tokenize.
      prefix:  Prefix to add to the input string.
      suffix:  Suffix to add to the input string.
      add_eos: If True, add an "end of sentence" token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """A TensorFlow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

嘗試透過例證化新的自訂「GemmaTokenizer」，然後套用在 MTNT 資料集的小型樣本上來試用看看：


In [11]:
tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]



建構一個針對整個 MTNT 資料集的資料載入器：


In [12]:
@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to the model.
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """The dataset builder for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample.
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert the samples to training inputs.
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples that are too long.
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset.
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary.
    ds = ds.repeat(num_epochs)

    # Build batches.
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and no repetition.
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

再次執行自訂 `GemmaTokenizer`，接著把它套用到 `MTNT` 資料集上，並抽樣兩個範例，來嘗試 `MTNTDatasetBuilder`：


In [13]:
tokenizer = GemmaTokenizer(vocab)

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()



Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108  10924    665  12302
  235341    108      2   4397  63011   1437  38696   1241      1      0]
 [     2  49688    736   1280   6987 235292    108  13835   1517 235265
     108      2  69875    540  19713 235265      1      0      0      0]
 [     2  49688    736   1280   6987 235292    108   6956   1586 235297
  235265    108      2  78368   1586 235297 235265      1      0      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True False False False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True False False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108  18874 235341    108
       2 115905   6425   1241      1      0      0  

## 配置模型

在你開始微調 Gemma 模型之前，你需要對其進行配置。

首先，使用 [gemma.params.load_and_format_params](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) 方法載入並格式化 Gemma 模型檢查點：


In [14]:
params = params_lib.load_and_format_params(CKPT_PATH)

要自動從 Gemma 模型檢查點載入正確的組態，請使用 [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65)。`cache_size` 參數是 Gemma `Transformer` 快取中的時間步數。之後使用 [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) 將 Gemma 模型實例化為 `model_2b` (繼承自 [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html))。

**註：** 由於在目前版本的 Gemma 中有未使用的 Token，因此詞彙量會小於輸入嵌入的數量。


In [15]:
config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30
)

model_2b = transformer_lib.Transformer(config=config_2b)

## 微調模型

在本節中，你將：

- 使用 `gemma.transformer.Transformer` 類別建立正向傳遞和損失函式。
- 建構 Token 的位置和注意力遮罩向量
- 使用 Flax 建構訓練步驟函式。
- 建構沒有反向傳遞的驗證步驟。
- 建立訓練迴圈。
- 微調 Gemma 模型。


使用 [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) 類別定義前向傳遞和損失函式。Gemma `Transformer` 從 [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) 繼承，且提供兩個必要的 method：

- `init`：初始化模型參數。
- `apply`：使用給定的參數組執行模型的 `__call__` 函式。

由於你使用預訓練過的 Gemma 權重，不需要使用 `init` 函式。


In [16]:
def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """The forward pass and the loss function.

  Args:
    params: Model's input parameters.
    model: The Gemma transformer model to call.
    input_tokens: Input tokens sequence, shape [B, L].
    input_mask: Tokens to ignore when computing the loss, shape [B, L].
    positions: Relative position of each token, shape [B, L].
    attention_mask: Input attention mask, shape [B, L].

  Returns:
    The softmax cross-entropy loss for the next-token prediction task.
  """

  # The forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels to one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Define the normalization factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the negative log likelihood (NLL) loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

[`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) 類別除了輸入外，還需要一個 `attention_mask` 和一個 `positions` 向量。你可以建立一個自訂函式來產生這些，而此函式會使用 [`Transformer.build_positions_from_mask`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L48) 和 [`Transformer.make_causal_attn_mask`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L29)：


In [17]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

建立 `train_step` 函式以進行反向傳遞並相應地更新模型的參數，其中：

- [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html) 用於在正向和反向傳遞中評估損失函式和梯度。
- [`optax.apply_updates`](https://optax.readthedocs.io/en/latest/api/apply_updates.html#optax.apply_updates) 用於更新參數。


In [18]:
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: The Gemma transformer model.
    params: The model's input parameters.
    optimizer: The Optax optimizer to use.
    opt_state: The input optimizer's state.
    pad_id: ID of the pad token.
    example: Input batch.

  Returns:
    The training loss, the updated parameters, and the updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # The forward and backward passes.
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters.
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

建立不具有反向傳遞的 `validation_step` 函式：


In [19]:
def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

使用 [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) 定義 SGD 優化器的訓練迴圈：


In [20]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):

  # Apply `jax.jit` on the training step, making the whole loop much more efficient.
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # Apply `jax.jit` on the validation step.
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, use the SGD optimizer instead of the usual Adam optimizer.
  # Note that for this specific example, SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset.
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo.
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of the validation loss.
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

開始對 Gemma 模型執行微調，並採用有限次的步驟 (`SEQ_SIZE`)，以確保這適合記憶體：


In [21]:
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)

Start, validation loss: 10.647212982177734
STEP 20 training loss: 3.3015992641448975 - eval loss: 2.686880111694336
STEP 40 training loss: 5.375057220458984 - eval loss: 2.6751961708068848
STEP 60 training loss: 2.6599338054656982 - eval loss: 2.663877010345459
STEP 80 training loss: 4.822389125823975 - eval loss: 2.3333375453948975
STEP 100 training loss: 2.0131142139434814 - eval loss: 2.360811948776245


訓練損失量及驗證損失量應會隨著步驟數降低。

使用 [`gemma.sampler.Sampler`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/sampler.py#L88) 建立 `sampler`。它使用 Gemma 模型檢查點與 Token 分詞器。


In [22]:
sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

使用 `sampler` 檢查你的模型是否能執行翻譯。 [`gemma.sampler.Sampler`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/sampler.py#L88) 中的 `total_generation_steps` 參數是在產生回應時執行的步驟數。如要確保輸入符合訓練格式，請在結尾使用換行字元加上前置詞「將這段文字翻譯成法文：\n」。這會讓模型知道要開始翻譯了。

**注意：** 由於硬體限制，用在 gemma Transformer 中的訓練參數數量可能無法在本示範中產生「穩定的」結果。

**注意：** 若發生記憶體不足的情況，請按一下 **執行時期** > **中斷連線並刪除執行時期** ，然後按 **執行時期** > **執行全部** 。


In [23]:
sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=100,
    ).text

["C'est Bonjour, mon nom est Morgane.C'est Bonjour, mon nom est Morgane."]

## 了解更多

- 你可以在 GitHub 上了解有關 Google DeepMind [`gemma` 函式庫](https://github.com/google-deepmind/gemma) 的更多資訊，該函式庫包含你在本教學課程中所用模組的 docstring，例如 [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py)、
[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py) 和
[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py)。
- 下列函式庫有自己的說明文件網站：[core JAX](https://jax.readthedocs.io)、[Flax](https://flax.readthedocs.io)、[Chex](https://chex.readthedocs.io/en/latest/)、[Optax](https://optax.readthedocs.io/en/latest/) 和 [Orbax](https://orbax.readthedocs.io/)。
- sentencepiece 分詞器/去分詞器文件，請查看 [Google 的 `sentencepiece` GitHub 儲存庫](https://github.com/google/sentencepiece)。
- `kagglehub` 文件，請查看 [Kaggle's `kagglehub` GitHub 儲存庫](https://github.com/Kaggle/kagglehub) 上的 `README.md`。
- 了解如何 [使用 Gemma 模型搭配 Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma)。
- 如果你使用 Google Cloud TPU (v3-8 與更新版)，請務必也更新到最新的 `jax[tpu]` 套件 (`!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html`)，重新啟動執行時期，並檢查 `jax` 和 `jaxlib` 版本是否一致 (`!pip list | grep jax`)。這樣可以防止因 `jaxlib` 和 `jax` 版本不一致而產生的 `RuntimeError`。有關更多 JAX 安裝說明，請參閱 [JAX 文件](https://jax.readthedocs.io/en/latest/tutorials/installation.html#install-google-tpu)。
