In [None]:
import os
import numpy as np
import tensorflow as tf

# -----------------------
# Data Download & Preprocessing
# -----------------------
print("[STATUS] Data Downloading & Preparation: Starting")
hf_files = {
    # "tokenized_books_1.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_1.npy?download=true",
    # "tokenized_books_2.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_2.npy?download=true",
    # "tokenized_books_3.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_3.npy?download=true",
    # "tokenized_books_4.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_4.npy?download=true",
    # "tokenized_books_5.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_5.npy?download=true",
    # "tokenized_books_6.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_6.npy?download=true",
    # "tokenized_books_7.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_books_7.npy?download=true",
    # "tokenized_conversations_1.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_conversations_1.npy?download=true",
    # "tokenized_conversations_2.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_conversations_2.npy?download=true",
    # "tokenized_conversations_3.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_conversations_3.npy?download=true",
    # "tokenized_conversations_4.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_conversations_4.npy?download=true",
    # "tokenized_conversations_5.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_conversations_5.npy?download=true",
    #"april_book.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_part_book.npy?download=true",
    "bom_book.npy": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_part_bom.npy?download=true"
}

local_files = []
for fname, url in hf_files.items():
    print(f"[STATUS] Downloading {fname}")
    local_path = tf.keras.utils.get_file(fname, url)
    local_files.append(local_path)

all_tokens = []
for file in local_files:
    tokens = np.load(file, allow_pickle=False)
    all_tokens.append(tokens)

# Combine all loaded tokenized data
combined_tokens = np.concatenate(all_tokens)
print(f"Total tokens across files: {len(combined_tokens)}")

# -----------------------
# Create Sequences, Shuffle, and Write TFRecord
# -----------------------

# Define sequence length (adjust if needed)
SEQUENCE_LENGTH = 512
num_sequences = len(combined_tokens) // SEQUENCE_LENGTH
print(f"Total sequences of length {SEQUENCE_LENGTH}: {num_sequences}")

# Only keep complete sequences and reshape accordingly
sequences = combined_tokens[:num_sequences * SEQUENCE_LENGTH].reshape(num_sequences, SEQUENCE_LENGTH)

# Create a tf.data.Dataset from the sequences
dataset = tf.data.Dataset.from_tensor_slices(sequences)

# Shuffle the dataset using a buffer (parallel shuffling is built-in)
BUFFER_SIZE = 500_000  # Adjust based on your memory/needs
dataset = dataset.shuffle(BUFFER_SIZE, reshuffle_each_iteration=False)

# Optionally, you can batch and prefetch if further processing is required
# dataset = dataset.batch(32).prefetch(tf.data.AUTOTUNE)

# Define a function to serialize a sequence
def serialize_example(sequence):
    feature = {'tokens': tf.train.Feature(int64_list=tf.train.Int64List(value=sequence))}
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()

tfrecord_filename = "tokenized_books_part4.tfrecord"
print(f"[STATUS] Saving dataset as TFRecord to {tfrecord_filename}...")

# Write the shuffled sequences to TFRecord
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
    # Using parallel processing with tf.data might not directly speed up TFRecordWriter,
    # but the dataset preparation (shuffling) is parallelized.
    for sequence in dataset:
        writer.write(serialize_example(sequence.numpy()))

print(f"✅ Saved shuffled dataset as {tfrecord_filename}")

# Conversations data is already all clean - 1_139_206 sequences for conversations 1
# Conversations data is already all clean - 616_193 sequences for conversations 2

# We end up with 1_755_399 sequences for a total of 898_764_288 million conversation tokens +165 april +712 bom
# in other words .9 billion tokens.


