<br>

<b>Imports and Constants</b>

In [1]:
# !pip install -q --upgrade tokenizer-viz

# Regular imports (native python and pypi packages)
import os
import sys
import random
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
import sentencepiece as spm
from IPython.display import HTML, display
from tokenizer_viz import TokenVisualization
from tqdm.notebook import tqdm; tqdm.pandas()

# Add project root into path so imports work
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.insert(0, PROJECT_DIR) 

# Our project imports
from spearecode.utils.preprocessing_utils import (
    load_from_txt_file, preprocess_shakespeare, save_to_txt_file, print_check_speare, get_spm_assets
)
from spearecode.utils.general_utils import (
    tf_xla_jit, tf_set_memory_growth, seed_it_all, flatten_l_o_l, print_ln
)
from spearecode.utils.filtering_utils import (
    save_ds_version, drop_str_from_col_names, pad_truncate_centered,
    get_metadata_df, check_chunks, tokenize, get_n_tokens,
    get_n_lines, get_n_chars
)
from spearecode.utils.tfrecord_utils import write_tfrecords, load_tfrecord_dataset

TRAIN_STYLE = "rcts_bpe_v4"
CHUNK_STYLE, TOK_STYLE, DS_VERSION = TRAIN_STYLE.split("_")

### DEFINE PATHS --- [PROJECT_DIR="/home/paperspace/home/spearecode"] --- ###
NBS_PATH = os.path.join(PROJECT_DIR, "nbs")
DATA_PATH = os.path.join(PROJECT_DIR, "data")
SS_TEXT_PATH = os.path.join(DATA_PATH, "t8.shakespeare.txt")
PREPROCESSED_FULL_TEXT_PATH = SS_TEXT_PATH.replace(".txt", "_preprocessed.txt")

DATASETS_PATH = os.path.join(DATA_PATH, "datasets") 
META_DIR = os.path.join(DATASETS_PATH, "meta") 
TFRECORD_DIR = os.path.join(DATASETS_PATH, "tfrecords", TRAIN_STYLE)
MODELS_DIR = os.path.join(PROJECT_DIR, "models")

