This notebook explores simplifying training labels by using embeddings and cosine similarity. We remove some words and see if the meaning is close to the documents we are labeling.

In [1]:
import pandas as pd

In [2]:
!pwd

/Users/Rrando/Documents/GitHub/smart-tab-grouping/notebooks


In [3]:
fine_tune_raw = pd.read_csv("../data/extract/fine_tuning_data__with_none__common_crawl_2025-02-23_08-18.csv", keep_default_na=False)

In [4]:
from sklearn.metrics.pairwise import cosine_similarity
from transformers import pipeline
import numpy as np

In [5]:
hint_embedder = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2", device=-1)


In [6]:
def get_embeddings_for_hint_matching(input_df):
    embed_input = pd.Series(input_df["input_titles"] + " " + input_df["input_keywords"]).to_list()
    return np.array([np.mean(hint_embedder(sentence)[0], axis=0) for sentence in embed_input])


In [7]:
per_row_embeddings = get_embeddings_for_hint_matching(fine_tune_raw)

In [8]:
FIXED_TOPICS = {"Adult Content", "None"}

In [61]:
def refine_table_class(row, index, boost=0.1):
    initial = row.output
    if initial in FIXED_TOPICS:
        return initial, None, None
    words = initial.split()
    initial_embedding = hint_embedder(initial)[0][0]
    if len(words) == 1:
        return initial, None, None
    alt_embeddings = [initial_embedding]
    alt_phrases = [initial]
    for missing_word_index in range(len(words)):
        # remove 1 word
        shortened_phrase = " ".join(words[:missing_word_index] + words[missing_word_index+1:])
        embed = hint_embedder(shortened_phrase)[0][0]
        alt_phrases.append(shortened_phrase)
        alt_embeddings.append(embed)
    if len(words) > 2:
        # remove 2 words
        for missing_word_index in range(len(words) - 1):
            shortened_phrase = " ".join(words[:missing_word_index] + words[missing_word_index + 2:])
            embed = hint_embedder(shortened_phrase)[0][0]
            alt_phrases.append(shortened_phrase)
            alt_embeddings.append(embed)
    document_embedding = per_row_embeddings[index].reshape(1,-1)
    similarity = cosine_similarity(document_embedding, np.array(alt_embeddings)).squeeze()
    similarity[0] -= boost # was 0.1
#    closest_indices = np.argsort(-similarity)
#    best_phrases = [alt_phrases[i] for i in closest_indices.tolist()]
    return initial, similarity.tolist(), alt_phrases


In [62]:
df = pd.DataFrame(data={"items":[], "scores":[]})

In [63]:
score_list = []
topic_list = []
initial_list = []

In [64]:
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [65]:
max_possibilities = 8
for index, row in fine_tune_raw.iterrows():
    initial, scores, topics = refine_table_class(row, index)
    if scores is not None:        
        topic_list.append(topics)
        score_list.append(scores)
        initial_list.append(initial)
    if index > 200:
        break

In [66]:
# Define colormap (white to red)
cmap = plt.cm.Greens
all_scores_flat = [score for row in score_list for score in row]
norm = mcolors.Normalize(vmin=min(all_scores_flat), vmax=max(all_scores_flat))

def score_to_hex(score):
    rgba = cmap(norm(score))
    return mcolors.to_hex(rgba)

def get_style_with_color(color):
    return f"background-color:{color}; padding:4px; margin:2px; border-radius:4px;"

# Generate HTML
html = "<div style='font-family:sans-serif; line-height:2em;'>"
for initial, row_topics, row_scores in zip(initial_list, topic_list, score_list):
    html += f"<span style='{get_style_with_color(score_to_hex(-1))}'>{initial}</span>"
    for topic, score in zip(row_topics, row_scores):
        color = score_to_hex(score)
        html += f"<span style='{get_style_with_color(color)}'>{topic}</span>"
    html += "<br>"  # new line after each row
html += "</div>"

# Display the styled HTML
display(HTML(html))    

In [None]:
def remove_and(a: str):
    words = a.split(" ")
    if len(words) < 2:
        return a
    for stopword in ["and", "&"]:
        if stopword in words:
            and_pos = words.index(stopword)
            if and_pos > 0:
                return " ".join(words[0:and_pos])
    return " ".join(words)


In [None]:
remove_and("Word & Dog")

In [None]:
def remove_or(a: str):
    words = a.split(" ")
    if len(words) < 3:
        return a
    if words[-2] == 'or':
        words = words[:-2]
    return " ".join(words)


In [None]:
def remove_news(a: str):
    words = a.split(" ")
    if len(words) == 1:
        return a
    if words[-1] == 'News':
        words = words[:-1]
    return " ".join(words)


In [None]:
remove_news("Tech News")

In [None]:
def remove_year(a: str):
    words = list(filter(lambda word: word not in ["2023", "2024"], a.split(" ")))
    return " ".join(words)