[STATUS] Data Downloading & Preparation: Starting
[STATUS] Downloading bom_book.npy
Downloading data from https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/tokenized_part_bom.npy?download=true
[1m1460336/1460336[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Total tokens across files: 365052
Total sequences of length 512: 712
[STATUS] Saving dataset as TFRecord to tokenized_books_part4.tfrecord...
✅ Saved shuffled dataset as tokenized_books_part4.tfrecord


In [None]:
!pip install fasttext
!pip install tiktoken

import os
import tensorflow as tf
import tiktoken
import fasttext
import numpy as np
import concurrent.futures

# -----------------------
# Setup: load fastText model and initialize tokenizer
# -----------------------
model_path = "lid.176.ftz"
if not os.path.exists(model_path):
    !wget -O {model_path} https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz

ft_model = fasttext.load_model(model_path)
encoding = tiktoken.get_encoding("gpt2")

# -----------------------
# Filtering function with newline removal
# -----------------------
def is_sequence_english_fasttext(sequence, encoding, threshold=0.95):
    """
    Decode the token sequence using the provided encoding, remove newlines,
    then use fastText to detect the language.
    Returns True if the detected language is English with probability >= threshold.
    """
    text = encoding.decode(sequence)
    # Remove newlines (fastText requires a single line)
    text = text.replace("\n", " ")

    # Monkey-patch np.array temporarily to avoid fastText copy error.
    original_np_array = np.array
    try:
        np.array = lambda obj, copy=False: np.asarray(obj, copy=True)
        labels, probs = ft_model.predict(text, k=1)
    finally:
        np.array = original_np_array

    label = labels[0]  # e.g., '__label__en'
    prob = probs[0]
    return label == '__label__en' and prob >= threshold

def process_sequence(seq):
    """Return the sequence if it passes the English filter; otherwise, return None."""
    if is_sequence_english_fasttext(seq, encoding):
        return seq
    return None

# -----------------------
# Load the entire TFRecord dataset
# -----------------------
input_tfrecord = "/content/tokenized_books_part4.tfrecord"

feature_description = {
    'tokens': tf.io.FixedLenFeature([512], tf.int64),
}

def parse_function(example_proto):
    return tf.io.parse_single_example(example_proto, feature_description)

dataset = tf.data.TFRecordDataset(input_tfrecord)
dataset = dataset.map(parse_function, num_parallel_calls=tf.data.AUTOTUNE)

# Comment out or remove the sharding line below to process the entire file.
# dataset = dataset.shard(num_shards=8, index=0)

sequences = []
for record in dataset:
    seq = record['tokens'].numpy()
    sequences.append(seq)
print(f"Total sequences in the file: {len(sequences)}")

# -----------------------
# Testing: print first 6 decoded sequences (before filtering)
# -----------------------
print("\n--- First 6 decoded sequences (before filtering) ---")
for i, seq in enumerate(sequences[:6]):
    decoded_text = encoding.decode(seq)
    print(f"\nSequence {i+1}:\n{decoded_text}\n{'-'*40}")

# -----------------------
# Filter sequences in parallel
# -----------------------
with concurrent.futures.ThreadPoolExecutor() as executor:
    results = list(executor.map(process_sequence, sequences))
filtered_sequences = [seq for seq in results if seq is not None]
print(f"\nTotal sequences kept after filtering: {len(filtered_sequences)}")

# -----------------------
# Testing: print first 6 decoded sequences (after filtering)
# -----------------------
print("\n--- First 6 decoded sequences (after filtering) ---")
for i, seq in enumerate(filtered_sequences[:6]):
    decoded_text = encoding.decode(seq)
    print(f"\nFiltered Sequence {i+1}:\n{decoded_text}\n{'-'*40}")

# -----------------------
# Write filtered sequences to a new TFRecord file.
# -----------------------
output_tfrecord = "/content/tokenized_books_part4_filtered_full.tfrecord"

def serialize_example(sequence):
    feature = {'tokens': tf.train.Feature(int64_list=tf.train.Int64List(value=sequence))}
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()

with tf.io.TFRecordWriter(output_tfrecord) as writer:
    for seq in filtered_sequences:
        writer.write(serialize_example(seq))

print(f"\nFiltered sequences written to {output_tfrecord}")


# books one went from 1_247_288 to 1_089_431 sequences, a 13% cleanup. Probably cleaned up my tags as well.
# books two went from 1_445_745 to 1_276_920 sequences, a 12% cleanup. Probably cleaned up my tags as well. That's ok though.

# I'm left with 1_211_571_712 book tokens.
# In other words 1.12 billion book tokens.


Collecting fasttext
  Downloading fasttext-0.9.3.tar.gz (73 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/73.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.4/73.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pybind11>=2.2 (from fasttext)
  Using cached pybind11-2.13.6-py3-none-any.whl.metadata (9.5 kB)
Using cached pybind11-2.13.6-py3-none-any.whl (243 kB)
Building wheels for collected packages: fasttext
  Building wheel for fasttext (pyproject.toml) ... [?25l[?25hdone
  Created wheel for fasttext: filename=fasttext-0.9.3-cp311-cp311-linux_x86_64.whl size=4313501 sha256=990b92b17b002089e7a6452c974b50e35d85b5d9147a37c69af0c3112b2a11e0
  Stored in directory: /root/.cache/pip/wheels/65/4f/35/5057db0249224e9ab55a51

In [None]:
import tensorflow as tf
import requests
import os

# Coming from a total of 1.12 billion book tokens + .9 billion conversation tokens
# final data set size (token-wise) ~ 2 billion clean tokens divided in 4 datasets
# 2 books datasets, 2 conversation datasets
# these datasets contained tokenized, sequenced, and shuffled data.
# I now check how many epochs are needed to do a full round of trianing for each dataset by using 3k steps per epoch.
# basically figuring out how long training will take

# URLs for the TFRecord files
tfrecord_urls = {
    "books_part1": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_1_sequenced_shuffled_books.tfrecord?download=true",
    "books_part2": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_2_sequenced_shuffled_books.tfrecord?download=true",
    "conversations_part3": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_3_sequenced_shuffled_conversations.tfrecord?download=true",
    "conversations_part4": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_4_sequenced_shuffled_conversations.tfrecord?download=true",
}

# Download and save TFRecord files
os.makedirs("/content/tfrecords", exist_ok=True)
tfrecord_files = {}

for name, url in tfrecord_urls.items():
    local_path = f"/content/tfrecords/{name}.tfrecord"
    if not os.path.exists(local_path):  # Avoid re-downloading if already present
        print(f"Downloading {name}...")
        response = requests.get(url, stream=True)
        with open(local_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=512):
                if chunk:
                    f.write(chunk)
        print(f"✅ Downloaded {name}")
    else:
        print(f"⚡ {name} already exists, skipping download.")
    tfrecord_files[name] = local_path

# Function to count sequences in a TFRecord file
def count_sequences(tfrecord_path):
    count = 0
    raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
    for _ in raw_dataset:
        count += 1
    return count

# Count sequences in each file
dataset_sizes = {name: count_sequences(path) for name, path in tfrecord_files.items()}

# Print the dataset sizes
print("\n📊 **Dataset Sequence Counts**")
for name, size in dataset_sizes.items():
    print(f"{name}: {size} sequences")

# Calculate recommended epochs based on 3000 steps per epoch
# Each step processes BATCH_SIZE sequences
BATCH_SIZE = 32  # Assuming 32 as batch size
STEPS_PER_EPOCH = 3000  # Given in the question

def recommended_epochs(dataset_size):
    total_steps = dataset_size // BATCH_SIZE
    return max(1, total_steps // STEPS_PER_EPOCH)  # Ensure at least 1 epoch

# Compute suggested epochs for each dataset
recommended_epochs_per_dataset = {name: recommended_epochs(size) for name, size in dataset_sizes.items()}

# Print recommended epochs
print("\n📌 **Recommended Epochs for Each Dataset**")
for name, epochs in recommended_epochs_per_dataset.items():
    print(f"{name}: {epochs} epochs")


Downloading books_part1...
✅ Downloaded books_part1
Downloading books_part2...
✅ Downloaded books_part2
Downloading conversations_part3...
✅ Downloaded conversations_part3
Downloading conversations_part4...
✅ Downloaded conversations_part4

📊 **Dataset Sequence Counts**
books_part1: 1089431 sequences
books_part2: 1276920 sequences
conversations_part3: 1139206 sequences
conversations_part4: 616193 sequences

📌 **Recommended Epochs for Each Dataset**
books_part1: 11 epochs
books_part2: 13 epochs
conversations_part3: 11 epochs
conversations_part4: 6 epochs


In [None]:
import os
# Suppress INFO-level messages (set to "2" to show warnings and errors only)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import io
import math
import time
import numpy as np
import tensorflow as tf
import tiktoken
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, initializers, optimizers
from tensorflow.keras.mixed_precision import LossScaleOptimizer
from tensorflow.keras.utils import Progbar

# -----------------------
# Utility: Download TFRecord Files from URLs
# -----------------------
def download_tfrecord_files(url_dict):
    local_paths = {}
    for fname, url in url_dict.items():
        try:
            path = tf.keras.utils.get_file(fname, url)
            local_paths[fname] = path
            tf.get_logger().info(f"Downloaded {fname} to {path}")
        except Exception as e:
            tf.get_logger().error(f"Error downloading {fname} from {url}: {e}")
            raise e
    return local_paths

# -----------------------
# URLs for the TFRecord files.
# -----------------------
tfrecord_urls = {
    "books_part1": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_1_sequenced_shuffled_books.tfrecord?download=true",
    "books_part2": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_2_sequenced_shuffled_books.tfrecord?download=true",
    "conversations_part3": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_3_sequenced_shuffled_conversations.tfrecord?download=true",
    "conversations_part4": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_4_sequenced_shuffled_conversations.tfrecord?download=true",
}

# -----------------------
# TFRecord Parsing Function
# -----------------------
def parse_tfrecord(example_proto):
    feature_description = {'tokens': tf.io.FixedLenFeature([512], tf.int64)}
    parsed_example = tf.io.parse_single_example(example_proto, feature_description)
    tokens = parsed_example['tokens']
    return tokens[:-1], tokens[1:]

# -----------------------
# Pre-computed Dataset Information
# -----------------------

# 📊 **Dataset Sequence Counts**
# books_part1: 1089431 sequences
# books_part2: 1276920 sequences
# conversations_part3: 1139206 sequences
# conversations_part4: 616193 sequences

# 📌 **Recommended Epochs for Each Dataset**
# books_part1: 11 epochs
# books_part2: 13 epochs
# conversations_part3: 11 epochs
# conversations_part4: 6 epochs

dataset_info = {
    "books_part1": {"path": None, "count": 1089431, "epochs": 11},
    "books_part2": {"path": None, "count": 1276920, "epochs": 13},
    "conversations_part3": {"path": None, "count": 1139206, "epochs": 11},
    "conversations_part4": {"path": None, "count": 616193, "epochs": 6},
}

# Download files and update dataset_info with file paths.
local_tfrecord_paths = download_tfrecord_files(tfrecord_urls)
for ds_name in dataset_info:
    dataset_info[ds_name]["path"] = local_tfrecord_paths[ds_name]

# -----------------------
# Transformer Model Components & Helper Functions
# -----------------------
class RMSNorm(layers.Layer):
    def __init__(self, epsilon=1e-8, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
    def build(self, input_shape):
        self.gamma = self.add_weight(name="gamma",
                                     shape=input_shape[-1:],
                                     initializer="ones",
                                     trainable=True)
        super(RMSNorm, self).build(input_shape)
    def call(self, inputs):
        gamma = tf.cast(self.gamma, inputs.dtype)
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
        return inputs * gamma / rms

def apply_rope(x, sin, cos):
    sin = tf.cast(sin, x.dtype)
    cos = tf.cast(cos, x.dtype)
    head_dim = tf.shape(x)[-1]
    x = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [head_dim // 2, 2]], axis=0))
    x1, x2 = x[..., 0], x[..., 1]
    sin_tensor = tf.expand_dims(tf.expand_dims(sin, axis=0), axis=2)
    cos_tensor = tf.expand_dims(tf.expand_dims(cos, axis=0), axis=2)
    x_rotated_first = x1 * cos_tensor - x2 * sin_tensor
    x_rotated_second = x1 * sin_tensor + x2 * cos_tensor
    x = tf.stack([x_rotated_first, x_rotated_second], axis=-1)
    return tf.reshape(x, tf.concat([tf.shape(x)[:-2], [head_dim]], axis=0))

class TiedDense(layers.Layer):
    def __init__(self, tied_to, **kwargs):
        super(TiedDense, self).__init__(**kwargs)
        self.tied_to = tied_to
    def call(self, inputs):
        tied_embeddings = tf.cast(self.tied_to.embeddings, inputs.dtype)
        return tf.matmul(inputs, tied_embeddings, transpose_b=True)
    def get_config(self):
        config = super().get_config()
        config.update({"tied_to": self.tied_to.name})
        return config

class RotarySelfAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1, **kwargs):
        super(RotarySelfAttention, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads")
        self.dropout_rate = dropout_rate
    def build(self, input_shape):
        robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.query_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.key_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.value_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.out_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.dropout     = layers.Dropout(self.dropout_rate)
        super(RotarySelfAttention, self).build(input_shape)
    def call(self, inputs, training=False, use_causal_mask=True):
        batch_size = tf.shape(inputs)[0]
        seq_len    = tf.shape(inputs)[1]
        query = self.query_dense(inputs)
        key   = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = tf.reshape(query, (batch_size, seq_len, self.num_heads, self.head_dim))
        key   = tf.reshape(key, (batch_size, seq_len, self.num_heads, self.head_dim))
        value = tf.reshape(value, (batch_size, seq_len, self.num_heads, self.head_dim))
        position = tf.cast(tf.range(seq_len), tf.float32)
        head_dim_int = self.head_dim
        inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, head_dim_int, 2), tf.float32) / tf.cast(head_dim_int, tf.float32)))
        sinusoid_inp = tf.tensordot(position, inv_freq, axes=0)
        sin = tf.sin(sinusoid_inp)
        cos = tf.cos(sinusoid_inp)
        query = apply_rope(query, sin, cos)
        key   = apply_rope(key, sin, cos)
        query = tf.transpose(query, perm=[0, 2, 1, 3])
        key   = tf.transpose(key, perm=[0, 2, 1, 3])
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        scaling = tf.cast(self.head_dim, query.dtype) ** -0.5
        query = query * scaling
        attn_logits = tf.matmul(query, key, transpose_b=True)
        if use_causal_mask:
            mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=query.dtype), -1, 0)
            mask = tf.reshape(mask, (1, 1, seq_len, seq_len))
            attn_logits = attn_logits * mask + tf.cast(-1e4, attn_logits.dtype) * (1 - mask)
        attn_weights = tf.nn.softmax(attn_logits, axis=-1)
        attn_weights = self.dropout(attn_weights, training=training)
        attn_output = tf.matmul(attn_weights, value)
        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
        attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.embed_dim))
        output = self.out_dense(attn_output)
        return output

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.attention = RotarySelfAttention(embed_dim, num_heads, dropout_rate)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = RMSNorm(epsilon=1e-8)
        self.ffn = models.Sequential([
            layers.Dense(ff_dim, activation=tf.nn.gelu,
                         kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02)),
            layers.Dense(embed_dim, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02))
        ])
        self.dropout2 = layers.Dropout(dropout_rate)
        self.norm2 = RMSNorm(epsilon=1e-8)
    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, training=training, use_causal_mask=True)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