# Specific Paths
SPM_MODEL_PATH = os.path.join(MODELS_DIR, f"spearecode_{TOK_STYLE}")
DATA_CSV_PATH  = os.path.join(DATASETS_PATH, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")
META_CSV_PATH  = os.path.join(META_DIR, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")

<br>

<b>Instantiate expected tools for the reset of the notebook</b>

In [2]:
sp, encoder, decoder = get_spm_assets(SPM_MODEL_PATH)
MASK_TOKEN_STR = "[MASK]"
MASK_TOKEN_INT = encoder(MASK_TOKEN_STR)

viz_tool = TokenVisualization(encoder, decoder, background_color="#FBFBFB", transparency=0.4)
train_df = pd.read_csv(DATA_CSV_PATH)
meta_df  = pd.read_csv(META_CSV_PATH)

display(train_df)
display(meta_df)

_ = viz_tool.visualize(train_df.content.sample(1).values[0], display_inline=True)

Unnamed: 0,content,token_content,n_tokens,n_chars,n_lines,valid_chunk
0,1\n From fairest creatures we desire increase...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",192,643,15,True
1,2\n When forty winters shall besiege thy brow...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",188,662,15,True
2,3\n Look in thy glass and tell the face thou ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,643,15,True
3,"4\n Unthrifty loveliness why dost thou spend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,619,15,True
4,5\n Those hours that with gentle work did fra...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",176,652,15,True
...,...,...,...,...,...,...
7694,"'""Lo, this device was sent me from a nun,\n O...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",282,944,23,True
7695,"'""How mighty then you are, O hear me tell!\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",298,983,23,True
7696,"'""Now all these hearts that do on mine depend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",283,977,23,True
7697,"'For lo, his passion, but an art of craft,\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",292,965,23,True


Unnamed: 0,n_tokens,n_chars,n_lines,valid_chunk
0,192,643,15,True
1,188,662,15,True
2,183,643,15,True
3,183,619,15,True
4,176,652,15,True
...,...,...,...,...
7694,282,944,23,True
7695,298,983,23,True
7696,283,977,23,True
7697,292,965,23,True


<br>

<b>Create Datasets</b>

In [3]:
# Get all tfrecords and shuffle
ALL_TFRECORDS = glob(os.path.join(TFRECORD_DIR, "*.tfrec"))
random.shuffle(ALL_TFRECORDS)
N_TOTAL_RECS = len(ALL_TFRECORDS)


EX_PER_TFREC = 100
VAL_PCT = 0.125
N_VAL_RECS = int(VAL_PCT*N_TOTAL_RECS)

VAL_TFRECORDS = ALL_TFRECORDS[:N_VAL_RECS]
TRAIN_TFRECORDS = ALL_TFRECORDS[N_VAL_RECS:]

train_ds = load_tfrecord_dataset(TRAIN_TFRECORDS)
val_ds = load_tfrecord_dataset(VAL_TFRECORDS)

(train_ds, val_ds)

(<MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>,
 <MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>)

<br>

<b>Training Configuration</b>

In [4]:
train_config = dict(
    batch_size=32,
    shuffle_buffer=512,
    encoder_context_len=128,
    decoder_context_len=64,
    mask_token_id=sp.encode("[MASK]")[0],
    vocab_size=sp.vocab_size(),
    n_epochs=100
)

<br>

<b>TF.Data Pipeline</b>

In [5]:
from typing import Tuple

# --- Pipeline Steps ---
# 
# 1. Shuffle examples (shuffle_buffer)
# 2. Batch examples (batch_size, drop_remainder, AUTOTUNE)
# 3. Split sequence into encoder/decoder inputs [`split_on_pivot`]
# 4. Split encoder inputs into:
#       --> 'inputs' (masked sequence)
#       --> 'labels' (unaltered sequence)
#       --> 'weights' (sample weights; 1.0 for masked tokens and 0.0 for non-mask tokens)
# 5. Split decoder inputs into:
#       --> 'inputs' (unaltered sequence)
#       --> 'labels' (sequence shifted by 1)

def split_on_pivot(tokens: tf.Tensor, 
                   encoder_context_len: int = 128, 
                   decoder_context_len: int = 64, 
                   seq_len: int = 384) -> Tuple[tf.Tensor, tf.Tensor]:
    """ Sample encoder and decoder input sequences from a batch of tokens with random pivot indices.
    
    Args:
        tokens: A batch of token sequences with shape (batch_size, seq_len).
        encoder_context_len: The number of tokens to be sampled for the encoder input sequences.
        decoder_context_len: The number of tokens to be sampled for the decoder input sequences.
        seq_len: The total length of each token sequence in the batch.

    Returns:
        encoder_inputs: A tensor with shape (batch_size, encoder_context_len) containing the
                        sampled encoder input sequences.
        decoder_inputs: A tensor with shape (batch_size, decoder_context_len) containing the
                        sampled decoder input sequences.

    Raises:
        AssertionError: If the sum of encoder_context_len and decoder_context_len is greater than seq_len.
    """
    
    # Add one to our decoder context length as we need it for AR head
    decoder_context_len+=1
    
    assert encoder_context_len + decoder_context_len <= seq_len
    batch_size = tf.shape(tokens)[0]
    c_point = seq_len // 2

    # Sample random pivot indices for each sequence in the batch
    pivot_indices = tf.random.uniform((batch_size, 1), minval=c_point - (c_point - encoder_context_len),
                                      maxval=c_point + (c_point - decoder_context_len), dtype=tf.int32)

    # Extract indices for examples before and after the pivot
    indices_before = tf.range(-encoder_context_len, 0, dtype=tf.int32)
    indices_after = tf.range(1, decoder_context_len + 1, dtype=tf.int32)

    # Compute the final indices for sampling from the data
    indices_before = tf.expand_dims(pivot_indices, 1) + indices_before
    indices_after = tf.expand_dims(pivot_indices, 1) + indices_after

    # Gather the corresponding examples from the data
    encoder_inputs = tf.squeeze(tf.gather(tokens, indices_before, axis=1, batch_dims=1))
    decoder_inputs = tf.squeeze(tf.gather(tokens, indices_after, axis=1, batch_dims=1))

    # Reshape the encoder_inputs and decoder_inputs tensors
    encoder_inputs = tf.reshape(encoder_inputs, (batch_size, encoder_context_len))
    decoder_inputs = tf.reshape(decoder_inputs, (batch_size, decoder_context_len))

    return tf.cast(encoder_inputs, tf.int32), tf.cast(decoder_inputs, tf.int32)

def mask_sequence(sequence, vocab_size, mask_token_id, pct_to_mask=0.15, pct_to_random=0.1, pct_to_keep=0.1):
        """ Mask a sequence of tokens. """

        # Calculate the probability of masking each token
        masking_prob = tf.random.uniform(shape=tf.shape(sequence), minval=0, maxval=1)

        # Calculate the mask based on the masking probability
        mask = tf.cast(masking_prob < pct_to_mask, tf.int32)

        # Calculate the probability of replacing with a random token
        random_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_random), tf.int32)

        # Calculate the probability of keeping the original token
        keep_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_keep), tf.int32)

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(mask == 1, mask_token_id * tf.ones_like(sequence, dtype=tf.int32), sequence)

        # Replace random_prob_mask tokens with random tokens
        random_tokens = tf.random.uniform(
            shape=tf.shape(sequence), minval=0, maxval=vocab_size, dtype=tf.int32
        )

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(random_prob_mask == 1, random_tokens, masked_sequence)

        # Keep the original tokens for keep_prob_mask
        masked_sequence = tf.where(keep_prob_mask == 1, sequence, masked_sequence)

        # Generate sample weights for masked tokens
        sample_weights = tf.cast(mask, tf.float32)

        return masked_sequence, sequence, sample_weights


