In [None]:
import os

VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = 'g_mini/2b_it_v1p1_orbax/1'
vocab_path = 'home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model'

In [None]:
# @title Python imports
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

import enum as Enum
import random

from absl import logging
import jax.dlpack


# Finally, we import Gemma.
from colabtools import adhoc_import


In [None]:
a = jnp.zeros(10)
a.devices()
a = jax.device_put(a, jax.devices('cpu')[0])
print(a.devices())
#a

## Inspect Dolly

In [None]:
ds = tfds.load('huggingface:databricks__databricks_dolly_15k', split='train')

In [None]:
ds.cardinality()

In [None]:
for element in ds.take(1):
  print(element)
  for key, val in element.items():
    print(f'{key}: {val}')

## Inspect MetaMath

In [None]:
ds = tfds.load('huggingface:meta_math__metamathqa', split='train')

In [None]:
ds.cardinality()

In [None]:
for element in ds.take(1):
  print(element)
  for key, val in element.items():
    print(f'{key}: {val}')

## Inspect CodeAlpaca

In [None]:
ds = tfds.load('huggingface:sahil2801__codealpaca_20k', split='train')

In [None]:
ds.cardinality()

In [None]:
for element in ds.take(1):
  print(element)
  for key, val in element.items():
    print(f'{key}: {val}')

## Inspect Open Web Math

In [None]:
ds = tfds.load('huggingface:open_web_math__open_web_math', split='train')

In [None]:
ds.cardinality()

In [None]:
for element in ds.take(1):
  print(element)
  for key, val in element.items():
    print(f'{key}: {val}')

### Tokenizer

Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

vocab_list = [(id, vocab.IdToPiece(id)) for id in range(vocab.GetPieceSize())]
letters = ['A', 'B', 'C', 'D']
res_dict = {}
for id, piece in vocab_list:
  try:
    letter = piece[piece.find(next(filter(str.isalpha, piece)))]
    if letter in letters:
      res_dict[id] = letter
  except:
    pass

class DatasetSplit(Enum.Enum):
  TRAIN = 'train'

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation
  target_mask: jax.Array

In [None]:
class GemmaTokenizer:
  """Custom wrapper around a SentencePieceProcessor for tensorflow."""

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad id."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    Tokenization function.

    Args:
      example: input string to tokenize.
      prefix:  prefix to add to the input string.
      suffix:  suffix to add to the input string.
      add_eos: if True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """Tensforflow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

### Data loader

We can now wrap everything a build our data loader.

In [None]:
# @title
"""Base class for dataset builders."""

class DatasetBuilder:
  """Base class for dataset builders.

  This class provides the interface for dataset builders.
  """

  def __init__(self, tokenizer: GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._max_seq_len = max_seq_len

  def _pad_up_to_max_len(
      self, input_tensor: tf.Tensor, pad_value: int | bool
  ) -> tf.Tensor:
    """Pads the given tensor up to max_seq_len."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(0, self._max_seq_len - seq_len)
    return tf.pad(
        input_tensor,
        [[0, to_pad]],
        mode='CONSTANT',
        constant_values=pad_value
    )

  def get_train_dataset(self):
    raise NotImplementedError()

  def get_validation_dataset(self, batch_size: int):
    raise NotImplementedError()


## MetaMath

