<br>

<b>Imports and Constants</b>

In [1]:
# !pip install -q --upgrade tokenizer-viz

# Regular imports (native python and pypi packages)
import os
import sys
import random
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
import sentencepiece as spm
from IPython.display import HTML, display
from tokenizer_viz import TokenVisualization
from tqdm.notebook import tqdm; tqdm.pandas()

# Add project root into path so imports work
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.insert(0, PROJECT_DIR) 

# Our project imports
from spearecode.utils.preprocessing_utils import (
    load_from_txt_file, preprocess_shakespeare, save_to_txt_file, print_check_speare, get_spm_assets
)
from spearecode.utils.general_utils import (
    tf_xla_jit, tf_set_memory_growth, seed_it_all, flatten_l_o_l, print_ln
)
from spearecode.utils.filtering_utils import (
    save_ds_version, drop_str_from_col_names, pad_truncate_centered,
    get_metadata_df, check_chunks, tokenize, get_n_tokens,
    get_n_lines, get_n_chars
)
from spearecode.utils.tfrecord_utils import write_tfrecords, load_tfrecord_dataset

TRAIN_STYLE = "rcts_bpe_v4"
CHUNK_STYLE, TOK_STYLE, DS_VERSION = TRAIN_STYLE.split("_")

### DEFINE PATHS --- [PROJECT_DIR="/home/paperspace/home/spearecode"] --- ###
NBS_PATH = os.path.join(PROJECT_DIR, "nbs")
DATA_PATH = os.path.join(PROJECT_DIR, "data")
SS_TEXT_PATH = os.path.join(DATA_PATH, "t8.shakespeare.txt")
PREPROCESSED_FULL_TEXT_PATH = SS_TEXT_PATH.replace(".txt", "_preprocessed.txt")

DATASETS_PATH = os.path.join(DATA_PATH, "datasets") 
META_DIR = os.path.join(DATASETS_PATH, "meta") 
TFRECORD_DIR = os.path.join(DATASETS_PATH, "tfrecords", TRAIN_STYLE)
MODELS_DIR = os.path.join(PROJECT_DIR, "models")

