In [None]:
import os

VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = '/g_mini/2b_it_v1p1_orbax/1'
vocab_path = 'home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model'

In [None]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

# Finally, we import Gemma.
from colabtools import adhoc_import
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
from sentencepiece.src.python import sentencepiece_processor as spm

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

vocab_list = [(id, vocab.IdToPiece(id)) for id in range(vocab.GetPieceSize())]
letters = ['A', 'B', 'C', 'D']
res_dict = {}
for id, piece in vocab_list:
  try:
    letter = piece[piece.find(next(filter(str.isalpha, piece)))]
    if letter in letters:
      res_dict[id] = letter
  except:
    pass


In [None]:
class GemmaTokenizer:
  """Custom wrapper around a SentencePieceProcessor for tensorflow."""

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad id."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_eos: bool = True) -> jax.Array:
    """
    Tokenization function.

    Args:
      example: input string to tokenize.
      prefix:  prefix to add to the input string.
      suffix:  suffix to add to the input string.
      add_eos: if True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_eos: bool = True) -> tf.Tensor:
    """Tensforflow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

In [None]:
tokenizer = GemmaTokenizer(vocab)


In [None]:
"""Base class for dataset builders."""

import chex
import jax
import tensorflow as tf

@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens provided to model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation
  target_mask: jax.Array


class DatasetBuilder:
  """Base class for dataset builders.

  This class provides the interface for dataset builders.
  """

  def __init__(self, tokenizer: GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._max_seq_len = max_seq_len

  def _pad_up_to_max_len(
      self, input_tensor: tf.Tensor, pad_value: int | bool
  ) -> tf.Tensor:
    """Pads the given tensor up to max_seq_len."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(0, self._max_seq_len - seq_len)
    return tf.pad(
        input_tensor,
        [[0, to_pad]],
        mode='CONSTANT',
        constant_values=pad_value
    )

  def get_train_dataset(self):
    raise NotImplementedError()

  def get_validation_dataset(self, batch_size: int):
    raise NotImplementedError()


In [None]:
"""Dataset builder for the Wikipedia datasets."""

import enum as Enum
import random

from absl import logging
import jax.dlpack
import tensorflow as tf
import tensorflow_datasets as tfds


topic_wise_save_path = 'home/shivguptashi/wikidata/topic_wise_tfds'

class DatasetSplit(Enum.Enum):
  TRAIN = 'train'


class WikipediaDatasetBuilder(DatasetBuilder):
  """Dataset builder for the Open Orca dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 2914896}

  #BUFFER_SIZE_SHUFFLE = 10_000
  BUFFER_SIZE_SHUFFLE = 100
  TEXT_PREFIX = 'Text: \n'
  TEXT_SUFFIX = '\n'
  TITLE_PREFIX = 'Title: \n'
  TITLE_SUFFIX = '\n'
  #TRANSLATION_PREFIX = 'Translate this into French:\n'
  #TRANSLATION_SUFFIX = '\n'

  def __init__(
      self, tokenizer: GemmaTokenizer, max_seq_len: int, topic_index: int
  ):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tf.data.Dataset.load(topic_wise_save_path + '_topic_' + str(topic_index)),
    }
    print(f'Topic {topic_index} size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')
    self._max_seq_len = max_seq_len

    sample_ds = self._base_data[DatasetSplit.TRAIN].take(2)
    for x in sample_ds:
      print(x[0])

  def _tokenize_title(self, example: tf.Tensor):
    """Tokenization function for the Question."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.TITLE_PREFIX,
        suffix=self.TITLE_SUFFIX,
        add_eos=False,
    )

  def _tokenize_text(self, example: tf.Tensor):
    """Tokenization function for the Response."""
    return self._tokenizer.tokenize_tf_op(
        example,
        prefix=self.TEXT_PREFIX,
        suffix=self.TEXT_SUFFIX,
        add_eos=True,
    )

  def _to_training_input(
      self,
      title_tokens: jax.Array,
      text_tokens: jax.Array,
  ):
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat(
        [title_tokens, text_tokens], axis=0
    )

    # To prevent the model from updating based on the source (input)
    # tokens, add a target mask to each input.
    title_mask = tf.ones_like(title_tokens, dtype=tf.bool)
    text_mask = tf.ones_like(text_tokens, dtype=tf.bool)

    mask = tf.concat([title_mask, text_mask], axis=0)
    # If the output tokens sequence is smaller than the target sequence size,
    # then pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # Don't want to perform the backward pass on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)
    return TrainingInput( #type: ignore
        input_tokens=tokens, #type:ignore
        target_mask=mask,  #type:ignore
    )# type: ignore

  def get_train_dataset(self):
    """Build the training dataset."""

    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x, y, z: (
            self._tokenize_title(x),
            self._tokenize_text(y),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)
    #ds = ds.repeat(num_epochs)
    #ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same steps as in `get_train_dataset`, but without shuffling and
    # repetition.
    # ds = self._base_data[DatasetSplit.VALIDATION].map(
    #    lambda x: (self._tokenize_source(x['src']),
    #               self._tokenize_destination(x['dst'])))
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x: (
            self._tokenize_title(x[0]),
            self._tokenize_text(x[1]),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(
        lambda x, y: self._to_training_input(x, y),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    # ds = ds.batch(batch_size, drop_remainder=True)
    return ds
    # ds = [self._to_training_input(x, y) for x, y in ds]
    # print('here3:', ds)
    # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]
    # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]


In [None]:
wiki_tokenized_path = 'home/shivguptashi/open_orca/wiki_tokenized'
tokenizer = GemmaTokenizer(vocab)
for topic in range(54, 64):
  wikipedia_dataset_builder = WikipediaDatasetBuilder(tokenizer, max_seq_len=1000, topic_index=topic)
  train_ds = wikipedia_dataset_builder.get_train_dataset()
  train_ds = train_ds.as_numpy_iterator()
  it = 0
  cur_tokenized_path = wiki_tokenized_path + '_topic_' + str(topic) + '.tfrecord'
  with tf.io.TFRecordWriter(cur_tokenized_path) as writer:
    for train_record in train_ds:
      record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), "target_mask": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()
      writer.write(record_bytes)
      print(f'it: {it}')
      it += 1