def create_transformer_model(vocab_size, sequence_length, embed_dim, num_heads, ff_dim, num_layers, dropout_rate=0.1):
    inputs = layers.Input(shape=(sequence_length,), dtype=tf.int32)
    robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
    token_embedding = layers.Embedding(input_dim=vocab_size,
                                       output_dim=embed_dim,
                                       embeddings_initializer=robust_init,
                                       name="token_embedding")
    x = token_embedding(inputs)
    for i in range(num_layers):
        x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate, name=f"transformer_block_{i}")(x)
    x = RMSNorm(epsilon=1e-8, name="final_rmsnorm")(x)
    logits = TiedDense(token_embedding, name="output_projection")(x)
    logits = layers.Lambda(lambda x: tf.cast(x, tf.float32))(logits)
    return models.Model(inputs=inputs, outputs=logits)

# -----------------------
# Custom Perplexity Metric
# -----------------------
class Perplexity(tf.keras.metrics.Metric):
    def __init__(self, name='perplexity', **kwargs):
        super().__init__(name=name, **kwargs)
        self.ce_tracker = tf.keras.metrics.Mean(name="crossentropy_mean", dtype=tf.float32)
    def update_state(self, y_true, y_pred, sample_weight=None):
        ce = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
        self.ce_tracker.update_state(ce, sample_weight=sample_weight)
    def result(self):
        avg_ce = self.ce_tracker.result()
        return tf.exp(avg_ce)
    def reset_state(self):
        self.ce_tracker.reset_state()

# -----------------------
# Learning Rate and Weight Decay Schedules
# -----------------------
class WarmUpCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, total_steps, warmup_steps, alpha=0.0):
        super(WarmUpCosineDecay, self).__init__()
        self.initial_lr = initial_lr
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps
        self.alpha = alpha
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_lr = self.initial_lr * step / tf.cast(self.warmup_steps, tf.float32)
        cosine_steps = tf.maximum(step - tf.cast(self.warmup_steps, tf.float32), 0.0)
        total_cosine_steps = tf.maximum(tf.cast(self.total_steps - self.warmup_steps, tf.float32), 1.0)
        cosine_decay = 0.5 * (1 + tf.cos(np.pi * cosine_steps / total_cosine_steps))
        decayed_lr = self.alpha * self.initial_lr + (1 - self.alpha) * self.initial_lr * cosine_decay
        return tf.where(step < tf.cast(self.warmup_steps, tf.float32), warmup_lr, decayed_lr)
    def get_config(self):
        return {"initial_lr": self.initial_lr, "total_steps": self.total_steps,
                "warmup_steps": self.warmup_steps, "alpha": self.alpha}

class DynamicWeightDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, base_wd, lr_schedule):
        super(DynamicWeightDecay, self).__init__()
        self.base_lr = base_lr
        self.base_wd = base_wd
        self.lr_schedule = lr_schedule
    def __call__(self, step):
        current_lr = self.lr_schedule(step)
        return self.base_wd * (current_lr / self.base_lr)
    def get_config(self):
        return {"base_lr": self.base_lr, "base_wd": self.base_wd}

# -----------------------
# Helper Function to Exclude Certain Parameters from Weight Decay
# -----------------------
def should_apply_weight_decay(var):
    var_name = var.name.lower()
    if "bias" in var_name:
        return False
    if "norm" in var_name or "rmsnorm" in var_name:
        return False
    return True

