<div align="center">

# 🚀 Spearecode Preprocessing 🚀

</div>

<br>

Welcome to the **Spearecode Preprocessing Notebook**! This notebook will guide you through the necessary preprocessing steps to prepare a toy dataset for Language Model training. We will focus on making the dataset more suitable for training by performing the following steps:

1. 📚 **Loading the dataset**: We'll start by importing the dataset from a file or external source.
2. 📦 **Chunking the text**: The dataset will be divided into smaller chunks or segments, making it easier to process during training.
3. 💬 **Tokenization**: Each chunk of text will be split into individual tokens (words or subwords), which are the basic units for language models.
4. 📊 **Basic Exploratory Data Analysis (EDA)**: We'll analyze the dataset's characteristics, such as token frequency, to gain insights and identify potential issues.

After completing the preprocessing and EDA, the toy dataset will be converted into `TFRecords` format. This efficient binary format is designed for use with TensorFlow and will enable seamless integration with your Language Model training pipeline.

Let's dive in and start preprocessing the dataset! 🎉


<br><br>

<div align="center">

# 🌟 Table of Contents 🌟

</div>

---

0. [**Setup**](#setup)
1. [**Loading the Dataset**](#loading-the-dataset)
2. [**Chunking the Text**](#chunking-the-text)
3. [**Tokenization**](#tokenization)
4. [**Basic Exploratory Data Analysis (EDA)**](#basic-eda)
5. [**Converting to TFRecords**](#converting-to-tfrecords)

---



<br>

<div align="center">

## 🛠️ Setup <a name="setup"></a>

</div>

<br>

In this section, we'll import required libraries and methods from our utilities file. We will also define relevant paths and high level information we may need later. We also run a few basic Tensorflow setup steps to ensure optimal and reproducible runs.

In [None]:
### IMPORTS ###
import os
import sys
import random
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
import sentencepiece as spm
PROJECT_DIR = os.path.dirname(os.getcwd())
sys.path.insert(0, PROJECT_DIR) # project root into path

from spearecode.preprocessing_utils import load_from_txt_file, preprocess_shakespeare, save_to_txt_file, print_check_speare
from spearecode.general_utils import tf_xla_jit, tf_set_memory_growth, seed_it_all, flatten_l_o_l, print_ln

### DEFINE PATHS ###
DATA_PATH = os.path.join(PROJECT_DIR, "data")
SS_TEXT_PATH = os.path.join(DATA_PATH, "t8.shakespeare.txt")
NBS_PATH = os.path.join(PROJECT_DIR, "nbs")

<br>

<div align="center">

## 📚 Loading the Dataset <a name="loading-the-dataset"></a>

</div>

<br>

In this section, we'll import the dataset from a file or external source. The dataset will be read into memory, allowing us to manipulate and process the text as needed throughout the preprocessing steps.


In [None]:
PREPROCESSED_FULL_TEXT_PATH = SS_TEXT_PATH.replace(".txt", "_preprocessed.txt")
raw_text = load_from_txt_file(SS_TEXT_PATH)
ss_text = preprocess_shakespeare(raw_text)
save_to_txt_file(ss_text, PREPROCESSED_FULL_TEXT_PATH)
print_check_speare(ss_text)

<br>

<div align="center">

## 📦 Chunking the Text <a name="chunking-the-text"></a>

</div>

<br>

Once the dataset is loaded, we'll divide it into smaller chunks or segments. This step is crucial for making the dataset more similar to code files (which is the type of data we will be using during the other parallel streams).

I implement two simple methods:
1. A basic double newline split **(`\n\n`)** resulting in 6294 chunks
2. Using Langchain RecursiveTextSplitter to chunk to a particular text length
    * This allows us to specify our desired text length and even overlap the chunks.
        * Note we allow for a small amount of overlap and this may cause some leakage... but whatever.
    * **We will use this method for our purposes.**
    


In [None]:
def do_rcts_chunking(text, chunk_size=512, chunk_overlap=64, length_fn=len):
    """
    Perform Recursive Character Text Splitting (RCTS) chunking on the input text.
    
    Args:
        text (str): The input text to be chunked.
        chunk_size (int): The maximum size of each chunk.
        chunk_overlap (int): The number of overlapping characters between adjacent chunks.
        length_fn (callable, optional): Function to calculate the length of the text. Defaults to len.
    
    Returns:
        list: A list of chunked text segments.
    """
    # Import the RecursiveCharacterTextSplitter from langchain.text_splitter module
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    
    # Instantiate the text splitter with the specified parameters
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=length_fn,
    )
    
    # Split the input text into chunks
    docs = text_splitter.create_documents([text])
    
    # Return the list of chunked text segments
    return [x.page_content for x in docs]

def do_basic_chunking(text, chunk_delimeter="\n\n", add_delim_back=False):
    """
    Perform basic chunking on the input text using the specified delimiter.
    
    Args:
        text (str): The input text to be chunked.
        chunk_delimeter (str, optional): The delimiter used to split the text. Defaults to "\n\n".
        add_delim_back (bool, optional): Whether to add the delimiter back to the end of each chunk. Defaults to False.
    
    Returns:
        list: A list of chunked text segments.
    """
    # Split the input text based on the specified delimiter (ensure no empty chunks by stripping from ends)
    docs = text.strip(chunk_delimeter).split(chunk_delimeter)
    
    # If specified, add the delimiter back to the end of each chunk
    if add_delim_back:
        docs = [x + chunk_delimeter for x in docs]
    
    # Return the list of chunked text segments
    return docs

In [None]:
# Feel free to pass non-default kwargs 
#    -- otherwise the rcts chunks will overlap by 64 and be 512 characters long
CHUNK_STYLE = "basic" # one of ['basic' | 'rcts']
basic_chunks = do_basic_chunking(ss_text)
rcts_chunks = do_rcts_chunking(ss_text)

print("\n... FIRST BASIC CHUNK ...\n")
print(basic_chunks[0])

print("\n... FIRST RCTS CHUNK ...\n")
print(rcts_chunks[0])

print("\n... EXAMPLE RANDOM BASIC CHUNK ...\n")
print(random.sample(basic_chunks, 1)[0])

print("\n... EXAMPLE RANDOM RCTS CHUNK ...\n")
print(random.sample(rcts_chunks, 1)[0])

print("\n... LAST BASIC CHUNK ...\n")
print(basic_chunks[-1])

print("\n... LAST RCTS CHUNK ...\n")
print(rcts_chunks[-1])


<br>

<div align="center">

## 💬 Tokenization <a name="tokenization"></a>

</div>

<br>

In this section, we'll tokenize the text, which involves splitting the chunks into individual tokens (words or subwords). Tokenization is an essential step in preprocessing, as it helps the Language Model understand the basic units of the text and learn meaningful patterns.

* We will train our tokenizer on the non-chunked dataset (after basic preprocessing), however, we will 


In [None]:
# Setup model directory if not already setup
MODEL_DIR = os.path.join(os.path.dirname(DATA_PATH), "models")
if not os.path.isdir(MODEL_DIR): os.makedirs(MODEL_DIR, exist_ok=True)

# User defined parameters (matching alphafold and code tokenization standards)
MODEL_PATH = os.path.join(MODEL_DIR, 'spearecode_bpe')
VOCAB_SIZE = 8_000
CHAR_COVERAGE = 0.99995
TOKENIZER_STYLE="bpe"
USER_DEFINED_SYMBOLS = ["\n","\t","\r","\f","\v"]

# Tokenizer parameters (and some defaults)
tokenizer_kwargs = dict(
    input = PREPROCESSED_FULL_TEXT_PATH,
    model_prefix=MODEL_PATH,
    vocab_size=VOCAB_SIZE,
    character_coverage=CHAR_COVERAGE,
    model_type=TOKENIZER_STYLE,
    pad_id=0, unk_id=1, bos_id=2, eos_id=3,
    remove_extra_whitespaces=False,
    allow_whitespace_only_pieces=True,
    add_dummy_prefix=False,
    user_defined_symbols=USER_DEFINED_SYMBOLS,
    normalization_rule_name="identity"
)


# train_tokenizer(ALL_TXT_PATHS, MODEL_PATH, VOCAB_SIZE, TOKENIZER_STYLE)
spm.SentencePieceTrainer.Train(**tokenizer_kwargs)

sp_uni = spm.SentencePieceProcessor().load(f'{MODEL_PATH}.model')
sp_bpe = spm.SentencePieceProcessor().load(f'{MODEL_PATH}_bpe.model')

uni_encoder = lambda x: sp_uni.encode(x)
uni_decoder = lambda x: sp_uni.decode(x)

bpe_encoder = lambda x: sp_bpe.encode(x)
bpe_decoder = lambda x: sp_bpe.decode(x)

In [None]:
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors


def replace_rightmost_newline(s, replacement='<br>'):
    parts = s.rsplit('\\n', 1)
    return replacement.join(parts)


def get_color(value, cmap='Pastel1', transparency=0.5):
    """
    Returns an HTML-formatted string representing the background color for a token.

    Args:
        value (int): The index of the token to color.
        cmap (str, optional): The name of the colormap to use. Defaults to 'Pastel1'.

    Returns:
        str: An HTML-formatted string representing the background color.
    """
    colormap = plt.get_cmap(cmap)
    return f"background-color: rgba{tuple([int(x*255) for x in colormap(value % colormap.N)[:-1]]+[transparency,])};"


def get_visualization1(tokens, decoder=None, cmap='Pastel1', font_family='Courier New',
                       transparency=0.75, font_size='1.1em', unk_token='???', font_weight=300, padding='0px',
                       margin_right='0px', border_radius='0px', display_inline=False):
    """
    Generates an HTML string to visualize the tokenization of a text.

    Args:
        tokens (list):
            – A list of integer tokens.
        decoder (function, optional):
            – A function that maps an integer to the representative string
            – If this is None, the tokens are assumed to be strings not integers
        cmap (str, optional):
            – The name of the colormap to use. Defaults to 'Pastel1'.
        font_family (str, optional):
            – The font family to use for tokens. Defaults to 'Courier New'.
        transparency (float, optional):
            background transparency
        font_size (str, optional):
            – The font size to use for tokens. Defaults to '1.1em'.
        unk_token (str, optional):
            – The string to use for unknown tokens. Defaults to '???'
        font_weight (str, optional):
            – The font weight to use for tokens. Defaults to 'bold'.
        padding (str, optional):
            – The padding to use for tokens. Defaults to '2px'.
        margin_right (str, optional):
            – The right margin to use for tokens. Defaults to '5px'.
        border_radius (str, optional):
            – The border radius to use for tokens. Defaults to '3px'.
        display_inline (bool, optional):
            – Whether to display the HTML inline. Defaults to False.

    Returns:
        str: An HTML string representing the tokenized text with styling.
    """
    html = f"<style>span.token {{font-family: {font_family} !important; font-size: {font_size} !important; font-weight: {font_weight} !important; " \
           f"padding: {padding} !important; margin-right: {margin_right} !important; border-radius: {border_radius} !important;}}</style>"

    html += "<div style='background-color: #F8F8F8; padding: 15px; border-radius: 5px;'>"
    
    for i, token in enumerate(tokens):
        color = get_color(i, cmap, transparency)
        try:
            html += f"<span class='token' style='{color}'>{decoder(token).replace(' ', '&nbsp;')}</span>"
        except TypeError:
            html += f"<span class='token' style='{color}'>{unk_token}</span>"

    html += "</div>"
    
    if display_inline:
        HTML(html)

    return html


def get_line_viz(token_lines, decoder, cmap='Pastel1', font_family='Courier New',
                 transparency=0.75, font_size='1.1em', unk_token='???', font_weight=300, padding='0px',
                 margin_right='0px', border_radius='0px', display_inline=False):
    """
    Generates an HTML string to visualize the tokenization of a text.

    Args:
        token_lines (list):
            – A list of lists of integer tokens.
        decoder (function, optional):
            – A function that maps an integer to the representative string
            – If this is None, the tokens are assumed to be strings not integers
        cmap (str, optional):
            – The name of the colormap to use. Defaults to 'Pastel1'.
        font_family (str, optional):
            – The font family to use for tokens. Defaults to 'Courier New'.
        transparency (float, optional):
            background transparency
        font_size (str, optional):
            – The font size to use for tokens. Defaults to '1.1em'.
        unk_token (str, optional):
            – The string to use for unknown tokens. Defaults to '???'
        font_weight (str, optional):
            – The font weight to use for tokens. Defaults to 'bold'.
        padding (str, optional):
            – The padding to use for tokens. Defaults to '2px'.
        margin_right (str, optional):
            – The right margin to use for tokens. Defaults to '5px'.
        border_radius (str, optional):
            – The border radius to use for tokens. Defaults to '3px'.
        display_inline (bool, optional):
            – Whether to display the HTML inline. Defaults to False.

    Returns:
        str: An HTML string representing the tokenized text with styling.
    """
    html = f"<style>span.token {{font-family: {font_family} !important; font-size: {font_size} !important; font-weight: {font_weight} !important; " \
           f"padding: {padding} !important; margin-right: {margin_right} !important; border-radius: {border_radius} !important;}}</style>"

    html += "<div style='background-color: #F8F8F8; padding: 15px; border-radius: 5px;'>"
    
    for token_line in token_lines:
        for i, token in enumerate(token_line):
            color = get_color(i, cmap, transparency)
            try:
                html += f"<span class='token' style='{color}'>{decoder(token).replace(' ', '&nbsp;')}</span>".replace('\t', '\\t').replace('\n', '\\n').replace('\r', '\\r').replace('\f', '\\f').replace('\v', '\\v')
            except TypeError:
                html += f"<span class='token' style='{color}'>{unk_token}</span>"
        html = replace_rightmost_newline(html)
    html += "</div>"
    
    if display_inline:
        HTML(html)

    return html

def plot_tokenization(text, encoder, decoder, split_on="\n"):
    display(HTML(get_line_viz([encoder(x+split_on) for x in text.split(split_on)], decoder)))

In [None]:
print("\n... BPE TOKENIZATION:")
plot_tokenization(basic_chunks[0], bpe_encoder, bpe_decoder)

print("\n... UNIGRAM TOKENIZATION:")
plot_tokenization(basic_chunks[0], uni_encoder, uni_decoder)

<br>

<div align="center">

## 📊 Basic Exploratory Data Analysis (EDA) <a name="basic-eda"></a>

</div>

<br>

Here, we'll perform a basic EDA on the dataset to gain insights and identify potential issues. This analysis may include examining token frequency, distribution of chunk lengths, and other relevant characteristics. This information can be helpful in understanding the dataset's structure and guiding further preprocessing decisions.


<br>

<div align="center">

## 💾 Converting to TFRecords <a name="converting-to-tfrecords"></a>

</div>

<br>

Finally, after completing the preprocessing steps and EDA, we'll convert the toy dataset into the `TFRecords` format. This efficient binary format is designed for use with TensorFlow and will enable seamless integration with your Language Model training pipeline.