# Specific Paths
SPM_MODEL_PATH = os.path.join(MODELS_DIR, f"spearecode_{TOK_STYLE}")
DATA_CSV_PATH  = os.path.join(DATASETS_PATH, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")
META_CSV_PATH  = os.path.join(META_DIR, f"{DS_VERSION}_{CHUNK_STYLE}_{TOK_STYLE}.csv")

2023-04-14 18:10:02.256938: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<br>

<b>Instantiate expected tools for the reset of the notebook</b>

In [2]:
sp, encoder, decoder = get_spm_assets(SPM_MODEL_PATH)
MASK_TOKEN_STR = "[MASK]"
MASK_TOKEN_INT = encoder(MASK_TOKEN_STR)

viz_tool = TokenVisualization(encoder, decoder, background_color="#FBFBFB", transparency=0.4)
train_df = pd.read_csv(DATA_CSV_PATH)
meta_df  = pd.read_csv(META_CSV_PATH)

display(train_df)
display(meta_df)

_ = viz_tool.visualize(train_df.content.sample(1).values[0], display_inline=True)

Unnamed: 0,content,token_content,n_tokens,n_chars,n_lines,valid_chunk
0,1\n From fairest creatures we desire increase...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",192,643,15,True
1,2\n When forty winters shall besiege thy brow...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",188,662,15,True
2,3\n Look in thy glass and tell the face thou ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,643,15,True
3,"4\n Unthrifty loveliness why dost thou spend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",183,619,15,True
4,5\n Those hours that with gentle work did fra...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",176,652,15,True
...,...,...,...,...,...,...
7694,"'""Lo, this device was sent me from a nun,\n O...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",282,944,23,True
7695,"'""How mighty then you are, O hear me tell!\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",298,983,23,True
7696,"'""Now all these hearts that do on mine depend,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",283,977,23,True
7697,"'For lo, his passion, but an art of craft,\n ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",292,965,23,True


Unnamed: 0,n_tokens,n_chars,n_lines,valid_chunk
0,192,643,15,True
1,188,662,15,True
2,183,643,15,True
3,183,619,15,True
4,176,652,15,True
...,...,...,...,...
7694,282,944,23,True
7695,298,983,23,True
7696,283,977,23,True
7697,292,965,23,True


<br>

<b>Create Datasets</b>

In [3]:
# Get all tfrecords and shuffle
ALL_TFRECORDS = glob(os.path.join(TFRECORD_DIR, "*.tfrec"))
random.shuffle(ALL_TFRECORDS)
N_TOTAL_RECS = len(ALL_TFRECORDS)


EX_PER_TFREC = 100
VAL_PCT = 0.125
N_VAL_RECS = int(VAL_PCT*N_TOTAL_RECS)

VAL_TFRECORDS = ALL_TFRECORDS[:N_VAL_RECS]
TRAIN_TFRECORDS = ALL_TFRECORDS[N_VAL_RECS:]

train_ds = load_tfrecord_dataset(TRAIN_TFRECORDS)
val_ds = load_tfrecord_dataset(VAL_TFRECORDS)

(train_ds, val_ds)

2023-04-14 18:10:04.636108: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-14 18:10:04.663823: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-14 18:10:04.664022: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysf

(<_MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>,
 <_MapDataset element_spec=TensorSpec(shape=(384,), dtype=tf.int64, name=None)>)

<br>

<b>Training Configuration</b>

In [4]:
train_config = dict(
    batch_size=32,
    shuffle_buffer=512,
    encoder_context_len=128,
    decoder_context_len=64,
    mask_token_id=sp.encode("[MASK]")[0],
    vocab_size=sp.vocab_size(),
    n_epochs=100
)

<br>

<b>TF.Data Pipeline</b>

In [5]:
from typing import Tuple

# --- Pipeline Steps ---
# 
# 1. Shuffle examples (shuffle_buffer)
# 2. Batch examples (batch_size, drop_remainder, AUTOTUNE)
# 3. Split sequence into encoder/decoder inputs [`split_on_pivot`]
# 4. Split encoder inputs into:
#       --> 'inputs' (masked sequence)
#       --> 'labels' (unaltered sequence)
#       --> 'weights' (sample weights; 1.0 for masked tokens and 0.0 for non-mask tokens)
# 5. Split decoder inputs into:
#       --> 'inputs' (unaltered sequence)
#       --> 'labels' (sequence shifted by 1)

def split_on_pivot(tokens: tf.Tensor, 
                   encoder_context_len: int = 128, 
                   decoder_context_len: int = 64, 
                   seq_len: int = 384) -> Tuple[tf.Tensor, tf.Tensor]:
    """ Sample encoder and decoder input sequences from a batch of tokens with random pivot indices.
    
    Args:
        tokens: A batch of token sequences with shape (batch_size, seq_len).
        encoder_context_len: The number of tokens to be sampled for the encoder input sequences.
        decoder_context_len: The number of tokens to be sampled for the decoder input sequences.
        seq_len: The total length of each token sequence in the batch.

    Returns:
        encoder_inputs: A tensor with shape (batch_size, encoder_context_len) containing the
                        sampled encoder input sequences.
        decoder_inputs: A tensor with shape (batch_size, decoder_context_len) containing the
                        sampled decoder input sequences.

    Raises:
        AssertionError: If the sum of encoder_context_len and decoder_context_len is greater than seq_len.
    """
    
    # Add one to our decoder context length as we need it for AR head
    decoder_context_len+=1
    
    assert encoder_context_len + decoder_context_len <= seq_len
    batch_size = tf.shape(tokens)[0]
    c_point = seq_len // 2

    # Sample random pivot indices for each sequence in the batch
    pivot_indices = tf.random.uniform((batch_size, 1), minval=c_point - (c_point - encoder_context_len),
                                      maxval=c_point + (c_point - decoder_context_len), dtype=tf.int32)

    # Extract indices for examples before and after the pivot
    indices_before = tf.range(-encoder_context_len, 0, dtype=tf.int32)
    indices_after = tf.range(1, decoder_context_len + 1, dtype=tf.int32)

    # Compute the final indices for sampling from the data
    indices_before = tf.expand_dims(pivot_indices, 1) + indices_before
    indices_after = tf.expand_dims(pivot_indices, 1) + indices_after

    # Gather the corresponding examples from the data
    encoder_inputs = tf.squeeze(tf.gather(tokens, indices_before, axis=1, batch_dims=1))
    decoder_inputs = tf.squeeze(tf.gather(tokens, indices_after, axis=1, batch_dims=1))

    # Reshape the encoder_inputs and decoder_inputs tensors
    encoder_inputs = tf.reshape(encoder_inputs, (batch_size, encoder_context_len))
    decoder_inputs = tf.reshape(decoder_inputs, (batch_size, decoder_context_len))

    return tf.cast(encoder_inputs, tf.int32), tf.cast(decoder_inputs, tf.int32)

def mask_sequence(sequence, vocab_size, mask_token_id, pct_to_mask=0.15, pct_to_random=0.1, pct_to_keep=0.1):
        """ Mask a sequence of tokens. """

        # Calculate the probability of masking each token
        masking_prob = tf.random.uniform(shape=tf.shape(sequence), minval=0, maxval=1)

        # Calculate the mask based on the masking probability
        mask = tf.cast(masking_prob < pct_to_mask, tf.int32)

        # Calculate the probability of replacing with a random token
        random_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_random), tf.int32)

        # Calculate the probability of keeping the original token
        keep_prob_mask = tf.cast(masking_prob < (pct_to_mask * pct_to_keep), tf.int32)

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(mask == 1, mask_token_id * tf.ones_like(sequence, dtype=tf.int32), sequence)

        # Replace random_prob_mask tokens with random tokens
        random_tokens = tf.random.uniform(
            shape=tf.shape(sequence), minval=0, maxval=vocab_size, dtype=tf.int32
        )

        # Replace the masked tokens with the mask_token_id
        masked_sequence = tf.where(random_prob_mask == 1, random_tokens, masked_sequence)

        # Keep the original tokens for keep_prob_mask
        masked_sequence = tf.where(keep_prob_mask == 1, sequence, masked_sequence)

        # Generate sample weights for masked tokens
        sample_weights = tf.cast(mask, tf.float32)

        return masked_sequence, sequence, sample_weights