# -----------------------
# Main Training Function with Interleaved Scheduling (Actual Version)
# -----------------------
def main():
    # --- Build schedule using the desired interleaved order ---
    desired_order = ["books_part2", "conversations_part3", "books_part1", "conversations_part4"]
    # Get the maximum epoch count among the datasets in the desired order.
    max_epochs = max(dataset_info[ds]["epochs"] for ds in desired_order)
    rounds = 4
    schedule = []
    for r in range(rounds):
        for e in range(1, max_epochs + 1):
            for ds in desired_order:
                if e <= dataset_info[ds]["epochs"]:
                    schedule.append((ds, e, dataset_info[ds]["epochs"], r + 1))
    total_scheduled_epochs = len(schedule)
    print(f"Total scheduled epochs (actual): {total_scheduled_epochs}")

    # Training parameters.
    steps_per_epoch = 3000
    total_steps = total_scheduled_epochs * steps_per_epoch
    warmup_steps = int(0.1 * total_steps)
    initial_lr = 1e-3
    BATCH_SIZE = 32

    # -----------------------
    # GPU Configuration, Mixed Precision & XLA
    # -----------------------
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    tf.config.optimizer.set_jit(True)
    strategy = tf.distribute.MirroredStrategy()
    print(f"Number of devices (actual): {strategy.num_replicas_in_sync}", flush=True)

    # Model Hyperparameters & Tokenizer Setup.
    SEQUENCE_LENGTH = 512
    # Updated hyperparameters to match GPT-1 (≈117M parameters)
    embed_dim = 768
    num_heads = 12
    ff_dim = 3072
    num_layers = 12
    dropout_rate = 0.1
    gpt2_encoding = tiktoken.get_encoding("gpt2")
    vocab_size = gpt2_encoding.n_vocab

    # -----------------------
    # Build the Model within the Strategy Scope.
    # -----------------------
    with strategy.scope():
        model = create_transformer_model(vocab_size, SEQUENCE_LENGTH - 1,
                                         embed_dim, num_heads, ff_dim, num_layers, dropout_rate)
        lr_schedule = WarmUpCosineDecay(initial_lr, total_steps, warmup_steps, alpha=0.0)
        dynamic_wd = DynamicWeightDecay(initial_lr, base_wd=1e-4, lr_schedule=lr_schedule)
        base_optimizer = optimizers.AdamW(
            learning_rate=lr_schedule,
            weight_decay=0.0,
            clipnorm=1.0
        )
        optimizer = LossScaleOptimizer(base_optimizer, dynamic=True)
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    # Print model summary.
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + "\n"))
    print(stream.getvalue(), flush=True)

    # -----------------------
    # Checkpointing Setup.
    # -----------------------
    checkpoint_dir = './checkpoints_actual'
    os.makedirs(checkpoint_dir, exist_ok=True)
    global_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer, epoch=global_epoch)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)
    initial_global_epoch = 0
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print(f"Restored (actual) from {ckpt_manager.latest_checkpoint}", flush=True)
        initial_global_epoch = int(global_epoch.numpy())

    # -----------------------
    # Metrics & Global Step.
    # -----------------------
    train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
    train_perplexity_metric = Perplexity(name='train_perplexity')
    val_loss_metric = tf.keras.metrics.Mean(name='val_loss')
    val_perplexity_metric = Perplexity(name='val_perplexity')
    global_step = tf.Variable(0, trainable=False, dtype=tf.int64)

    @tf.function
    def train_step(x_batch_train, y_batch_train, global_step):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
            scaled_loss = optimizer.get_scaled_loss(loss_value)
        scaled_grads = tape.gradient(scaled_loss, model.trainable_variables)
        grads = optimizer.get_unscaled_gradients(scaled_grads)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        global_step.assign_add(1)
        current_lr = lr_schedule(tf.cast(global_step, tf.float32))
        current_weight_decay = dynamic_wd(tf.cast(global_step, tf.float32))
        for var in model.trainable_variables:
            if should_apply_weight_decay(var):
                var.assign_sub(current_lr * current_weight_decay * var)
        return loss_value, logits

    # -----------------------
    # Lists to Store Metrics for Each Epoch
    # -----------------------
    train_losses = []
    val_losses = []
    train_perplexities = []
    val_perplexities = []

    # -----------------------
    # Main Training Loop (Actual Version)
    # -----------------------
    for sched_epoch in range(initial_global_epoch, total_scheduled_epochs):
        ds_name, epoch_in_ds, total_epochs_for_ds, current_round = schedule[sched_epoch]
        print(f"\n[Actual] Round {current_round} - Training dataset '{ds_name}', epoch {epoch_in_ds}/{total_epochs_for_ds}")
        info = dataset_info[ds_name]
        raw_dataset = tf.data.TFRecordDataset(info["path"])
        dataset = raw_dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
        # Use 1% for validation.
        val_size = int(0.01 * info["count"])
        val_dataset = dataset.take(val_size)
        train_dataset = dataset.skip(val_size)
        train_dataset = train_dataset.shuffle(10000, reshuffle_each_iteration=True)
        train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
        val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)
        val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

        train_loss_metric.reset_state()
        train_perplexity_metric.reset_state()
        progbar = Progbar(steps_per_epoch)
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset.take(steps_per_epoch)):
            loss_value, logits = train_step(x_batch_train, y_batch_train, global_step)
            train_loss_metric.update_state(loss_value)
            train_perplexity_metric.update_state(y_batch_train, logits)
            progbar.update(step + 1, values=[("loss", train_loss_metric.result().numpy()),
                                             ("perplexity", train_perplexity_metric.result().numpy())])

        # Validation loop.
        val_loss_metric.reset_state()
        val_perplexity_metric.reset_state()
        for x_batch_val, y_batch_val in val_dataset:
            val_logits = model(x_batch_val, training=False)
            val_loss = loss_fn(y_batch_val, val_logits)
            val_loss_metric.update_state(val_loss)
            val_perplexity_metric.update_state(y_batch_val, val_logits)

        current_train_loss = train_loss_metric.result().numpy()
        current_val_loss = val_loss_metric.result().numpy()
        current_train_perplexity = train_perplexity_metric.result().numpy()
        current_val_perplexity = val_perplexity_metric.result().numpy()

        train_losses.append(current_train_loss)
        val_losses.append(current_val_loss)
        train_perplexities.append(current_train_perplexity)
        val_perplexities.append(current_val_perplexity)

        print(f"[Actual] Epoch {sched_epoch + 1}/{total_scheduled_epochs}: Train Loss = {current_train_loss:.4f}, Train Perplexity = {current_train_perplexity:.4f}")
        print(f"[Actual] Epoch {sched_epoch + 1}/{total_scheduled_epochs}: Val Loss = {current_val_loss:.4f}, Val Perplexity = {current_val_perplexity:.4f}")

        global_epoch.assign(sched_epoch + 1)
        saved_path = ckpt_manager.save()
        print(f"[Actual] Checkpoint saved at: {saved_path}", flush=True)

    # -----------------------
    # Plotting results for all epochs.
    # -----------------------
    epochs_range = range(1, total_scheduled_epochs + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, train_losses, marker='o', label='Training Loss')
    plt.plot(epochs_range, val_losses, marker='o', label='Validation Loss')
    plt.title('Actual: Training vs. Validation Loss Over All Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, train_perplexities, marker='o', label='Training Perplexity')
    plt.plot(epochs_range, val_perplexities, marker='o', label='Validation Perplexity')
    plt.title('Actual: Training vs. Validation Perplexity Over All Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == '__main__':
    main()


In [None]:
import os
import tensorflow as tf
import io
from tensorflow.keras import layers, models, initializers
import tiktoken

# --- Custom Layers and Model Architecture Definitions ---
class RMSNorm(layers.Layer):
    def __init__(self, epsilon=1e-8, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
    def build(self, input_shape):
        self.gamma = self.add_weight(name="gamma",
                                     shape=input_shape[-1:],
                                     initializer="ones",
                                     trainable=True)
        super(RMSNorm, self).build(input_shape)
    def call(self, inputs):
        gamma = tf.cast(self.gamma, inputs.dtype)
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
        return inputs * gamma / rms

class TiedDense(layers.Layer):
    def __init__(self, tied_to, **kwargs):
        super(TiedDense, self).__init__(**kwargs)
        self.tied_to = tied_to
    def call(self, inputs):
        tied_embeddings = tf.cast(self.tied_to.embeddings, inputs.dtype)
        return tf.matmul(inputs, tied_embeddings, transpose_b=True)
    def get_config(self):
        config = super().get_config()
        config.update({"tied_to": self.tied_to.name})
        return config

class RotarySelfAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1, **kwargs):
        super(RotarySelfAttention, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads")
        self.dropout_rate = dropout_rate
    def build(self, input_shape):
        robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.query_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.key_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.value_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.out_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.dropout     = layers.Dropout(self.dropout_rate)
        super(RotarySelfAttention, self).build(input_shape)
    def call(self, inputs, training=False, use_causal_mask=True):
        batch_size = tf.shape(inputs)[0]
        seq_len    = tf.shape(inputs)[1]
        query = self.query_dense(inputs)
        key   = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = tf.reshape(query, (batch_size, seq_len, self.num_heads, self.head_dim))
        key   = tf.reshape(key, (batch_size, seq_len, self.num_heads, self.head_dim))
        value = tf.reshape(value, (batch_size, seq_len, self.num_heads, self.head_dim))
        position = tf.cast(tf.range(seq_len), tf.float32)
        head_dim_int = self.head_dim
        inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, head_dim_int, 2), tf.float32) / tf.cast(head_dim_int, tf.float32)))
        sinusoid_inp = tf.tensordot(position, inv_freq, axes=0)
        sin = tf.sin(sinusoid_inp)
        cos = tf.cos(sinusoid_inp)
        # Apply RoPE
        sin = tf.cast(sin, query.dtype)
        cos = tf.cast(cos, query.dtype)
        def apply_rope(x):
            head_dim = tf.shape(x)[-1]
            x = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [head_dim // 2, 2]], axis=0))
            x1, x2 = x[..., 0], x[..., 1]
            sin_tensor = tf.expand_dims(tf.expand_dims(sin, axis=0), axis=2)
            cos_tensor = tf.expand_dims(tf.expand_dims(cos, axis=0), axis=2)
            x_rotated_first = x1 * cos_tensor - x2 * sin_tensor
            x_rotated_second = x1 * sin_tensor + x2 * cos_tensor
            x = tf.stack([x_rotated_first, x_rotated_second], axis=-1)
            return tf.reshape(x, tf.concat([tf.shape(x)[:-2], [head_dim]], axis=0))
        query = apply_rope(query)
        key   = apply_rope(key)
        query = tf.transpose(query, perm=[0, 2, 1, 3])
        key   = tf.transpose(key, perm=[0, 2, 1, 3])
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        scaling = tf.cast(self.head_dim, query.dtype) ** -0.5
        query = query * scaling
        attn_logits = tf.matmul(query, key, transpose_b=True)
        if use_causal_mask:
            mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=query.dtype), -1, 0)
            mask = tf.reshape(mask, (1, 1, seq_len, seq_len))
            attn_logits = attn_logits * mask + tf.cast(-1e4, attn_logits.dtype) * (1 - mask)
        attn_weights = tf.nn.softmax(attn_logits, axis=-1)
        attn_weights = self.dropout(attn_weights, training=training)
        attn_output = tf.matmul(attn_weights, value)
        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
        attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.embed_dim))
        output = self.out_dense(attn_output)
        return output

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.attention = RotarySelfAttention(embed_dim, num_heads, dropout_rate)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = RMSNorm(epsilon=1e-8)
        self.ffn = models.Sequential([
            layers.Dense(ff_dim, activation=tf.nn.gelu,
                         kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02)),
            layers.Dense(embed_dim, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02))
        ])
        self.dropout2 = layers.Dropout(dropout_rate)
        self.norm2 = RMSNorm(epsilon=1e-8)
    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, training=training, use_causal_mask=True)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

