<br>

<b>Imports and Constants</b>

In [1]:
# !pip install -q --upgrade tokenizer-viz

# Regular imports (native python and pypi packages)
import os
import sys
import random
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
import sentencepiece as spm
from IPython.display import HTML, display
from tokenizer_viz import TokenVisualization
from tqdm.notebook import tqdm; tqdm.pandas()

# Add project root into path so imports work
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.insert(0, PROJECT_DIR) 

# Our project imports
from spearecode.preprocessing_utils import (
    load_from_txt_file, preprocess_shakespeare, save_to_txt_file, print_check_speare, get_spm_assets
)
from spearecode.general_utils import (
    tf_xla_jit, tf_set_memory_growth, seed_it_all, flatten_l_o_l, print_ln
)
from spearecode.filtering_utils import (
    save_ds_version, drop_str_from_col_names, pad_truncate_centered,
    get_metadata_df, check_chunks, tokenize, get_n_tokens,
    get_n_lines, get_n_chars
)
from spearecode.tfrecord_utils import write_tfrecords, load_tfrecord_dataset

TRAIN_STYLE = "rcts_bpe_v4"
CHUNK_STYLE, TOK_STYLE, DS_VERSION = TRAIN_STYLE.split("_")

### DEFINE PATHS --- [PROJECT_DIR="/home/paperspace/home/spearecode"] --- ###
NBS_PATH = os.path.join(PROJECT_DIR, "nbs")
DATA_PATH = os.path.join(PROJECT_DIR, "data")
SS_TEXT_PATH = os.path.join(DATA_PATH, "t8.shakespeare.txt")
PREPROCESSED_FULL_TEXT_PATH = SS_TEXT_PATH.replace(".txt", "_preprocessed.txt")

DATASETS_PATH = os.path.join(DATA_PATH, "datasets") 
META_DIR = os.path.join(DATASETS_PATH, "meta") 
TFRECORD_DIR = os.path.join(DATASETS_PATH, "tfrecords", TRAIN_STYLE)
MODELS_DIR = os.path.join(PROJECT_DIR, "models")