def shift_and_split_decoder_inputs(x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """ TBD """
    window_size = tf.shape(x)[1]-1
    
    # Get the indices for the first and second vectors
    input_indices = tf.range(0, window_size, dtype=tf.int32)
    output_indices = tf.range(1, window_size+1, dtype=tf.int32)

    # Gather the corresponding columns for the first and second vectors
    decoder_inputs = tf.gather(x, input_indices, axis=-1)
    decoder_outputs = tf.gather(x, output_indices, axis=-1)

    return decoder_inputs, decoder_outputs
    
def transform_sequence(sequence, vocab_size, mask_token_id):
    encoder_inputs, decoder_inputs = split_on_pivot(sequence)
    
    # Encoder transform
    encoder_inputs, encoder_labels, encoder_sample_wts = mask_sequence(
        encoder_inputs, vocab_size, tf.constant(MASK_TOKEN_INT, dtype=tf.int32)
    )
    
    # Decoder transform
    decoder_inputs, decoder_labels = shift_and_split_decoder_inputs(decoder_inputs)
    decoder_sample_wts = tf.ones_like(decoder_labels, dtype=tf.float32)
    
    _inputs = (encoder_inputs, decoder_inputs)
    _labels = (encoder_labels, decoder_labels)
    _sample_wts = (encoder_sample_wts, decoder_sample_wts)
    return _inputs, _labels, _sample_wts
    
    
train_ds = train_ds.shuffle(train_config["shuffle_buffer"])\
                   .batch(train_config["batch_size"], drop_remainder=True)\
                   .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
                   .prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.shuffle(train_config["shuffle_buffer"])\
               .batch(train_config["batch_size"], drop_remainder=True)\
               .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
               .prefetch(tf.data.AUTOTUNE)


train_ds, val_ds

(<PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>,
 <PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>)

In [6]:
_inputs, _labels, _wts = next(iter(val_ds))

In [7]:
_ = viz_tool(_inputs[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_inputs[1][0].numpy().tolist(), display_inline=True)

_ = viz_tool(_labels[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_labels[1][0].numpy().tolist(), display_inline=True)

<div class="alert alert-block alert-warning" style="font-size: 12px;">
<br><center><b style="font-size: 18px;">TensorFloat-32 Warning:</b></center><br>This warning is related to the use of <b>TensorFloat-32</b> (<b>TF32</b>) in TensorFlow on NVIDIA Ampere architecture GPUs. <b>TensorFloat-32</b> is a new math mode in NVIDIA's A100 GPU for accelerating mixed-precision training in deep learning models. <b>TF32</b> combines the speed of lower-precision FP16 (half-precision) with the dynamic range of FP32 (single-precision).
<br><br>
The warning message you see is informing you that TensorFlow is using <b>TensorFloat-32</b> for matrix multiplication operations on the GPU. This is expected behavior and does not indicate a problem with your code or model. The warning message is logged only once to let you know that <b>TensorFloat-32</b> is being used for matrix multiplications.
<br><br>
<b>In most cases, using TensorFloat-32 can lead to significant speed improvements in training deep learning models without negatively impacting the model's accuracy or convergence.</b><br><br>

</div>

In [8]:
from spearecode.models.cllm_backbone import CLLM

# --- Model Steps ---
# 
# 1. Define Configurations
# 2. Load Model Architecture
# 3. Define Optimizer and Learning Rate Details
# 4. Define Callbacks
#       --> TBD
#       --> TBD
#       --> TBD
# 5. Define Loss Functions
#       --> MLM Loss
#       --> AR Loss
# 6. Define Metrics
#       --> TBD
#       --> TBD

enc_vocab_size, dec_vocab_size       = sp.vocab_size(), sp.vocab_size()
enc_context_len, dec_context_len     = 128, 64
enc_embed_dim, dec_embed_dim         = 128, 128
enc_hidden_layers, dec_hidden_layers = 2, 2
enc_attn_heads, dec_attn_heads       = 4, 4
enc_ffn_act, dec_ffn_act             = "gelu", "gelu"
enc_ffn_dropout, dec_ffn_dropout     = 0.1, 0.1
enc_attn_dropout, dec_attn_dropout   = 0.1, 0.1
enc_use_bias, dec_use_bias           = False, False
enc_expansion, dec_expansion         = 4, 4

enc_config = dict(
    vocab_size=enc_vocab_size,
    context_length=enc_context_len,
    embedding_size=enc_embed_dim,
    n_heads=enc_attn_heads,
    n_layers=enc_hidden_layers,
    use_bias=enc_use_bias,
    ffn_act=enc_ffn_act,
    expansion_factor=enc_expansion,
    dropout_rate=enc_ffn_dropout,
)

dec_config = dict(
    vocab_size=dec_vocab_size,
    context_length=dec_context_len,
    embedding_size=dec_embed_dim,
    n_heads=dec_attn_heads,
    n_layers=dec_hidden_layers,
    use_bias=dec_use_bias,
    ffn_act=dec_ffn_act,
    expansion_factor=dec_expansion,
    dropout_rate=dec_ffn_dropout,
)

cllm = CLLM(encoder_kwargs=enc_config, decoder_kwargs=dec_config, batch_size=train_config["batch_size"])
cllm.summary()

# test predict
cllm.predict(val_ds.take(1))

Model: "cllm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 transformer_encoder (Transf  multiple                 2835456   
 ormerEncoder)                                                   
                                                                 
 transformer_decoder (Transf  multiple                 3358720   
 ormerDecoder)                                                   
                                                                 
Total params: 6,194,176
Trainable params: 6,194,176
Non-trainable params: 0
_________________________________________________________________


(array([[[ 0.00122931, -0.00095357, -0.01183734, ..., -0.00827219,
          -0.00268872, -0.02040298],
         [ 0.00508671, -0.00096983, -0.00961313, ..., -0.01311184,
          -0.00764793, -0.02018817],
         [ 0.00446914,  0.0100721 , -0.01147282, ..., -0.00699435,
          -0.00724797, -0.00494835],
         ...,
         [-0.00913095, -0.01012567, -0.01302268, ...,  0.00353795,
          -0.00450482, -0.02274027],
         [-0.00487253, -0.01187996, -0.01218519, ..., -0.00332531,
          -0.00837998, -0.02055086],
         [-0.00235857, -0.00119126, -0.01370125, ...,  0.00077477,
          -0.01173878, -0.01417518]],
 
        [[-0.00517943,  0.00559564, -0.00861732, ..., -0.00070297,
          -0.0050701 , -0.002491  ],
         [ 0.00578697,  0.00692591, -0.009928  , ..., -0.01175569,
          -0.01015619, -0.00861843],
         [-0.00335376, -0.00390136, -0.01253616, ...,  0.00590277,
          -0.00479247,  0.01420439],
         ...,
         [-0.01419794, -0.0118983

In [9]:
from spearecode.optimizers import AdamWeightDecay, WarmUpCosineDecay

approx_total_steps = N_TOTAL_RECS*100
approx_val_steps = N_VAL_RECS*100
approx_train_steps = approx_total_steps-approx_val_steps

optimizer_config = dict(
    use_basic_adam=True,
    use_cdecay_lr=True,
    weight_decay_rate=0.1,
    clipnorm=True,
    gradient_clip_norm=1.0,
    beta_1=1.0,
    beta_2=0.95,
    exclude_from_weight_decay = ['layer_normalization', 'bias'],
)

lr_config = dict(
    init_lr=0.001,
    min_lr=5e-05,
    decay_portion=1.0,
    warmup_portion=0.05,
    hold_portion=0.01,
    total_steps=approx_train_steps,
    alpha=0.0,
    decay_steps=approx_train_steps,
    warmup_steps=int(approx_train_steps*0.05),
    hold_steps=int(approx_train_steps*0.01),
)

# Instantiate our learning rate (or lr-schedule)
if optimizer_config["use_cdecay_lr"]:
    optimizer_config.pop("use_cdecay_lr")
    lr=WarmUpCosineDecay(**lr_config)
else:
    lr=lr_config["init_lr"]

# Instantiate our optimizer (AdamW or just vanilla Adam)
if not optimizer_config["use_basic_adam"]:
    optimizer_config.pop("use_basic_adam")
    optimizer = AdamWeightDecay(learning_rate=lr, **optimizer_config)
else:
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    
optimizer

<keras.optimizers.optimizer_v2.adam.Adam at 0x7f0540246310>

In [10]:
from spearecode.callbacks import get_callbacks

CKPT_DIR = os.path.join(MODELS_DIR, "ckpts")
if not os.path.isdir(CKPT_DIR): os.makedirs(CKPT_DIR, exist_ok=True)

cb_config = dict(
    ckpt_dir=CKPT_DIR,
    save_weights_only=True,
    use_early_stopping=True,
    es_patience=10,
    verbose=1,
)

cb_list = get_callbacks(cb_config)
cb_list

[<keras.callbacks.ModelCheckpoint at 0x7f052c551400>,
 <keras.callbacks.EarlyStopping at 0x7f052c5517c0>]

In [11]:
loss_fns = [
    tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # ENCODER MLM LOSS
    tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # DECODER AR  LOSS
]

metrics = [
    #TBD
]

In [None]:
# loss_weights=[0.5, 0.5]
# metrics = TBD
cllm.compile(optimizer, loss=loss_fns)
history = cllm.fit(train_ds, validation_data=val_ds, epochs=train_config["n_epochs"], callbacks=cb_list)

Epoch 1/100

Epoch 1: val_loss improved from inf to 5.20506, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 2/100

Epoch 2: val_loss improved from 5.20506 to 4.68722, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 3/100

Epoch 3: val_loss improved from 4.68722 to 4.58753, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 4/100

Epoch 4: val_loss improved from 4.58753 to 4.45024, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 5/100

Epoch 5: val_loss improved from 4.45024 to 4.31344, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 6/100

Epoch 6: val_loss did not improve from 4.31344
Epoch 7/100

Epoch 7: val_loss improved from 4.31344 to 4.17539, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 8/100

Epoch 8: val_loss improved from 4.17539 to 4.10638, saving model to /home/paperspace/home/spearecode/models/ckpts
Epoch 9/100

In [None]:
x,y,s = next(iter(train_ds))

In [None]:
x