Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.

---

# Fine-tuning the 2B Gemma model with flax

In this tutorial you will learn how to fine-tune the 2B Gemma model for a simple translation task. To run this colab, you will need to use a TPU v4 runtime.

## Setup

In [None]:
# @title Installation
# ! pip install git+https://github.com/google-deepmind/gemma.git
# ! pip install --user kaggle

## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Then run the cell below.

In [None]:
# import kagglehub
# kagglehub.login()

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. On a single host, only the 2b model can fit in memory for fine-tuning.

In [None]:
import os

VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = '/tfhub/prod/g_mini/2b_it_v1p1_orbax/1'
vocab_path = '/home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model'

In [None]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

# Finally, we import Gemma.
from colabtools import adhoc_import
with adhoc_import.SubmittedChangelist():
  from gemma.deprecated import params as params_lib
  from gemma.deprecated import sampler as sampler_lib
  from gemma.deprecated import transformer as transformer_lib
  from .third_party.sentencepiece.src.python import sentencepiece_processor as spm

In [None]:
a = jnp.zeros(10)
a.devices()
a = jax.device_put(a, jax.devices('cpu')[0])
print(a.devices())
#a

{CpuDevice(id=0)}


## Step 1: prepare the dataset

### The MTNT dataset

In this tutorial, we will use the MTNT dataset, from the paper [MTNT: A Testbed for Machine Translation of Noisy Text](https://arxiv.org/abs/1809.00388). This dataset is directly available in the [TensorFlow dataset catalog](https://www.tensorflow.org/datasets/catalog/mtnt).

More precisely we will focus on the English to French translation.

But let's have a look at the data themselves.

In [None]:
ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

Example 0:
dst: b'Le groupe de " toutes les \xc3\xa9toiles potentielles de la conf\xc3\xa9rence de l\'Est mais qui ne s\'en sortent pas dans le groupe de l\'Ouest ".'
src: b'The group of \xe2\x80\x9ceastern conference potential all stars but not making it in the West\xe2\x80\x9d group.'

Example 1:
dst: b"Kameron est-elle un peu aigrie de son manque de temps \xc3\xa0 l'\xc3\xa9cran ?"
src: b'Is Kameron a Little Salty About Her Lack of Air Time?'



Each sample in the dataset contains two entries:
- 'src': The original English sentence.
- 'dst': The corresponding French translation.

### Tokenizer

Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

vocab_list = [(id, vocab.IdToPiece(id)) for id in range(vocab.GetPieceSize())]
letters = ['A', 'B', 'C', 'D']
res_dict = {}
for id, piece in vocab_list:
  try:
    letter = piece[piece.find(next(filter(str.isalpha, piece)))]
    if letter in letters:
      res_dict[id] = letter
  except:
    pass


Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Gemma 2B model, we need a few adjustments:

- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.

- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.

- **LM Tokens**: Gemma models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example.

In [None]:
class GemmaTokenizer:
  """Custom wrapper around a SentencePieceProcessor for tensorflow."""

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad id."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    Tokenization function.

    Args:
      example: input string to tokenize.
      prefix:  prefix to add to the input string.
      suffix:  suffix to add to the input string.
      add_eos: if True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """Tensforflow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

Now let's try our custom tokenizer on the MTNT dataset

In [None]:
tokenizer = GemmaTokenizer(vocab)

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  prefix='Translate this into French:\n',
                                  suffix='\n',
                                  add_eos=False)
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example,
                                  add_eos=True)

ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),
                       'dst': tokenize_destination(tokenizer, x['dst'])})
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

Example 0:
src: [     2  49688    736   1280   6987 235292    108    651   2778    576
   1080 104745  11982   5736    832   8995    901    780   3547    665
    575    573   4589 235369   2778 235265    108]
dst: [     2   2025  29653    581    664  16298   1437  55563  41435   7840
    581    683 111452    581    533 235303   9776   4108   2459    679
    485 235303    479   6728    579   1806   2499    709  29653    581
    533 235303 101323  16054      1]

Example 1:
src: [     2  49688    736   1280   6987 235292    108   2437  87150    477
    476  11709 230461   8045   3636  40268    576   4252   4897 235336
    108]
dst: [     2 213606    477   1455 235290   3510    748   8268 191017   2809
    581   2032  69972    581  11495   1305    533 235303  65978   1654
      1]



### Data loader

We can now wrap everything a build our data loader.

In [None]:
@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens given to the model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'


class MTNTDatasetBuilder:
  """Data loader for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692,
             DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(example,
                                          prefix=self.TRANSLATION_PREFIX,
                                          suffix=self.TRANSLATION_SUFFIX,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # We want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, we add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then we pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # We don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert them to training inputs
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary
    ds = ds.repeat(num_epochs)

    # Build batches
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

In [None]:
"""Base class for dataset builders."""

import chex
import jax
import tensorflow as tf

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation
  target_mask: jax.Array


class DatasetBuilder:
  """Base class for dataset builders.

  This class provides the interface for dataset builders.
  """

  def __init__(self, tokenizer: GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._max_seq_len = max_seq_len

  def _pad_up_to_max_len(
      self, input_tensor: tf.Tensor, pad_value: int | bool
  ) -> tf.Tensor:
    """Pads the given tensor up to max_seq_len."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(0, self._max_seq_len - seq_len)
    return tf.pad(
        input_tensor,
        [[0, to_pad]],
        mode='CONSTANT',
        constant_values=pad_value
    )

  def get_train_dataset(self):
    raise NotImplementedError()

  def get_validation_dataset(self, batch_size: int):
    raise NotImplementedError()


In [None]:

"""Dataset builder for the SciQ dataset."""

import enum as Enum
import random
import numpy as np

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds


class DatasetSplit(Enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'validation'


class SciQDatasetBuilder(DatasetBuilder):
  """Dataset builder for the Open Orca dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 13679}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  QUESTION_PREFIX = 'Question: \n'
  QUESTION_SUFFIX = '\n'
  OPTION_PREFIXES = ["(a) ", " (b) ", "(c) ", "(d) "]
  OPTION_SUFFIX = "\n"
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:sciq', split='train',
        ),
        DatasetSplit.VALIDATION: tfds.load('huggingface:sciq', split='validation')
    }
    logging.info(f'sciq size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_question(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUESTION_PREFIX,
        suffix=self.QUESTION_SUFFIX,
        add_eos=False,
    )

  def _tokenize_option(self, example: tf.Tensor, option_prefix: string):
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=option_prefix,
        suffix=self.OPTION_SUFFIX,
        add_eos=False,
    )
  def _tokenize_correct_answer(self, correct_answer: tf.Tensor):
    return self._tokenizer.tokenize_tf_op(correct_answer, prefix='Correct answer: ', suffix='\n', add_eos=False)

  def _tokenize_support(self, support: tf.Tensor):
    return self._tokenizer.tokenize_tf_op(support, prefix='Support: ', suffix='\n', add_eos=False)

  def _to_training_input(
      self,
      question_tokens: jax.Array,
      option1: jax.Array,
      option2: jax.Array,
      option3: jax.Array,
      option4: jax.Array,
      correct_answer: jax.Array,
      support: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [question_tokens, option1, option2, option3, option4, correct_answer, support], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    question_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    option1_mask = tf.zeros_like(option1, dtype=tf.bool)
    option2_mask = tf.zeros_like(option2, dtype=tf.bool)
    option3_mask = tf.zeros_like(option3, dtype=tf.bool)
    option4_mask = tf.zeros_like(option4, dtype=tf.bool)
    correct_answer_mask = tf.ones_like(correct_answer, dtype=tf.bool)
    support_mask = tf.ones_like(support, dtype=tf.bool)
    mask = tf.concat([question_mask, option1_mask, option2_mask, option3_mask, option4_mask, correct_answer_mask, support_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def _tokenize_example(self, example):
    options = [example['distractor1'], example['distractor2'], example['distractor3'], example['correct_answer']]
    options_permutation = np.random.permutation(4)
    return (self._tokenize_question(example['question']),
            self._tokenize_option(options[options_permutation[0]], self.OPTION_PREFIXES[0]),
            self._tokenize_option(options[options_permutation[1]], self.OPTION_PREFIXES[1]),
            self._tokenize_option(options[options_permutation[2]], self.OPTION_PREFIXES[2]),
            self._tokenize_option(options[options_permutation[3]], self.OPTION_PREFIXES[3]),
            self._tokenize_correct_answer(example['correct_answer']),
            self._tokenize_support(example['support']))

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        self._tokenize_example,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y, z, a, b, c, d: self._to_training_input(x, y, z, a, b, c, d),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self):
    """Build the validation dataset."""
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        self._tokenize_example,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y, z, a, b, c, d: self._to_training_input(x, y, z, a, b, c, d),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    return ds

In [None]:
#sciq_path = '/home/shivguptashi/sciq/sciq_train.tfrecord'
sciq_validation_path = '/home/shivguptashi/sciq/sciq_validation.tfrecord'
tokenizer = GemmaTokenizer(vocab)
sciq_dataset_builder = SciQDatasetBuilder(tokenizer, max_seq_len=1000)
train_ds = sciq_dataset_builder.get_validation_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(sciq_validation_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

it: 0
it: 1
it: 2
it: 3
it: 4
it: 5
it: 6
it: 7
it: 8
it: 9
it: 10
it: 11
it: 12
it: 13
it: 14
it: 15
it: 16
it: 17
it: 18
it: 19
it: 20
it: 21
it: 22
it: 23
it: 24
it: 25
it: 26
it: 27
it: 28
it: 29
it: 30
it: 31
it: 32
it: 33
it: 34
it: 35
it: 36
it: 37
it: 38
it: 39
it: 40
it: 41
it: 42
it: 43
it: 44
it: 45
it: 46
it: 47
it: 48
it: 49
it: 50
it: 51
it: 52
it: 53
it: 54
it: 55
it: 56
it: 57
it: 58
it: 59
it: 60
it: 61
it: 62
it: 63
it: 64
it: 65
it: 66
it: 67
it: 68
it: 69
it: 70
it: 71
it: 72
it: 73
it: 74
it: 75
it: 76
it: 77
it: 78
it: 79
it: 80
it: 81
it: 82
it: 83
it: 84
it: 85
it: 86
it: 87
it: 88
it: 89
it: 90
it: 91
it: 92
it: 93
it: 94
it: 95
it: 96
it: 97
it: 98
it: 99
it: 100
it: 101
it: 102
it: 103
it: 104
it: 105
it: 106
it: 107
it: 108
it: 109
it: 110
it: 111
it: 112
it: 113
it: 114
it: 115
it: 116
it: 117
it: 118
it: 119
it: 120
it: 121
it: 122
it: 123
it: 124
it: 125
it: 126
it: 127
it: 128
it: 129
it: 130
it: 131
it: 132
it: 133
it: 134
it: 135
it: 136
it: 137
it: 13

In [None]:
"""Dataset builder for the SciTail dataset."""

import enum as Enum
import random
import numpy as np

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds


class DatasetSplit(Enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'validation'


class SciTailDatasetBuilder(DatasetBuilder):
  """Dataset builder for the Open Orca dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 13679}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  PREMISE_PREFIX = 'Premise: \n'
  PREMISE_SUFFIX = '\n'
  HYPOTHESIS_PREFIX = 'Hypothesis: \n'
  HYPOTHESIS_SUFFIX = '\n'
  LABEL_PREFIX = 'Label: \n'
  LABEL_SUFFIX = '\n'
  OPTION_PREFIXES = ["(a) ", " (b) ", "(c) ", "(d) "]
  OPTION_SUFFIX = "\n"
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:scitail/tsv_format', split='train',
        ),
        DatasetSplit.VALIDATION: tfds.load('huggingface:scitail/tsv_format', split='validation')
    }
    self._max_seq_len = max_seq_len

  def _tokenize_premise(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.PREMISE_PREFIX,
        suffix=self.PREMISE_SUFFIX,
        add_eos=False,
    )

  def _tokenize_hypothesis(self, example: tf.Tensor):
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.HYPOTHESIS_PREFIX,
        suffix=self.HYPOTHESIS_SUFFIX,
        add_eos=False,
    )

  def _tokenize_label(self, example: tf.Tensor):
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.LABEL_PREFIX,
        suffix=self.LABEL_SUFFIX,
        add_eos=False,
    )


  def _to_training_input(
      self,
      premise_tokens: jax.Array,
      hypothesis_tokens: jax.Array,
      label_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [premise_tokens, hypothesis_tokens, label_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    premise_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    hypothesis_mask = tf.zeros_like(hypothesis_tokens, dtype=tf.bool)
    label_mask = tf.ones_like(label_tokens, dtype=tf.bool)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def _tokenize_example(self, example):
    options = [example['distractor1'], example['distractor2'], example['distractor3'], example['correct_answer']]
    options_permutation = np.random.permutation(4)
    return (self._tokenize_question(example['question']),
            self._tokenize_option(options[options_permutation[0]], self.OPTION_PREFIXES[0]),
            self._tokenize_option(options[options_permutation[1]], self.OPTION_PREFIXES[1]),
            self._tokenize_option(options[options_permutation[2]], self.OPTION_PREFIXES[2]),
            self._tokenize_option(options[options_permutation[3]], self.OPTION_PREFIXES[3]),
            self._tokenize_correct_answer(example['correct_answer']),
            self._tokenize_support(example['support']))

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        self._tokenize_example,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y, z, a, b, c, d: self._to_training_input(x, y, z, a, b, c, d),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self):
    """Build the validation dataset."""
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        self._tokenize_example,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y, z, a, b, c, d: self._to_training_input(x, y, z, a, b, c, d),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    return ds

In [None]:
"""Dataset builder for the Open Orca dataset."""

import enum as Enum
import random

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds


class DatasetSplit(Enum.Enum):
  TRAIN = 'train'


class OpenOrcaDatasetBuilder(DatasetBuilder):
  """Dataset builder for the Open Orca dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 2914896}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  SYSTEM_PREFIX = 'System: \n'
  SYSTEM_SUFFIX = '\n'
  QUESTION_PREFIX = 'Question: \n'
  QUESTION_SUFFIX = '\n'
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:open_orca__openorca', split='train'
        ),
    }
    logging.info(f'open orca size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_system(self, example: tf.Tensor) -> tf.Tensor:
    """Tokenization function for the system prompt."""
    res = self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.SYSTEM_PREFIX,
        suffix=self.SYSTEM_SUFFIX,
        add_eos=False,
    )
    return res

  def _tokenize_question(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUESTION_PREFIX,
        suffix=self.QUESTION_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        add_eos=True,
    )

  def _to_training_input(
      self,
      system_tokens: jax.Array,
      question_tokens: jax.Array,
      response_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [system_tokens, question_tokens, response_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    system_mask = tf.zeros_like(system_tokens, dtype=tf.bool)
    question_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    response_mask = tf.ones_like(response_tokens, dtype=tf.bool)
    mask = tf.concat([system_mask, question_mask, response_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_system(x['system_prompt']),
            self._tokenize_question(x['question']),
            self._tokenize_response(x['response'])
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and
    # repetition.
    # ds = self._base_data[DatasetSplit.VALIDATION].map(
    #    lambda x: (self._tokenize_source(x['src']),
    #               self._tokenize_destination(x['dst'])))
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_system(x['system_prompt']),
            self._tokenize_question(x['question']),
            self._tokenize_response(x['response']),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(
        lambda x, y, z: self._to_training_input(x, y, z),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    #ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    # ds = ds.batch(batch_size, drop_remainder=True)
    return ds
    # ds = [self._to_training_input(x, y) for x, y in ds]
    # print('here3:', ds)
    # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]
    # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]


In [None]:
open_orca_path = '/home/shivguptashi/open_orca/open_orca_data.tfrecord'
tokenizer = GemmaTokenizer(vocab)
open_orca_dataset_builder = OpenOrcaDatasetBuilder(tokenizer, max_seq_len=1000)
train_ds = open_orca_dataset_builder.get_train_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(open_orca_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
it: 287710
it: 287711
it: 287712
it: 287713
it: 287714
it: 287715
it: 287716
it: 287717
it: 287718
it: 287719
it: 287720
it: 287721
it: 287722
it: 287723
it: 287724
it: 287725
it: 287726
it: 287727
it: 287728
it: 287729
it: 287730
it: 287731
it: 287732
it: 287733
it: 287734
it: 287735
it: 287736
it: 287737
it: 287738
it: 287739
it: 287740
it: 287741
it: 287742
it: 287743
it: 287744
it: 287745
it: 287746
it: 287747
it: 287748
it: 287749
it: 287750
it: 287751
it: 287752
it: 287753
it: 287754
it: 287755
it: 287756
it: 287757
it: 287758
it: 287759
it: 287760
it: 287761
it: 287762
it: 287763
it: 287764
it: 287765
it: 287766
it: 287767
it: 287768
it: 287769
it: 287770
it: 287771
it: 287772
it: 287773
it: 287774
it: 287775
it: 287776
it: 287777
it: 287778
it: 287779
it: 287780
it: 287781
it: 287782
it: 287783
it: 287784
it: 287785
it: 287786
it: 287787
it: 287788
it: 287789
it: 287790
it: 287791
it: 287792
it: 287793
it: 287794


KeyboardInterrupt: 

In [None]:

import enum as Enum
import random

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds

class DatasetSplit(Enum.Enum):
  TRAIN = 'train'
  TEST = 'test'


class GSM8KDatasetBuilder(DatasetBuilder):
  """Dataset builder for the GSM8k dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 7473}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  ANSWER_PREFIX = 'A: '
  ANSWER_SUFFIX = '\n'
  QUESTION_PREFIX = 'Q: '
  QUESTION_SUFFIX = '\n'
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:gsm8k/main', split='train'
        ),
        DatasetSplit.TEST: tfds.load(
            'huggingface:gsm8k/main', split='test'
        ),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_question(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUESTION_PREFIX,
        suffix=self.QUESTION_SUFFIX,
        add_eos=False,
    )

  def _tokenize_answer(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        add_eos=True,
    )

  def _to_training_input(
      self,
      question_tokens: jax.Array,
      answer_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [question_tokens, answer_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    question_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    answer_mask = tf.ones_like(answer_tokens, dtype=tf.bool)
    mask = tf.concat([question_mask, answer_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_answer(x['answer']),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and
    # repetition.
    # ds = self._base_data[DatasetSplit.VALIDATION].map(
    #    lambda x: (self._tokenize_source(x['src']),
    #               self._tokenize_destination(x['dst'])))
    ds = self._base_data[DatasetSplit.TEST].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_answer(x['answer']),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(
        lambda x, y: self._to_training_input(x, y),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    # ds = ds.batch(batch_size, drop_remainder=True)
    return ds
    # ds = [self._to_training_input(x, y) for x, y in ds]
    # print('here3:', ds)
    # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]
    # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]

  def get_question_answer_dataset(self):
    #ds = self._base_data[DatasetSplit.TEST]
    return self._base_data[DatasetSplit.TEST]

In [None]:
gsm8k_path = '/home/shivguptashi/gsm8k_train/gsm8k_train.tfrecord'
tokenizer = GemmaTokenizer(vocab)
gsm8k_dataset_builder = GSM8KDatasetBuilder(tokenizer, max_seq_len=1000)
train_ds = gsm8k_dataset_builder.get_train_dataset(batch_size=100, num_epochs=1)
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(gsm8k_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
it: 2473
it: 2474
it: 2475
it: 2476
it: 2477
it: 2478
it: 2479
it: 2480
it: 2481
it: 2482
it: 2483
it: 2484
it: 2485
it: 2486
it: 2487
it: 2488
it: 2489
it: 2490
it: 2491
it: 2492
it: 2493
it: 2494
it: 2495
it: 2496
it: 2497
it: 2498
it: 2499
it: 2500
it: 2501
it: 2502
it: 2503
it: 2504
it: 2505
it: 2506
it: 2507
it: 2508
it: 2509
it: 2510
it: 2511
it: 2512
it: 2513
it: 2514
it: 2515
it: 2516
it: 2517
it: 2518
it: 2519
it: 2520
it: 2521
it: 2522
it: 2523
it: 2524
it: 2525
it: 2526
it: 2527
it: 2528
it: 2529
it: 2530
it: 2531
it: 2532
it: 2533
it: 2534
it: 2535
it: 2536
it: 2537
it: 2538
it: 2539
it: 2540
it: 2541
it: 2542
it: 2543
it: 2544
it: 2545
it: 2546
it: 2547
it: 2548
it: 2549
it: 2550
it: 2551
it: 2552
it: 2553
it: 2554
it: 2555
it: 2556
it: 2557
it: 2558
it: 2559
it: 2560
it: 2561
it: 2562
it: 2563
it: 2564
it: 2565
it: 2566
it: 2567
it: 2568
it: 2569
it: 2570
it: 2571
it: 2572
it: 2573
it: 2574
it: 2575
it: 2576

In [None]:
def decode_fn(record_bytes):
  parsed_features = tf.io.parse_example(record_bytes, {"input_tokens": tf.io.FixedLenFeature((), tf.string), "target_mask": tf.io.FixedLenFeature((), tf.string)})
  return {'input_tokens': tf.io.decode_raw(parsed_features["input_tokens"], out_type=tf.int32), "target_mask": tf.io.decode_raw(parsed_features["target_mask"], out_type=tf.bool)}

dataset = tf.data.TFRecordDataset([open_orca_path])
dataset = dataset.map(decode_fn)
for record in dataset:
  print(record)
  #print(record)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False,  True,  True,  True,  True,
        True, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False])>}
{'input_tokens': <tf.Tensor: shape=(100,), dtype=int32, numpy=
array([     2,   2622, 235292, 235248,    108,   2045,    708,    476,
        10055,  20409, 235269,   1064,   2593,   3658,  15844, 235265,
        24174,   1154,    692,    708,  39534,    577,    476,   4105,
         1162,   2187, 235265,    108,      2,   9413, 235292, 235248,
          108,  49688,    774,  13035,    577,   4645, 235292,    109,
        28

In [None]:

import enum as Enum
import random

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds



class DatasetSplit(Enum.Enum):
  TRAIN = 'train'


class OrcaMathDatasetBuilder(DatasetBuilder):
  """Dataset builder for the Orca Math dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 200035}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  QUESTION_PREFIX = 'Question: \n'
  QUESTION_SUFFIX = '\n'
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:microsoft__orca_math_word_problems_200k', split='train'
        ),
    }
    logging.info(f'orca math size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_question(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUESTION_PREFIX,
        suffix=self.QUESTION_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        add_eos=True,
    )

  def _to_training_input(
      self,
      question_tokens: jax.Array,
      answer_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [question_tokens, answer_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    question_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    answer_mask = tf.ones_like(answer_tokens, dtype=tf.bool)
    mask = tf.concat([question_mask, answer_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_response(x['answer'])
        )
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and
    # repetition.
    # ds = self._base_data[DatasetSplit.VALIDATION].map(
    #    lambda x: (self._tokenize_source(x['src']),
    #               self._tokenize_destination(x['dst'])))
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_response(x['answer'])
        )
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds
    #ds = [self._to_training_input(x, y) for x, y in ds]
    #print('here3:', ds)
    #ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]
    #ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]


In [None]:
orca_math_path = '/home/shivguptashi/orca_math/orca_math_data.tfrecord'
tokenizer = GemmaTokenizer(vocab)
orca_math_dataset_builder = OrcaMathDatasetBuilder(tokenizer, max_seq_len=1000)
train_ds = orca_math_dataset_builder.get_train_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(orca_math_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
it: 6475
it: 6476
it: 6477
it: 6478
it: 6479
it: 6480
it: 6481
it: 6482
it: 6483
it: 6484
it: 6485
it: 6486
it: 6487
it: 6488
it: 6489
it: 6490
it: 6491
it: 6492
it: 6493
it: 6494
it: 6495
it: 6496
it: 6497
it: 6498
it: 6499
it: 6500
it: 6501
it: 6502
it: 6503
it: 6504
it: 6505
it: 6506
it: 6507
it: 6508
it: 6509
it: 6510
it: 6511
it: 6512
it: 6513
it: 6514
it: 6515
it: 6516
it: 6517
it: 6518
it: 6519
it: 6520
it: 6521
it: 6522
it: 6523
it: 6524
it: 6525
it: 6526
it: 6527
it: 6528
it: 6529
it: 6530
it: 6531
it: 6532
it: 6533
it: 6534
it: 6535
it: 6536
it: 6537
it: 6538
it: 6539
it: 6540
it: 6541
it: 6542
it: 6543
it: 6544
it: 6545
it: 6546
it: 6547
it: 6548
it: 6549
it: 6550
it: 6551
it: 6552
it: 6553
it: 6554
it: 6555
it: 6556
it: 6557
it: 6558
it: 6559
it: 6560
it: 6561
it: 6562
it: 6563
it: 6564
it: 6565
it: 6566
it: 6567
it: 6568
it: 6569
it: 6570
it: 6571
it: 6572
it: 6573
it: 6574
it: 6575
it: 6576
it: 6577
it: 6578

KeyboardInterrupt: 

In [None]:
"""Dataset builder for the Orca Math dataset."""

import enum as Enum
import random

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds

from .experimental.users.shivguptashi.datamix_gemma.dataset_builders import dataset_builder
from .experimental.users.shivguptashi.datamix_gemma.tokenizers import gemma_tokenizer


class DatasetSplit(Enum.Enum):
  TRAIN = 'train'


class OrcaMathDatasetBuilder(dataset_builder.DatasetBuilder):
  """Dataset builder for the Orca Math dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 200035}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  QUESTION_PREFIX = 'Question: \n'
  QUESTION_SUFFIX = '\n'
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:microsoft__orca_math_word_problems_200k', split='train'
        ),
    }
    logging.info(f'orca math size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_question(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUESTION_PREFIX,
        suffix=self.QUESTION_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        add_eos=True,
    )

  def _to_training_input(
      self,
      question_tokens: jax.Array,
      answer_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [question_tokens, answer_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    question_mask = tf.zeros_like(question_tokens, dtype=tf.bool)
    answer_mask = tf.ones_like(answer_tokens, dtype=tf.bool)
    mask = tf.concat([question_mask, answer_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return dataset_builder.TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_response(x['answer'])
        )
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and
    # repetition.
    # ds = self._base_data[DatasetSplit.VALIDATION].map(
    #    lambda x: (self._tokenize_source(x['src']),
    #               self._tokenize_destination(x['dst'])))
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_question(x['question']),
            self._tokenize_response(x['answer'])
        )
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds
    #ds = [self._to_training_input(x, y) for x, y in ds]
    #print('here3:', ds)
    #ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]
    #ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]


ModuleNotFoundError: No module named '.experimental'

Let's give it a try.

In [None]:
tokenizer = GemmaTokenizer(vocab)
tokenized_str = tokenizer.tokenize('The answer is A')
print(tokenized_str)
for token in tokenized_str:
  if int(token) in res_dict:
    print(res_dict[int(token)])
    break
dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

[   2  651 3448  603  586    1]
A
Example 0:
input_tokens: [[     2  49688    736   1280   6987 235292    108   2728  92820    604
   35546    108      2  39614  67032   1982    683 227484      1      0]
 [     2  49688    736   1280   6987 235292    108  17099  15531   7404
     955    108      2  13590   2360  21536  15845 235265      1      0]
 [     2  49688    736   1280   6987 235292    108 132380   3646 235265
     108      2  87006  15845  35624 235303  22863 235265      1      0]]
target_mask: [[False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False False
   True  True  True  True  True  True  True False]
 [False False False False False False False False False False False  True
   True  True  True  True  True  True  True False]]

Example 1:
input_tokens: [[     2  49688    736   1280   6987 235292    108 235393  66660 235393
     108      2 23

## Fine tuning the Gemma model

### Getting started

First let's load the model

In [None]:
# Load parameters

# TODO: change once the downloading url is known
params = params_lib.load_and_format_params(ckpt_path)

# We use the `transformer_lib.TransformerConfig.from_params` function to
# automatically load the correct configuration from a checkpoint. Note that the
# vocabulary size is smaller than the number of input embeddings due to unused
# tokens in this release.
config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30  # Number of time steps in the transformer's cache
)
model_2b = transformer_lib.Transformer(config=config_2b)

Can our model translate French ? Well let's try it out !

In [None]:
sampler_old = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['transformer'],
)

In [None]:
print(sampler_old(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    # number of steps performed when generating
    total_generation_steps=30,
  ).text)

As expected, it didn't work. Let's see if we can get better results by fine-tuning.

Before moving further, don't forget to clear the memory if necessary.

In [None]:
del sampler_old

### Model forward and loss function

Gemma `Transformer` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:

- `init`: Initializes the model's parameters.

- `apply`: Executes the model's `__call__` function using a given set of parameters.

Since are working with pre-trained weights, we won't use the `init` function.

We define a `forward_and_loss_fn` as follows:

In [None]:
def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [None]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [None]:
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

In [None]:
def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

And now the training loop itself.

In [None]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None


def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig):


  # We jit the train step, making the whole loop much more efficient
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # We do the same with the validation step
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, we use a SGD optimizer instead of the usual Adam. Note that
  # for this specific example SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of validation loss
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

We can fine-tune our model on a limited number of steps.

In [None]:
# Small seq size so that everything fits in memory
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)

Both the training loss and the validation's are going down. But is it working ? Let's try again with our previous example:

In [None]:
sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\n'  and a newline character at the end. This signals the model to begin translation.

In [None]:
sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=30,
    ).text


NameError: name 'sampler' is not defined