def create_transformer_model(vocab_size, sequence_length, embed_dim, num_heads, ff_dim, num_layers, dropout_rate=0.1):
    inputs = layers.Input(shape=(sequence_length,), dtype=tf.int32)
    robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
    token_embedding = layers.Embedding(input_dim=vocab_size,
                                       output_dim=embed_dim,
                                       embeddings_initializer=robust_init,
                                       name="token_embedding")
    x = token_embedding(inputs)
    for i in range(num_layers):
        x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate, name=f"transformer_block_{i}")(x)
    x = RMSNorm(epsilon=1e-8, name="final_rmsnorm")(x)
    logits = TiedDense(token_embedding, name="output_projection")(x)
    logits = layers.Lambda(lambda x: tf.cast(x, tf.float32))(logits)
    return models.Model(inputs=inputs, outputs=logits)

# --- Model Hyperparameters (matching GPT-1 style, ~117M parameters) ---
SEQUENCE_LENGTH = 512
embed_dim = 768
num_heads = 12
ff_dim = 3072
num_layers = 12
dropout_rate = 0.1

# Load GPT-2 encoding to get vocab size
gpt2_encoding = tiktoken.get_encoding("gpt2")
vocab_size = gpt2_encoding.n_vocab

# Build the model
model = create_transformer_model(vocab_size, SEQUENCE_LENGTH - 1,
                                 embed_dim, num_heads, ff_dim, num_layers, dropout_rate)

# --- Restore Weights from the Checkpoints ---
checkpoint_dir = './checkpoints_actual'
global_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)
ckpt = tf.train.Checkpoint(model=model, epoch=global_epoch)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Checkpoint restored from:", ckpt_manager.latest_checkpoint)
else:
    print("No checkpoint found!")

# --- Save the Entire Model to a File ---
export_path = "./exported_model"
model.save(export_path)
print("Model saved to", export_path)


In [None]:
import os
import tensorflow as tf
import numpy as np
import tiktoken

# --- Re-define Custom Layers (for model loading) ---
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-8, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
    def build(self, input_shape):
        self.gamma = self.add_weight(name="gamma",
                                     shape=input_shape[-1:],
                                     initializer="ones",
                                     trainable=True)
        super(RMSNorm, self).build(input_shape)
    def call(self, inputs):
        gamma = tf.cast(self.gamma, inputs.dtype)
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
        return inputs * gamma / rms

class TiedDense(tf.keras.layers.Layer):
    def __init__(self, tied_to, **kwargs):
        super(TiedDense, self).__init__(**kwargs)
        self.tied_to = tied_to
    def call(self, inputs):
        tied_embeddings = tf.cast(self.tied_to.embeddings, inputs.dtype)
        return tf.matmul(inputs, tied_embeddings, transpose_b=True)
    def get_config(self):
        config = super().get_config()
        config.update({"tied_to": self.tied_to.name})
        return config