# Specific Paths
SPM_MODEL_PATH = os.path.join(MODELS_DIR, f"spearecode_{TOK_STYLE}")
DATA_CSV_PATH  = os.path.join(DATASETS_PATH, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")
META_CSV_PATH  = os.path.join(META_DIR, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")

<br>

<b>Instantiate expected tools for the reset of the notebook</b>

In [2]:
sp, encoder, decoder = get_spm_assets(SPM_MODEL_PATH)
MASK_TOKEN_STR = "[MASK]"
MASK_TOKEN_INT = encoder(MASK_TOKEN_STR)

viz_tool = TokenVisualization(encoder, decoder, background_color="#FBFBFB", transparency=0.4)
train_df = pd.read_csv(DATA_CSV_PATH)
meta_df  = pd.read_csv(META_CSV_PATH)

display(train_df)
display(meta_df)

_ = viz_tool.visualize(train_df.content.sample(1).values[0], display_inline=True)

Unnamed: 0,content,token_content,n_tokens,n_chars,n_lines,valid_chunk
0,1\n From fairest creatures we desire increase...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",192,643,15,True
1,2\n When forty winters shall besiege thy brow...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",188,662,15,True
2,3\n Look in thy glass and tell the face thou ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,643,15,True
3,"4\n Unthrifty loveliness why dost thou spend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,619,15,True
4,5\n Those hours that with gentle work did fra...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",176,652,15,True
...,...,...,...,...,...,...
7694,"'""Lo, this device was sent me from a nun,\n O...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",282,944,23,True
7695,"'""How mighty then you are, O hear me tell!\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",298,983,23,True
7696,"'""Now all these hearts that do on mine depend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",283,977,23,True
7697,"'For lo, his passion, but an art of craft,\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",292,965,23,True


Unnamed: 0,n_tokens,n_chars,n_lines,valid_chunk
0,192,643,15,True
1,188,662,15,True
2,183,643,15,True
3,183,619,15,True
4,176,652,15,True
...,...,...,...,...
7694,282,944,23,True
7695,298,983,23,True
7696,283,977,23,True
7697,292,965,23,True


<br>

<b>Create Datasets</b>

In [3]:
# Get all tfrecords and shuffle
ALL_TFRECORDS = glob(os.path.join(TFRECORD_DIR, "*.tfrec"))
random.shuffle(ALL_TFRECORDS)
N_TOTAL_RECS = len(ALL_TFRECORDS)


EX_PER_TFREC = 100
VAL_PCT = 0.125
N_VAL_RECS = int(VAL_PCT*N_TOTAL_RECS)

VAL_TFRECORDS = ALL_TFRECORDS[:N_VAL_RECS]
TRAIN_TFRECORDS = ALL_TFRECORDS[N_VAL_RECS:]

train_ds = load_tfrecord_dataset(TRAIN_TFRECORDS)
val_ds = load_tfrecord_dataset(VAL_TFRECORDS)

(train_ds, val_ds)

(<MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>,
 <MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>)

<br>

<b>Training Configuration</b>

In [15]:
train_config = dict(
    batch_size=32,
    shuffle_buffer=512,
    encoder_context_len=128,
    decoder_context_len=64,
    mask_token_id=sp.encode("[MASK]")[0],
    vocab_size=sp.vocab_size(),
)

<br>

<b>TF.Data Pipeline</b>

In [16]:
from typing import Tuple

# --- Pipeline Steps ---
# 
# 1. Shuffle examples (shuffle_buffer)
# 2. Batch examples (batch_size, drop_remainder, AUTOTUNE)
# 3. Split sequence into encoder/decoder inputs [`split_on_pivot`]
# 4. Split encoder inputs into:
#       --> 'inputs' (masked sequence)
#       --> 'labels' (unaltered sequence)
#       --> 'weights' (sample weights; 1.0 for masked tokens and 0.0 for non-mask tokens)
# 5. Split decoder inputs into:
#       --> 'inputs' (unaltered sequence)
#       --> 'labels' (sequence shifted by 1)

def split_on_pivot(tokens: tf.Tensor, 
                   encoder_context_len: int = 128, 
                   decoder_context_len: int = 64, 
                   seq_len: int = 384) -> Tuple[tf.Tensor, tf.Tensor]:
    """ Sample encoder and decoder input sequences from a batch of tokens with random pivot indices.
    
    Args:
        tokens: A batch of token sequences with shape (batch_size, seq_len).
        encoder_context_len: The number of tokens to be sampled for the encoder input sequences.
        decoder_context_len: The number of tokens to be sampled for the decoder input sequences.
        seq_len: The total length of each token sequence in the batch.

    Returns:
        encoder_inputs: A tensor with shape (batch_size, encoder_context_len) containing the
                        sampled encoder input sequences.
        decoder_inputs: A tensor with shape (batch_size, decoder_context_len) containing the
                        sampled decoder input sequences.

    Raises:
        AssertionError: If the sum of encoder_context_len and decoder_context_len is greater than seq_len.
    """
    
    # Add one to our decoder context length as we need it for AR head
    decoder_context_len+=1
    
    assert encoder_context_len + decoder_context_len <= seq_len
    batch_size = tf.shape(tokens)[0]
    c_point = seq_len // 2

    # Sample random pivot indices for each sequence in the batch
    pivot_indices = tf.random.uniform((batch_size, 1), minval=c_point - (c_point - encoder_context_len),
                                      maxval=c_point + (c_point - decoder_context_len), dtype=tf.int32)

    # Extract indices for examples before and after the pivot
    indices_before = tf.range(-encoder_context_len, 0, dtype=tf.int32)
    indices_after = tf.range(1, decoder_context_len + 1, dtype=tf.int32)

    # Compute the final indices for sampling from the data
    indices_before = tf.expand_dims(pivot_indices, 1) + indices_before
    indices_after = tf.expand_dims(pivot_indices, 1) + indices_after

    # Gather the corresponding examples from the data
    encoder_inputs = tf.squeeze(tf.gather(tokens, indices_before, axis=1, batch_dims=1))
    decoder_inputs = tf.squeeze(tf.gather(tokens, indices_after, axis=1, batch_dims=1))

    # Reshape the encoder_inputs and decoder_inputs tensors
    encoder_inputs = tf.reshape(encoder_inputs, (batch_size, encoder_context_len))
    decoder_inputs = tf.reshape(decoder_inputs, (batch_size, decoder_context_len))

    return tf.cast(encoder_inputs, tf.int32), tf.cast(decoder_inputs, tf.int32)

def mask_sequence(sequence, vocab_size, mask_token_id, pct_to_mask=0.15, pct_to_random=0.1, pct_to_keep=0.1):
        """ Mask a sequence of tokens. """

        # Calculate the probability of masking each token
        masking_prob = tf.random.uniform(shape=tf.shape(sequence), minval=0, maxval=1)

        # Calculate the mask based on the masking probability
        mask = tf.cast(masking_prob < pct_to_mask, tf.int32)

        # Calculate the probability of replacing with a random token
        random_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_random), tf.int32)

        # Calculate the probability of keeping the original token
        keep_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_keep), tf.int32)

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(mask == 1, mask_token_id * tf.ones_like(sequence, dtype=tf.int32), sequence)

        # Replace random_prob_mask tokens with random tokens
        random_tokens = tf.random.uniform(
            shape=tf.shape(sequence), minval=0, maxval=vocab_size, dtype=tf.int32
        )

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(random_prob_mask == 1, random_tokens, masked_sequence)

        # Keep the original tokens for keep_prob_mask
        masked_sequence = tf.where(keep_prob_mask == 1, sequence, masked_sequence)

        # Generate sample weights for masked tokens
        sample_weights = tf.cast(mask, tf.float32)

        return masked_sequence, sequence, sample_weights