def shift_and_split_decoder_inputs(x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """ TBD """
    window_size = tf.shape(x)[1]-1
    
    # Get the indices for the first and second vectors
    input_indices = tf.range(0, window_size, dtype=tf.int32)
    output_indices = tf.range(1, window_size+1, dtype=tf.int32)

    # Gather the corresponding columns for the first and second vectors
    decoder_inputs = tf.gather(x, input_indices, axis=-1)
    decoder_outputs = tf.gather(x, output_indices, axis=-1)

    return decoder_inputs, decoder_outputs
    
def transform_sequence(sequence, vocab_size, mask_token_id):
    encoder_inputs, decoder_inputs = split_on_pivot(sequence)
    
    # Encoder transform
    encoder_inputs, encoder_labels, encoder_sample_wts = mask_sequence(
        encoder_inputs, vocab_size, tf.constant(MASK_TOKEN_INT, dtype=tf.int32)
    )
    
    # Decoder transform
    decoder_inputs, decoder_labels = shift_and_split_decoder_inputs(decoder_inputs)
    decoder_sample_wts = tf.ones_like(decoder_labels, dtype=tf.float32)
    
    _inputs = (encoder_inputs, decoder_inputs)
    _labels = (encoder_labels, decoder_labels)
    _sample_wts = (encoder_sample_wts, decoder_sample_wts)
    return _inputs, _labels, _sample_wts
    
    
train_ds = train_ds.shuffle(train_config["shuffle_buffer"])\
                   .batch(train_config["batch_size"], drop_remainder=True)\
                   .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
                   .prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.shuffle(train_config["shuffle_buffer"])\
               .batch(train_config["batch_size"], drop_remainder=True)\
               .map(lambda x: transform_sequence(x, train_config["vocab_size"], train_config["mask_token_id"]), num_parallel_calls=tf.data.AUTOTUNE)\
               .prefetch(tf.data.AUTOTUNE)


train_ds, val_ds

(<_PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>,
 <_PrefetchDataset element_spec=((TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.int32, name=None), TensorSpec(shape=(32, 64), dtype=tf.int32, name=None)), (TensorSpec(shape=(32, 128), dtype=tf.float32, name=None), TensorSpec(shape=(32, 64), dtype=tf.float32, name=None)))>)