class RotarySelfAttention(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1, **kwargs):
        super(RotarySelfAttention, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads")
        self.dropout_rate = dropout_rate
    def build(self, input_shape):
        robust_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.query_dense = tf.keras.layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.key_dense   = tf.keras.layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.value_dense = tf.keras.layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.out_dense   = tf.keras.layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.dropout     = tf.keras.layers.Dropout(self.dropout_rate)
        super(RotarySelfAttention, self).build(input_shape)
    def call(self, inputs, training=False, use_causal_mask=True):
        batch_size = tf.shape(inputs)[0]
        seq_len    = tf.shape(inputs)[1]
        query = self.query_dense(inputs)
        key   = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = tf.reshape(query, (batch_size, seq_len, self.num_heads, self.head_dim))
        key   = tf.reshape(key, (batch_size, seq_len, self.num_heads, self.head_dim))
        value = tf.reshape(value, (batch_size, seq_len, self.num_heads, self.head_dim))
        position = tf.cast(tf.range(seq_len), tf.float32)
        head_dim_int = self.head_dim
        inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, head_dim_int, 2), tf.float32) / tf.cast(head_dim_int, tf.float32)))
        sinusoid_inp = tf.tensordot(position, inv_freq, axes=0)
        sin = tf.sin(sinusoid_inp)
        cos = tf.cos(sinusoid_inp)
        sin = tf.cast(sin, query.dtype)
        cos = tf.cast(cos, query.dtype)
        def apply_rope(x):
            head_dim = tf.shape(x)[-1]
            x = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [head_dim // 2, 2]], axis=0))
            x1, x2 = x[..., 0], x[..., 1]
            sin_tensor = tf.expand_dims(tf.expand_dims(sin, axis=0), axis=2)
            cos_tensor = tf.expand_dims(tf.expand_dims(cos, axis=0), axis=2)
            x_rotated_first = x1 * cos_tensor - x2 * sin_tensor
            x_rotated_second = x1 * sin_tensor + x2 * cos_tensor
            x = tf.stack([x_rotated_first, x_rotated_second], axis=-1)
            return tf.reshape(x, tf.concat([tf.shape(x)[:-2], [head_dim]], axis=0))
        query = apply_rope(query)
        key   = apply_rope(key)
        query = tf.transpose(query, perm=[0, 2, 1, 3])
        key   = tf.transpose(key, perm=[0, 2, 1, 3])
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        scaling = tf.cast(self.head_dim, query.dtype) ** -0.5
        query = query * scaling
        attn_logits = tf.matmul(query, key, transpose_b=True)
        if use_causal_mask:
            mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=query.dtype), -1, 0)
            mask = tf.reshape(mask, (1, 1, seq_len, seq_len))
            attn_logits = attn_logits * mask + tf.cast(-1e4, attn_logits.dtype) * (1 - mask)
        attn_weights = tf.nn.softmax(attn_logits, axis=-1)
        attn_weights = self.dropout(attn_weights, training=training)
        attn_output = tf.matmul(attn_weights, value)
        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
        attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.embed_dim))
        output = self.out_dense(attn_output)
        return output

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.attention = RotarySelfAttention(embed_dim, num_heads, dropout_rate)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.norm1 = RMSNorm(epsilon=1e-8)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation=tf.nn.gelu,
                                  kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)),
            tf.keras.layers.Dense(embed_dim, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))
        ])
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.norm2 = RMSNorm(epsilon=1e-8)
    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, training=training, use_causal_mask=True)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

def create_transformer_model(vocab_size, sequence_length, embed_dim, num_heads, ff_dim, num_layers, dropout_rate=0.1):
    inputs = tf.keras.layers.Input(shape=(sequence_length,), dtype=tf.int32)
    robust_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    token_embedding = tf.keras.layers.Embedding(input_dim=vocab_size,
                                       output_dim=embed_dim,
                                       embeddings_initializer=robust_init,
                                       name="token_embedding")
    x = token_embedding(inputs)
    for i in range(num_layers):
        x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate, name=f"transformer_block_{i}")(x)
    x = RMSNorm(epsilon=1e-8, name="final_rmsnorm")(x)
    logits = TiedDense(token_embedding, name="output_projection")(x)
    logits = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(logits)
    return tf.keras.models.Model(inputs=inputs, outputs=logits)

# --- Hyperparameters (should match those used during saving) ---
SEQUENCE_LENGTH = 512
embed_dim = 768
num_heads = 12
ff_dim = 3072
num_layers = 12
dropout_rate = 0.1
gpt2_encoding = tiktoken.get_encoding("gpt2")
vocab_size = gpt2_encoding.n_vocab

# --- Load the Saved Model ---
export_path = "./exported_model"
model = tf.keras.models.load_model(export_path, custom_objects={
    "RMSNorm": RMSNorm,
    "TiedDense": TiedDense,
    "RotarySelfAttention": RotarySelfAttention,
    "TransformerBlock": TransformerBlock
})
print("Model loaded from", export_path)

# --- Text Generation Function ---
def generate_text(model, prompt, num_tokens=100, temperature=1.0):
    # Get the tokenizer encoding from tiktoken (GPT-2 encoding)
    encoding = tiktoken.get_encoding("gpt2")
    input_ids = encoding.encode(prompt)

    for _ in range(num_tokens):
        # Limit input to maximum model input length if needed
        input_ids_cond = input_ids[-(SEQUENCE_LENGTH - 1):]
        input_tensor = tf.convert_to_tensor([input_ids_cond])
        logits = model(input_tensor, training=False)
        # Get logits for the last token and apply temperature scaling
        logits = logits[:, -1, :] / temperature
        probs = tf.nn.softmax(logits, axis=-1).numpy()[0]
        # Sample from the distribution
        next_token = int(np.random.choice(len(probs), p=probs))
        input_ids.append(next_token)
    return encoding.decode(input_ids)

# --- Generate and Print the Output ---
prompt = "The world seemed like such a peaceful place until the magic tree was discovered in London."
generated_text = generate_text(model, prompt, num_tokens=100, temperature=1.0)
print("Generated Text:\n", generated_text)


=======================

In [None]:
import os
# Suppress INFO-level messages (set to "2" to show warnings and errors only)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import io
import math
import time
import numpy as np
import tensorflow as tf
import tiktoken
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, initializers, optimizers
from tensorflow.keras.mixed_precision import LossScaleOptimizer
from tensorflow.keras.utils import Progbar

# -----------------------
# Utility: Download TFRecord Files from URLs
# -----------------------
def download_tfrecord_files(url_dict):
    local_paths = {}
    for fname, url in url_dict.items():
        try:
            path = tf.keras.utils.get_file(fname, url)
            local_paths[fname] = path
            tf.get_logger().info(f"Downloaded {fname} to {path}")
        except Exception as e:
            tf.get_logger().error(f"Error downloading {fname} from {url}: {e}")
            raise e
    return local_paths

# -----------------------
# URLs for the TFRecord files.
# -----------------------
tfrecord_urls = {
    "books_part1": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_1_sequenced_shuffled_books.tfrecord?download=true",
    # "books_part2": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_2_sequenced_shuffled_books.tfrecord?download=true",
    # "conversations_part3": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_3_sequenced_shuffled_conversations.tfrecord?download=true",
    # "conversations_part4": "https://huggingface.co/datasets/tonadeleon/books_and_conversations/resolve/main/512_4_sequenced_shuffled_conversations.tfrecord?download=true",
}

# -----------------------
# TFRecord Parsing Function
# -----------------------
def parse_tfrecord(example_proto):
    feature_description = {'tokens': tf.io.FixedLenFeature([512], tf.int64)}
    parsed_example = tf.io.parse_single_example(example_proto, feature_description)
    tokens = parsed_example['tokens']
    return tokens[:-1], tokens[1:]

# -----------------------
# Pre-computed Dataset Information
# -----------------------

# 📊 **Dataset Sequence Counts**
# books_part1: 1089431 sequences
# books_part2: 1276920 sequences
# conversations_part3: 1139206 sequences
# conversations_part4: 616193 sequences

# 📌 **Recommended Epochs for Each Dataset**
# books_part1: 11 epochs
# books_part2: 13 epochs
# conversations_part3: 11 epochs
# conversations_part4: 6 epochs

dataset_info = {
    "books_part1": {"path": None, "count": 1089431, "epochs": 11},
    # "books_part2": {"path": None, "count": 1276920, "epochs": 13},
    # "conversations_part3": {"path": None, "count": 1139206, "epochs": 11},
    # "conversations_part4": {"path": None, "count": 616193, "epochs": 6},
}

# Download files and update dataset_info with file paths.
local_tfrecord_paths = download_tfrecord_files(tfrecord_urls)
for ds_name in dataset_info:
    dataset_info[ds_name]["path"] = local_tfrecord_paths[ds_name]

