In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from datasets import load_dataset
import pandas as pd

categorized_data = load_dataset("msaad02/categorized-data", split="train").to_pandas().dropna(subset=['category']).reset_index(drop=True)

def filter_out_small_strings(input_string):
    "There are many 'sentences' that are really just titles for sections. These don't add much value to the embeddings so we'll filter them out."
    split_string = pd.Series(input_string.split("\n")) # Split by newlines (which is how the data is formatted)
    tf_mask = [len(i) > 50 for i in split_string] # Filter out groups that are less than 50 characters long
    out = split_string[tf_mask].str.strip() # Apply filter and strip whitespace
    out = out.str.removeprefix("- ").str.removesuffix(".").to_list() # Remove leading bullets and trailing periods
    return ". ".join(out) # Join the list back into a string and add periods back in

cleaned = categorized_data['data'].apply(filter_out_small_strings)
cleaned = cleaned[cleaned.str.split(" ").str.len() > 100] # filter out data with less than 100 words

## Finds the average length of sentences in the dataset.
## We are aiming to have 2-3 sentences per chunk, so we need to found a character count to use. This tells us that. Our output was ~346 characters.
# cleaned.apply(lambda x: pd.Series([len(sentences) for sentences in x.split('. ')]).mean()).mean() * 2.5

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=350,
    chunk_overlap=25,
    separators=[".", "?", "!"]
)

raw_chunks = cleaned.apply(text_splitter.split_text)

def clean_chunks(chunks):
    "RecursiveCharacterTextSplitter has weird behavior... It puts the punctuation from previous chunks into the next chunk. This cleans that up."
    for i in range(len(chunks)-1):
        chunks[i] += chunks[i+1][:1]
        chunks[i+1] = chunks[i+1][2:]
    return chunks

chunked_data = pd.Series(raw_chunks.copy()).apply(clean_chunks)
categorized_data['chunked_data'] = chunked_data
categorized_data = categorized_data.dropna().reset_index(drop=True)
categorized_data = categorized_data.explode('chunked_data').loc[:, ['category', 'subcategory', 'chunked_data']].reset_index(drop=True)

# Filter to only chunks with less than 100 words
# Inspecting the data, most of the chunks with more than 100 words are just a bunch of bullet points of long lists about profs, programs, etc.
categorized_data = categorized_data[categorized_data['chunked_data'].str.split().str.len() < 100].reset_index(drop=True)

# # This is the distribution of the number of words per chunk. Should look nice and pretty (~Normal).
# categorized_data['chunked_data'].str.split().str.len().hist(bins=30)

In [2]:
df = categorized_data.rename(columns={'chunked_data': 'data'})

In [3]:

from sentence_transformers import SentenceTransformer
# Create all the embeddings for the categories and subcategories
model = SentenceTransformer('BAAI/bge-large-en-v1.5')

# Line of code from `train_subcat_classifier.py`
categories_with_subcategories = df['category'].unique()[df.groupby(["category"])['subcategory'].nunique() > 4]

# Create embeddings for all the categories and subcategories and store them in a dictionary
embeddings = {}
data = {}

for category in df['category'].unique():
    print("Writing embeddings for category: ", category, "...")
    category_df = df[df['category'] == category]

    if category in categories_with_subcategories:
        # first get the embeddings for the main category
        data_to_embed = category_df[category_df['subcategory'].isnull()]['data'].to_list()
        if data_to_embed != []:
            embeddings[category] = model.encode(data_to_embed, normalize_embeddings=True)
            data[category] = data_to_embed

        # split the data into subcategories
        subcategories = category_df[category_df['subcategory'].notnull()]['subcategory'].unique()
        for subcategory in subcategories:
            data_to_embed = category_df[category_df['subcategory'] == subcategory]['data'].to_list()
            if data_to_embed != []:
                embeddings[f"{category}-{subcategory}"] = model.encode(data_to_embed, normalize_embeddings=True)
                data[f"{category}-{subcategory}"] = data_to_embed
    else:
        data_to_embed = category_df['data'].to_list()
        if data_to_embed != []:
            embeddings[category] = model.encode(data_to_embed, normalize_embeddings=True)
            data[category] = data_to_embed

2023-12-30 20:26:11.883348: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-30 20:26:11.905844: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Writing embeddings for category:  about ...
Writing embeddings for category:  academics ...
Writing embeddings for category:  admissions ...
Writing embeddings for category:  alumni ...
Writing embeddings for category:  graduate ...
Writing embeddings for category:  library ...
Writing embeddings for category:  life ...
Writing embeddings for category:  live ...
Writing embeddings for category:  scholarships-aid ...
Writing embeddings for category:  support ...


In [5]:
embeddings

{'about-brockport-downtown': array([[ 1.23745971e-03, -3.89539413e-02, -1.61911938e-02, ...,
          1.93481576e-02, -2.95804795e-02, -1.51566425e-02],
        [ 1.00988550e-02, -4.17094380e-02, -2.87212748e-02, ...,
          1.50915398e-03, -3.21157500e-02, -1.05346795e-02],
        [-3.78557146e-02, -1.88044051e-03, -3.38566080e-02, ...,
          1.50390324e-05, -1.51441935e-02, -1.24610411e-02],
        ...,
        [ 7.16479635e-03,  1.11187631e-02, -6.86378311e-03, ...,
         -5.71413198e-03, -2.07590088e-02,  1.03724701e-02],
        [ 2.73780003e-02,  1.49657940e-02, -1.60304401e-02, ...,
         -3.59943099e-02, -2.13673823e-02,  8.04951706e-04],
        [-2.91582942e-02, -2.00760160e-02, -2.79658213e-02, ...,
         -4.14520921e-03,  1.67296361e-02, -1.07252318e-02]], dtype=float32),
 'about-diversity': array([[ 0.01750481, -0.04892975, -0.03209124, ..., -0.06007944,
         -0.0316372 ,  0.01592939],
        [-0.00348224,  0.00607599, -0.02567437, ..., -0.01857654,