In [6]:
_inputs, _labels, _wts = next(iter(val_ds))

2023-04-14 18:10:05.991892: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [9]
	 [[{{node Placeholder/_0}}]]
2023-04-14 18:10:05.992187: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [9]
	 [[{{node Placeholder/_0}}]]


In [7]:
_ = viz_tool(_inputs[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_inputs[1][0].numpy().tolist(), display_inline=True)

_ = viz_tool(_labels[0][0].numpy().tolist(), display_inline=True)
_ = viz_tool(_labels[1][0].numpy().tolist(), display_inline=True)

<div class="alert alert-block alert-warning" style="font-size: 12px;">
<br><center><b style="font-size: 18px;">TensorFloat-32 Warning:</b></center><br>This warning is related to the use of <b>TensorFloat-32</b> (<b>TF32</b>) in TensorFlow on NVIDIA Ampere architecture GPUs. <b>TensorFloat-32</b> is a new math mode in NVIDIA's A100 GPU for accelerating mixed-precision training in deep learning models. <b>TF32</b> combines the speed of lower-precision FP16 (half-precision) with the dynamic range of FP32 (single-precision).
<br><br>
The warning message you see is informing you that TensorFlow is using <b>TensorFloat-32</b> for matrix multiplication operations on the GPU. This is expected behavior and does not indicate a problem with your code or model. The warning message is logged only once to let you know that <b>TensorFloat-32</b> is being used for matrix multiplications.
<br><br>
<b>In most cases, using TensorFloat-32 can lead to significant speed improvements in training deep learning models without negatively impacting the model's accuracy or convergence.</b><br><br>

</div>

In [8]:
from spearecode.models.cllm_backbone import CLLM

# --- Model Steps ---
# 
# 1. Define Configurations
# 2. Load Model Architecture
# 3. Define Optimizer and Learning Rate Details
# 4. Define Callbacks
#       --> TBD
#       --> TBD
#       --> TBD
# 5. Define Loss Functions
#       --> MLM Loss
#       --> AR Loss
# 6. Define Metrics
#       --> TBD
#       --> TBD

enc_vocab_size, dec_vocab_size       = sp.vocab_size(), sp.vocab_size()
enc_context_len, dec_context_len     = 128, 64
enc_embed_dim, dec_embed_dim         = 128, 128
enc_hidden_layers, dec_hidden_layers = 2, 2
enc_attn_heads, dec_attn_heads       = 4, 4
enc_ffn_act, dec_ffn_act             = "gelu", "gelu"
enc_ffn_dropout, dec_ffn_dropout     = 0.1, 0.1
enc_attn_dropout, dec_attn_dropout   = 0.1, 0.1
enc_use_bias, dec_use_bias           = False, False
enc_expansion, dec_expansion         = 4, 4

enc_config = dict(
    vocab_size=enc_vocab_size,
    context_length=enc_context_len,
    embedding_size=enc_embed_dim,
    n_heads=enc_attn_heads,
    n_layers=enc_hidden_layers,
    use_bias=enc_use_bias,
    ffn_act=enc_ffn_act,
    expansion_factor=enc_expansion,
    dropout_rate=enc_ffn_dropout,
)

dec_config = dict(
    vocab_size=dec_vocab_size,
    context_length=dec_context_len,
    embedding_size=dec_embed_dim,
    n_heads=dec_attn_heads,
    n_layers=dec_hidden_layers,
    use_bias=dec_use_bias,
    ffn_act=dec_ffn_act,
    expansion_factor=dec_expansion,
    dropout_rate=dec_ffn_dropout,
)

cllm = CLLM(encoder_kwargs=enc_config, decoder_kwargs=dec_config, batch_size=train_config["batch_size"])
cllm.summary()

# test predict
cllm.predict(val_ds.take(1))

Model: "cllm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 transformer_encoder (Transf  multiple                 2835456   
 ormerEncoder)                                                   
                                                                 
 transformer_decoder (Transf  multiple                 3358720   
 ormerDecoder)                                                   
                                                                 