# -----------------------
# Transformer Model Components & Helper Functions
# -----------------------
class RMSNorm(layers.Layer):
    def __init__(self, epsilon=1e-8, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
    def build(self, input_shape):
        self.gamma = self.add_weight(name="gamma",
                                     shape=input_shape[-1:],
                                     initializer="ones",
                                     trainable=True)
        super(RMSNorm, self).build(input_shape)
    def call(self, inputs):
        gamma = tf.cast(self.gamma, inputs.dtype)
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
        return inputs * gamma / rms

def apply_rope(x, sin, cos):
    sin = tf.cast(sin, x.dtype)
    cos = tf.cast(cos, x.dtype)
    head_dim = tf.shape(x)[-1]
    x = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [head_dim // 2, 2]], axis=0))
    x1, x2 = x[..., 0], x[..., 1]
    sin_tensor = tf.expand_dims(tf.expand_dims(sin, axis=0), axis=2)
    cos_tensor = tf.expand_dims(tf.expand_dims(cos, axis=0), axis=2)
    x_rotated_first = x1 * cos_tensor - x2 * sin_tensor
    x_rotated_second = x1 * sin_tensor + x2 * cos_tensor
    x = tf.stack([x_rotated_first, x_rotated_second], axis=-1)
    return tf.reshape(x, tf.concat([tf.shape(x)[:-2], [head_dim]], axis=0))

class TiedDense(layers.Layer):
    def __init__(self, tied_to, **kwargs):
        super(TiedDense, self).__init__(**kwargs)
        self.tied_to = tied_to
    def call(self, inputs):
        tied_embeddings = tf.cast(self.tied_to.embeddings, inputs.dtype)
        return tf.matmul(inputs, tied_embeddings, transpose_b=True)
    def get_config(self):
        config = super().get_config()
        config.update({"tied_to": self.tied_to.name})
        return config

class RotarySelfAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1, **kwargs):
        super(RotarySelfAttention, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        if self.head_dim * num_heads != embed_dim:
            raise ValueError("embed_dim must be divisible by num_heads")
        self.dropout_rate = dropout_rate
    def build(self, input_shape):
        robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.query_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.key_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.value_dense = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.out_dense   = layers.Dense(self.embed_dim, kernel_initializer=robust_init)
        self.dropout     = layers.Dropout(self.dropout_rate)
        super(RotarySelfAttention, self).build(input_shape)
    def call(self, inputs, training=False, use_causal_mask=True):
        batch_size = tf.shape(inputs)[0]
        seq_len    = tf.shape(inputs)[1]
        query = self.query_dense(inputs)
        key   = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = tf.reshape(query, (batch_size, seq_len, self.num_heads, self.head_dim))
        key   = tf.reshape(key, (batch_size, seq_len, self.num_heads, self.head_dim))
        value = tf.reshape(value, (batch_size, seq_len, self.num_heads, self.head_dim))
        position = tf.cast(tf.range(seq_len), tf.float32)
        head_dim_int = self.head_dim
        inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, head_dim_int, 2), tf.float32) / tf.cast(head_dim_int, tf.float32)))
        sinusoid_inp = tf.tensordot(position, inv_freq, axes=0)
        sin = tf.sin(sinusoid_inp)
        cos = tf.cos(sinusoid_inp)
        query = apply_rope(query, sin, cos)
        key   = apply_rope(key, sin, cos)
        query = tf.transpose(query, perm=[0, 2, 1, 3])
        key   = tf.transpose(key, perm=[0, 2, 1, 3])
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        scaling = tf.cast(self.head_dim, query.dtype) ** -0.5
        query = query * scaling
        attn_logits = tf.matmul(query, key, transpose_b=True)
        if use_causal_mask:
            mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=query.dtype), -1, 0)
            mask = tf.reshape(mask, (1, 1, seq_len, seq_len))
            attn_logits = attn_logits * mask + tf.cast(-1e4, attn_logits.dtype) * (1 - mask)
        attn_weights = tf.nn.softmax(attn_logits, axis=-1)
        attn_weights = self.dropout(attn_weights, training=training)
        attn_output = tf.matmul(attn_weights, value)
        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
        attn_output = tf.reshape(attn_output, (batch_size, seq_len, self.embed_dim))
        output = self.out_dense(attn_output)
        return output

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.attention = RotarySelfAttention(embed_dim, num_heads, dropout_rate)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = RMSNorm(epsilon=1e-8)
        self.ffn = models.Sequential([
            layers.Dense(ff_dim, activation=tf.nn.gelu,
                         kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02)),
            layers.Dense(embed_dim, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.02))
        ])
        self.dropout2 = layers.Dropout(dropout_rate)
        self.norm2 = RMSNorm(epsilon=1e-8)
    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, training=training, use_causal_mask=True)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

def create_transformer_model(vocab_size, sequence_length, embed_dim, num_heads, ff_dim, num_layers, dropout_rate=0.1):
    inputs = layers.Input(shape=(sequence_length,), dtype=tf.int32)
    robust_init = initializers.RandomNormal(mean=0.0, stddev=0.02)
    token_embedding = layers.Embedding(input_dim=vocab_size,
                                       output_dim=embed_dim,
                                       embeddings_initializer=robust_init,
                                       name="token_embedding")
    x = token_embedding(inputs)
    for i in range(num_layers):
        x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate, name=f"transformer_block_{i}")(x)
    x = RMSNorm(epsilon=1e-8, name="final_rmsnorm")(x)
    logits = TiedDense(token_embedding, name="output_projection")(x)
    logits = layers.Lambda(lambda x: tf.cast(x, tf.float32))(logits)
    return models.Model(inputs=inputs, outputs=logits)

# -----------------------
# Custom Perplexity Metric
# -----------------------
class Perplexity(tf.keras.metrics.Metric):
    def __init__(self, name='perplexity', **kwargs):
        super().__init__(name=name, **kwargs)
        self.ce_tracker = tf.keras.metrics.Mean(name="crossentropy_mean", dtype=tf.float32)
    def update_state(self, y_true, y_pred, sample_weight=None):
        ce = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
        self.ce_tracker.update_state(ce, sample_weight=sample_weight)
    def result(self):
        avg_ce = self.ce_tracker.result()
        return tf.exp(avg_ce)
    def reset_state(self):
        self.ce_tracker.reset_state()

# -----------------------
# Learning Rate and Weight Decay Schedules
# -----------------------
class WarmUpCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, total_steps, warmup_steps, alpha=0.0):
        super(WarmUpCosineDecay, self).__init__()
        self.initial_lr = initial_lr
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps
        self.alpha = alpha
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_lr = self.initial_lr * step / tf.cast(self.warmup_steps, tf.float32)
        cosine_steps = tf.maximum(step - tf.cast(self.warmup_steps, tf.float32), 0.0)
        total_cosine_steps = tf.maximum(tf.cast(self.total_steps - self.warmup_steps, tf.float32), 1.0)
        cosine_decay = 0.5 * (1 + tf.cos(np.pi * cosine_steps / total_cosine_steps))
        decayed_lr = self.alpha * self.initial_lr + (1 - self.alpha) * self.initial_lr * cosine_decay
        return tf.where(step < tf.cast(self.warmup_steps, tf.float32), warmup_lr, decayed_lr)
    def get_config(self):
        return {"initial_lr": self.initial_lr, "total_steps": self.total_steps,
                "warmup_steps": self.warmup_steps, "alpha": self.alpha}

class DynamicWeightDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, base_wd, lr_schedule):
        super(DynamicWeightDecay, self).__init__()
        self.base_lr = base_lr
        self.base_wd = base_wd
        self.lr_schedule = lr_schedule
    def __call__(self, step):
        current_lr = self.lr_schedule(step)
        return self.base_wd * (current_lr / self.base_lr)
    def get_config(self):
        return {"base_lr": self.base_lr, "base_wd": self.base_wd}