In [None]:
class MetaMathDatasetBuilder(dataset_builder.DatasetBuilder):
  """Dataset builder for the MetaMath dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 395000}

  BUFFER_SIZE_SHUFFLE = 100
  QUERY_PREFIX = 'Query: \n'
  QUERY_SUFFIX = '\n'
  RESPONSE_PREFIX = 'Response: \n'
  RESPONSE_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:meta_math__metamathqa', split='train',
        ),
    }
    # logging.info(f'sciq size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_query(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.QUERY_PREFIX,
        suffix=self.QUERY_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.RESPONSE_PREFIX,
        suffix=self.RESPONSE_SUFFIX,
        add_eos=False,
    )

  def _to_training_input(
      self,
      query_tokens: jax.Array,
      response_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [query_tokens, response_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    query_mask = tf.zeros_like(query_tokens, dtype=tf.bool)
    response_mask = tf.ones_like(response_tokens, dtype=tf.bool)
    mask = tf.concat([query_mask, response_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_query(x['query']),
            self._tokenize_response(x['response'])
        )
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    return ds

In [None]:
metamath_path = '/home/xinyic/metamath/metamath_data.tfrecord'
tokenizer = GemmaTokenizer(vocab)
metamath_dataset_builder = MetaMathDatasetBuilder(tokenizer, max_seq_len=1000)  # why is this the case?
train_ds = metamath_dataset_builder.get_train_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(metamath_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

## CodeAlpaca

In [None]:
class CodeAlpacaDatasetBuilder(dataset_builder.DatasetBuilder):
  """Dataset builder for the CodeAlpaca dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 20022}
  BUFFER_SIZE_SHUFFLE = 100

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:sahil2801__codealpaca_20k', split='train'
        ),
    }
    # logging.info(f'orca math size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_input(self, example: tf.Tensor):
    """Tokenization function for the Input."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix='Input: \n',
        suffix='\n',
        add_eos=False,
    )

  def _tokenize_instruction(self, example: tf.Tensor):
    """Tokenization function for the Instruction."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix='Instruction: \n',
        suffix='\n',
        add_eos=False,
    )

  def _tokenize_output(self, example: tf.Tensor):
    """Tokenization function for the Output."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix='Output: \n',
        suffix='\n',
        add_eos=False,
    )

  def _to_training_input(
      self,
      input_tokens: jax.Array,
      instruction_tokens: jax.Array,
      output_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [input_tokens, instruction_tokens, output_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    input_mask = tf.zeros_like(input_tokens, dtype=tf.bool)
    instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)
    output_mask = tf.ones_like(output_tokens, dtype=tf.bool)
    mask = tf.concat([input_mask, instruction_mask, output_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_input(x['input']),
            self._tokenize_instruction(x['instruction']),
            self._tokenize_output(x['output']),
        )
    )
    ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    return ds

In [None]:
codealpaca_path = '/home/xinyic/codealpaca/codealpaca_data.tfrecord'
tokenizer = GemmaTokenizer(vocab)
codealpaca_dataset_builder = CodeAlpacaDatasetBuilder(tokenizer, max_seq_len=1000)
train_ds = codealpaca_dataset_builder.get_train_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(codealpaca_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

## Dolly-15K

In [None]:
class DollyDatasetBuilder(dataset_builder.DatasetBuilder):
  """Dataset builder for the Dolly dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 15011}


  BUFFER_SIZE_SHUFFLE = 100
  CONTEXT_PREFIX = 'Context: \n'
  CONTEXT_SUFFIX = '\n'
  INSTRUCTION_PREFIX = 'Instruction: \n'
  INSTRUCTION_SUFFIX = '\n'
  RESPONSE_PREFIX = 'Response: \n'
  RESPONSE_SUFFIX = '\n'

  def __init__(
      self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:databricks__databricks_dolly_15k', split='train'
        ),
    }
    logging.info(f'dolly size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_context(self, example: tf.Tensor):
    """Tokenization function for the context."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.CONTEXT_PREFIX,
        suffix=self.CONTEXT_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.RESPONSE_PREFIX,
        suffix=self.RESPONSE_SUFFIX,
        add_eos=False,
    )

  def _tokenize_instruction(self, example: tf.Tensor):
    """Tokenization function for the instruction."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.INSTRUCTION_PREFIX,
        suffix=self.INSTRUCTION_SUFFIX,
        add_eos=False,
    )

  def _to_training_input(
      self,
      instruction_tokens: jax.Array,
      context_tokens: jax.Array,
      response_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [instruction_tokens, context_tokens, response_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    context_mask = tf.zeros_like(context_tokens, dtype=tf.bool)
    instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)
    response_mask = tf.ones_like(response_tokens, dtype=tf.bool)
    mask = tf.concat([instruction_mask, context_mask, response_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return dataset_builder.TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_instruction(x['instruction']),
            self._tokenize_context(x['context']),
            self._tokenize_response(x['response'])
        )
    )
    ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    return ds

In [None]:
dolly_path = '/home/xinyic/dolly/dolly_data.tfrecord'
tokenizer = GemmaTokenizer(vocab)
dolly_dataset_builder = DollyDatasetBuilder(tokenizer, max_seq_len=1000)  # why is this the case?
train_ds = dolly_dataset_builder.get_train_dataset()
train_ds = train_ds.as_numpy_iterator()
it = 0
with tf.io.TFRecordWriter(dolly_path) as writer:
  for train_record in train_ds:
    record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
    writer.write(record_bytes)
    print(f'it: {it}')
    it += 1

## Open Web Math

In [None]:
class OpenWebMathDatasetBuilder(dataset_builder.DatasetBuilder):
  """Dataset builder for the Open Web Math dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 6315233}


  BUFFER_SIZE_SHUFFLE = 100
  CONTEXT_PREFIX = 'Context: \n'
  CONTEXT_SUFFIX = '\n'
  INSTRUCTION_PREFIX = 'Instruction: \n'
  INSTRUCTION_SUFFIX = '\n'
  RESPONSE_PREFIX = 'Response: \n'
  RESPONSE_SUFFIX = '\n'

  def __init__(
      self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load(
            'huggingface:databricks__databricks_dolly_15k', split='train'
        ),
    }
    logging.info(f'dolly size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

  def _tokenize_context(self, example: tf.Tensor):
    """Tokenization function for the context."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.CONTEXT_PREFIX,
        suffix=self.CONTEXT_SUFFIX,
        add_eos=False,
    )

  def _tokenize_response(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.RESPONSE_PREFIX,
        suffix=self.RESPONSE_SUFFIX,
        add_eos=False,
    )

  def _tokenize_instruction(self, example: tf.Tensor):
    """Tokenization function for the instruction."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.INSTRUCTION_PREFIX,
        suffix=self.INSTRUCTION_SUFFIX,
        add_eos=False,
    )

  def _to_training_input(
      self,
      instruction_tokens: jax.Array,
      context_tokens: jax.Array,
      response_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [instruction_tokens, context_tokens, response_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    context_mask = tf.zeros_like(context_tokens, dtype=tf.bool)
    instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)
    response_mask = tf.ones_like(response_tokens, dtype=tf.bool)
    mask = tf.concat([instruction_mask, context_mask, response_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return dataset_builder.TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_instruction(x['instruction']),
            self._tokenize_context(x['context']),
            self._tokenize_response(x['response'])
        )
    )
    ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    return ds