Total params: 6,194,176
Trainable params: 6,194,176
Non-trainable params: 0
_________________________________________________________________


2023-04-14 18:10:06.984130: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:637] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-04-14 18:10:07.129503: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [9]
	 [[{{node Placeholder/_0}}]]
2023-04-14 18:10:07.129780: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [9]
	 [[{{node Placeholder/_0}}]]




(array([[[ 2.33853399e-03, -1.05445618e-02,  1.57787912e-02, ...,
          -1.89837590e-02,  1.69920251e-02,  2.07580067e-03],
         [ 1.81162730e-02, -2.89396918e-03,  1.18544949e-02, ...,
           7.24603003e-03, -6.87501905e-03,  2.85925041e-03],
         [ 2.03935243e-03, -1.11114727e-02,  2.24376712e-02, ...,
          -7.11629540e-03, -5.36899548e-03, -1.48997931e-02],
         ...,
         [ 3.96782812e-03, -2.57622148e-03,  1.71221849e-02, ...,
          -1.81406215e-02,  5.58015425e-03,  1.77714471e-02],
         [ 7.10261520e-03, -4.50161286e-03,  1.51792038e-02, ...,
          -1.91542562e-02,  9.30690207e-03,  1.23434113e-02],
         [ 8.44348315e-03,  1.58557901e-03,  2.57380493e-02, ...,
          -1.00576598e-02,  1.01305009e-03,  2.15468369e-02]],
 
        [[ 8.41643475e-03, -1.52913788e-02,  1.05650434e-02, ...,
          -2.31477879e-02,  1.35068120e-02, -9.32093151e-03],
         [ 1.14050889e-02, -9.51392762e-03,  1.29568614e-02, ...,
          -1.71991643

In [20]:
from spearecode.optimizers import AdamWeightDecay, WarmUpCosineDecay

approx_total_steps = N_TOTAL_RECS*100
approx_val_steps = N_VAL_RECS*100
approx_train_steps = approx_total_steps-approx_val_steps

optimizer_config = dict(
    use_basic_adam=True,
    use_cdecay_lr=True,
    weight_decay_rate=0.1,
    clipnorm=True,
    gradient_clip_norm=1.0,
    beta_1=1.0,
    beta_2=0.95,
    exclude_from_weight_decay = ['layer_normalization', 'bias'],
)

lr_config = dict(
    init_lr=0.001,
    min_lr=5e-05,
    decay_portion=1.0,
    warmup_portion=0.05,
    hold_portion=0.01,
    total_steps=approx_train_steps,
    alpha=0.0,
    decay_steps=approx_train_steps,
    warmup_steps=int(approx_train_steps*0.05),
    hold_steps=int(approx_train_steps*0.01),
)

# Instantiate our learning rate (or lr-schedule)
if optimizer_config["use_cdecay_lr"]:
    optimizer_config.pop("use_cdecay_lr")
    lr=WarmUpCosineDecay(**lr_config)
else:
    lr=lr_config["init_lr"]

# Instantiate our optimizer (AdamW or just vanilla Adam)
if not optimizer_config["use_basic_adam"]:
    optimizer_config.pop("use_basic_adam")
    optimizer = AdamWeightDecay(learning_rate=lr, **optimizer_config)
else:
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    
optimizer

<keras.optimizers.adam.Adam at 0x7f9c364a4d90>

In [21]:
from spearecode.callbacks import get_callbacks

CKPT_DIR = os.path.join(MODELS_DIR, "ckpts")
if not os.path.isdir(CKPT_DIR): os.makedirs(CKPT_DIR, exist_ok=True)

cb_config = dict(
    ckpt_dir=CKPT_DIR,
    save_weights_only=True,
    use_early_stopping=True,
    es_patience=10,
    verbose=1,
)

cb_list = get_callbacks(cb_config)
cb_list

[<keras.callbacks.ModelCheckpoint at 0x7f9c2c142430>,
 <keras.callbacks.EarlyStopping at 0x7f9c2c1427f0>]

In [22]:
loss_fns = [
    tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # ENCODER MLM LOSS
    tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # DECODER AR  LOSS
]

metrics = [
    #TBD
]

In [23]:
# loss_weights=[0.5, 0.5]
# metrics = TBD
cllm.compile(optimizer, loss=loss_fns)
history = cllm.fit(train_ds, validation_data=val_ds, epochs=train_config["n_epochs"], callbacks=cb_list)

Epoch 1/100


2023-04-14 18:14:49.408932: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x5ab60d10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-14 18:14:49.408951: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA RTX A4000, Compute Capability 8.6
2023-04-14 18:14:49.412920: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-04-14 18:14:49.686645: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
2023-04-14 18:14:49.710886: E ten

InternalError: Graph execution error:

Detected at node 'StatefulPartitionedCall_21' defined at (most recent call last):
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/paperspace/.local/lib/python3.9/site-packages/traitlets/config/application.py", line 1041, in launch_instance
      app.start()
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/home/paperspace/.local/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/home/paperspace/.local/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2885, in run_cell
      result = self._run_cell(
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in _run_cell
      return runner(coro)
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3139, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3318, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/paperspace/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_8549/1517078499.py", line 4, in <module>
      history = cllm.fit(train_ds, validation_data=val_ds, epochs=train_config["n_epochs"], callbacks=cb_list)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/engine/training.py", line 1685, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/engine/training.py", line 1284, in train_function
      return step_function(self, iterator)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/engine/training.py", line 1268, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/engine/training.py", line 1249, in run_step
      outputs = model.train_step(data)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/engine/training.py", line 1054, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 543, in minimize
      self.apply_gradients(grads_and_vars)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 1174, in apply_gradients
      return super().apply_gradients(grads_and_vars, name=name)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 650, in apply_gradients
      iteration = self._internal_apply_gradients(grads_and_vars)
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 1200, in _internal_apply_gradients
      return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 1250, in _distributed_apply_gradients_fn
      distribution.extended.update(
    File "/home/paperspace/.local/lib/python3.9/site-packages/keras/optimizers/optimizer.py", line 1245, in apply_grad_to_update_var
      return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'StatefulPartitionedCall_21'
RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:618) dnn != nullptr 
	 [[{{node StatefulPartitionedCall_21}}]] [Op:__inference_train_function_16869]

In [15]:
x,y,s = next(iter(train_ds))

In [19]:
x

(<tf.Tensor: shape=(32, 128), dtype=int32, numpy=
 array([[   0,    0,    0, ...,  549, 1124, 2804],
        [7938,   92,    9, ..., 7938,   48,   83],
        [   9,  460, 2031, ...,  707, 3554, 7938],
        ...,
        [ 155, 4992,   29, ..., 7938,   44, 1248],
        [   0,    0,    0, ..., 1102, 1255,    4],
        [5005, 2379, 7970, ..., 7953, 7930, 4513]], dtype=int32)>,
 <tf.Tensor: shape=(32, 64), dtype=int32, numpy=
 array([[  72,   46, 1895, ..., 1204,   65, 4443],
        [   4, 7977, 2402, ...,    0,    0,    0],
        [7483,  353,  105, ..., 7981,    4,   11],
        ...,
        [1351,  117, 7969, ..., 2633, 7953, 7934],
        [ 198, 1502,  275, ..., 7939,    4,   11],
        [   4,   11, 3593, ...,    4,   11, 1121]], dtype=int32)>)