# -----------------------
# Helper Function to Exclude Certain Parameters from Weight Decay
# -----------------------
def should_apply_weight_decay(var):
    var_name = var.name.lower()
    if "bias" in var_name:
        return False
    if "norm" in var_name or "rmsnorm" in var_name:
        return False
    return True

# -----------------------
# Main Training Function with Interleaved Scheduling (Actual Version)
# -----------------------
def main():
    # --- Build schedule using the desired interleaved order ---
    # desired_order = ["books_part2", "conversations_part3", "books_part1", "conversations_part4"]
    desired_order = ["books_part1"]
    # Get the maximum epoch count among the datasets in the desired order.
    # max_epochs = max(dataset_info[ds]["epochs"] for ds in desired_order)
    max_epochs = 1
    rounds = 2
    schedule = []
    for r in range(rounds):
        for e in range(1, max_epochs + 1):
            for ds in desired_order:
                if e <= dataset_info[ds]["epochs"]:
                    schedule.append((ds, e, dataset_info[ds]["epochs"], r + 1))
    total_scheduled_epochs = len(schedule)
    print(f"Total scheduled epochs (actual): {total_scheduled_epochs}")

    # Training parameters.
    steps_per_epoch = 3000
    total_steps = total_scheduled_epochs * steps_per_epoch
    warmup_steps = int(0.1 * total_steps)
    initial_lr = 1e-3
    BATCH_SIZE = 32

    # -----------------------
    # GPU Configuration, Mixed Precision & XLA
    # -----------------------
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    tf.config.optimizer.set_jit(True)
    strategy = tf.distribute.MirroredStrategy()
    print(f"Number of devices (actual): {strategy.num_replicas_in_sync}", flush=True)

    # Model Hyperparameters & Tokenizer Setup.
    SEQUENCE_LENGTH = 512
    # Updated hyperparameters to match GPT-1 (≈117M parameters)
    embed_dim = 768
    num_heads = 12
    ff_dim = 3072
    num_layers = 12
    dropout_rate = 0.1
    gpt2_encoding = tiktoken.get_encoding("gpt2")
    vocab_size = gpt2_encoding.n_vocab

    # -----------------------
    # Build the Model within the Strategy Scope.
    # -----------------------
    with strategy.scope():
        model = create_transformer_model(vocab_size, SEQUENCE_LENGTH - 1,
                                         embed_dim, num_heads, ff_dim, num_layers, dropout_rate)
        lr_schedule = WarmUpCosineDecay(initial_lr, total_steps, warmup_steps, alpha=0.0)
        dynamic_wd = DynamicWeightDecay(initial_lr, base_wd=1e-4, lr_schedule=lr_schedule)
        base_optimizer = optimizers.AdamW(
            learning_rate=lr_schedule,
            weight_decay=0.0,
            clipnorm=1.0
        )
        optimizer = LossScaleOptimizer(base_optimizer, dynamic=True)
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    # Print model summary.
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + "\n"))
    print(stream.getvalue(), flush=True)

    # -----------------------
    # Checkpointing Setup.
    # -----------------------
    checkpoint_dir = './checkpoints_actual'
    os.makedirs(checkpoint_dir, exist_ok=True)
    global_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer, epoch=global_epoch)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)
    initial_global_epoch = 0
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print(f"Restored (actual) from {ckpt_manager.latest_checkpoint}", flush=True)
        initial_global_epoch = int(global_epoch.numpy())

    # -----------------------
    # Metrics & Global Step.
    # -----------------------
    train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
    train_perplexity_metric = Perplexity(name='train_perplexity')
    val_loss_metric = tf.keras.metrics.Mean(name='val_loss')
    val_perplexity_metric = Perplexity(name='val_perplexity')
    global_step = tf.Variable(0, trainable=False, dtype=tf.int64)

    @tf.function
    def train_step(x_batch_train, y_batch_train, global_step):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
            scaled_loss = optimizer.get_scaled_loss(loss_value)
        scaled_grads = tape.gradient(scaled_loss, model.trainable_variables)
        grads = optimizer.get_unscaled_gradients(scaled_grads)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        global_step.assign_add(1)
        current_lr = lr_schedule(tf.cast(global_step, tf.float32))
        current_weight_decay = dynamic_wd(tf.cast(global_step, tf.float32))
        for var in model.trainable_variables:
            if should_apply_weight_decay(var):
                var.assign_sub(current_lr * current_weight_decay * var)
        return loss_value, logits

    # -----------------------
    # Lists to Store Metrics for Each Epoch
    # -----------------------
    train_losses = []
    val_losses = []
    train_perplexities = []
    val_perplexities = []

    # -----------------------
    # Main Training Loop (Actual Version)
    # -----------------------
    for sched_epoch in range(initial_global_epoch, total_scheduled_epochs):
        ds_name, epoch_in_ds, total_epochs_for_ds, current_round = schedule[sched_epoch]
        print(f"\n[Actual] Round {current_round} - Training dataset '{ds_name}', epoch {epoch_in_ds}/{total_epochs_for_ds}")
        info = dataset_info[ds_name]
        raw_dataset = tf.data.TFRecordDataset(info["path"])
        dataset = raw_dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
        # Use 1% for validation.
        val_size = int(0.01 * info["count"])
        val_dataset = dataset.take(val_size)
        train_dataset = dataset.skip(val_size)
        train_dataset = train_dataset.shuffle(10000, reshuffle_each_iteration=True)
        train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
        val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)
        val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

        train_loss_metric.reset_state()
        train_perplexity_metric.reset_state()
        progbar = Progbar(steps_per_epoch)
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset.take(steps_per_epoch)):
            loss_value, logits = train_step(x_batch_train, y_batch_train, global_step)
            train_loss_metric.update_state(loss_value)
            train_perplexity_metric.update_state(y_batch_train, logits)
            progbar.update(step + 1, values=[("loss", train_loss_metric.result().numpy()),
                                             ("perplexity", train_perplexity_metric.result().numpy())])

        # Validation loop.
        val_loss_metric.reset_state()
        val_perplexity_metric.reset_state()
        for x_batch_val, y_batch_val in val_dataset:
            val_logits = model(x_batch_val, training=False)
            val_loss = loss_fn(y_batch_val, val_logits)
            val_loss_metric.update_state(val_loss)
            val_perplexity_metric.update_state(y_batch_val, val_logits)

        current_train_loss = train_loss_metric.result().numpy()
        current_val_loss = val_loss_metric.result().numpy()
        current_train_perplexity = train_perplexity_metric.result().numpy()
        current_val_perplexity = val_perplexity_metric.result().numpy()

        train_losses.append(current_train_loss)
        val_losses.append(current_val_loss)
        train_perplexities.append(current_train_perplexity)
        val_perplexities.append(current_val_perplexity)

        print(f"[Actual] Epoch {sched_epoch + 1}/{total_scheduled_epochs}: Train Loss = {current_train_loss:.4f}, Train Perplexity = {current_train_perplexity:.4f}")
        print(f"[Actual] Epoch {sched_epoch + 1}/{total_scheduled_epochs}: Val Loss = {current_val_loss:.4f}, Val Perplexity = {current_val_perplexity:.4f}")

        global_epoch.assign(sched_epoch + 1)
        saved_path = ckpt_manager.save()
        print(f"[Actual] Checkpoint saved at: {saved_path}", flush=True)

    # -----------------------
    # Plotting results for all epochs.
    # -----------------------
    epochs_range = range(1, total_scheduled_epochs + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, train_losses, marker='o', label='Training Loss')
    plt.plot(epochs_range, val_losses, marker='o', label='Validation Loss')
    plt.title('Actual: Training vs. Validation Loss Over All Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, train_perplexities, marker='o', label='Training Perplexity')
    plt.plot(epochs_range, val_perplexities, marker='o', label='Validation Perplexity')
    plt.title('Actual: Training vs. Validation Perplexity Over All Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == '__main__':
    main()