def shift_and_split_decoder_inputs(x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """ TBD """
    window_size = tf.shape(x)[1]-1
    
    # Get the indices for the first and second vectors
    input_indices = tf.range(0, window_size, dtype=tf.int32)
    output_indices = tf.range(1, window_size+1, dtype=tf.int32)

    # Gather the corresponding columns for the first and second vectors
    decoder_inputs = tf.gather(x, input_indices, axis=-1)
    decoder_outputs = tf.gather(x, output_indices, axis=-1)

    return decoder_inputs, decoder_outputs
    
def transform_sequence(sequence, vocab_size, mask_token_id):
    encoder_inputs, decoder_inputs = split_on_pivot(sequence)
    
    # Encoder transform
    encoder_inputs, encoder_labels, encoder_sample_wts = mask_sequence(
        encoder_inputs, vocab_size, tf.constant(MASK_TOKEN_INT, dtype=tf.int32)
    )
    
    # Decoder transform
    decoder_inputs, decoder_labels = shift_and_split_decoder_inputs(decoder_inputs)
    decoder_sample_wts = tf.ones_like(decoder_labels, dtype=tf.float32)
    
    _inputs = (encoder_inputs, decoder_inputs)
    _labels = (encoder_labels, decoder_labels)
    _sample_wts = (encoder_sample_wts, decoder_sample_wts)
    return _inputs, _labels, _sample_wts
    
    
train_ds = train_ds.shuffle(train_config["shuffle_buffer"])\
                   .batch(train_config["batch_size"], drop_remainder=True)\
                   .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
                   .prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.shuffle(train_config["shuffle_buffer"])\
               .batch(train_config["batch_size"], drop_remainder=True)\
               .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
               .prefetch(tf.data.AUTOTUNE)


train_ds, val_ds

(<PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>,
 <PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>)

In [17]:
_inputs, _labels, _wts = next(iter(val_ds))

In [25]:
_ = viz_tool(_inputs[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_inputs[1][0].numpy().tolist(), display_inline=True)

_ = viz_tool(_labels[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_labels[1][0].numpy().tolist(), display_inline=True)

In [34]:
from spearecode.models.cllm_backbone import CLLM



model_config = dict(
    vocab_size=-1,
    context_length=128,
    embedding_size=256,
    n_heads=-1,
    n_layers=-1,
    use_bias=True,
    ffn_act="gelu",
    expansion_factor=6,
    dropout_rate=0.3,
)

# cllm = CLLM(**model_config)
# cllm.summary()