In [1]:
!pip install sentence_transformers faiss-cpu



In [2]:
import math
import pandas as pd 
import torch

from sentence_transformers import SentenceTransformer

In [3]:
df = pd.read_csv("data/simulation_data.csv")

display(f"Found {len(df)} rows.")

'Found 50000 rows.'

In [4]:
RELEVANT_COLUMNS = ["input_text", "target_txt", "output_meta-llama_Llama-3.2-3B-Instruct", "output_meta-llama_Llama-3.1-8B-Instruct", "output_meta-llama_Llama-3.3-70B-Instruct"]

MINIBATCH_SIZE = 1024
NUM_CHUNKS = math.ceil(len(df) / MINIBATCH_SIZE)


In [5]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cuda" if torch.cuda.is_available() else "cpu")
display(f"Model running on {model.device}.")

for col in RELEVANT_COLUMNS:
    emb_list = []
    for idx, minibatch in enumerate(range(NUM_CHUNKS)):
        min_idx = minibatch * MINIBATCH_SIZE
        max_idx = (minibatch + 1) * MINIBATCH_SIZE - 1

        text_input = df.loc[min_idx: max_idx, col].tolist()
        embeddings = model.encode(text_input)
        emb_list += [i for i in embeddings]

    display(len(emb_list))
    df[f"{col}_embed"] = pd.Series(emb_list).astype(object)

'Model running on cuda:0.'

50000

50000

50000

50000

50000

In [6]:
def compute_distance(row, col_1: str, col_2: str): 
    distance = torch.sqrt(
        torch.sum(
            torch.pow(
                torch.subtract(torch.tensor(row[f"{col_1}_embed"]), torch.tensor(row[f"{col_2}_embed"])), 
                2
            ), 
            dim=0
        )
    )  
    return distance.item()

DISTANCE_TUPLES = {
    "3b-8b": ("output_meta-llama_Llama-3.2-3B-Instruct", "output_meta-llama_Llama-3.1-8B-Instruct"),
    "3b-70b": ("output_meta-llama_Llama-3.2-3B-Instruct", "output_meta-llama_Llama-3.3-70B-Instruct"),
    "8b-70b": ("output_meta-llama_Llama-3.1-8B-Instruct", "output_meta-llama_Llama-3.3-70B-Instruct"),
    "3b-target_txt": ("output_meta-llama_Llama-3.2-3B-Instruct", "target_txt"),
    "8b-target_txt": ("output_meta-llama_Llama-3.1-8B-Instruct", "target_txt"),
    "70b-target_txt": ("output_meta-llama_Llama-3.3-70B-Instruct", "target_txt"),
}

for k, dist_tuple in DISTANCE_TUPLES.items(): 
    df[f"dist_{k}"] = df.apply(lambda row: compute_distance(row, col_1=dist_tuple[0], col_2=dist_tuple[1]), axis=1)

In [7]:
for dist in DISTANCE_TUPLES.keys(): 
    display(f"Distance statistics between model {dist}: {round(df[f'dist_{dist}'].min(), 4)} (Min.) - {round(df[f'dist_{dist}'].max(), 4)} (Max.) - {round(df[f'dist_{dist}'].mean(), 4)} (Avg.) - {round(df[f'dist_{dist}'].std(), 4)} (Std.Dev.).")


'Distance statistics between model 3b-8b: 0.0 (Min.) - 1.4643 (Max.) - 0.5181 (Avg.) - 0.2291 (Std.Dev.).'

'Distance statistics between model 3b-70b: 0.0 (Min.) - 1.4568 (Max.) - 0.5089 (Avg.) - 0.2283 (Std.Dev.).'

'Distance statistics between model 8b-70b: 0.0 (Min.) - 1.4267 (Max.) - 0.4431 (Avg.) - 0.2207 (Std.Dev.).'

'Distance statistics between model 3b-target_txt: 0.0 (Min.) - 1.4518 (Max.) - 0.6775 (Avg.) - 0.218 (Std.Dev.).'

'Distance statistics between model 8b-target_txt: 0.0 (Min.) - 1.431 (Max.) - 0.6452 (Avg.) - 0.2152 (Std.Dev.).'

'Distance statistics between model 70b-target_txt: 0.0 (Min.) - 1.4268 (Max.) - 0.5856 (Avg.) - 0.221 (Std.Dev.).'

In [8]:
CHUNK_SIZE = 1000
NUM_CHUNKS = int(len(df) / CHUNK_SIZE)

for idx, chunk in enumerate(range(NUM_CHUNKS)):
    min_idx = idx * CHUNK_SIZE
    max_idx = (idx + 1) * CHUNK_SIZE - 1

    df.loc[min_idx: max_idx].to_parquet(f"data/sentence_transformer_embeddings/data_chunk_{idx}.parquet", index=None)


In [9]:
df_test = pd.read_parquet(f"data/sentence_transformer_embeddings/data_chunk_3.parquet")

In [10]:
display(df_test.loc[3000, "target_txt_embed"])

array([-1.74004666e-03, -2.23796978e-03,  3.74957547e-02, -1.02234706e-02,
        7.76794106e-02,  3.98207754e-02, -1.05262930e-02,  1.29920319e-02,
       -1.04747124e-01, -8.24390631e-03,  2.05260087e-02, -8.63364805e-03,
       -1.71148051e-02,  2.39389371e-02, -1.03658875e-02, -1.16256149e-02,
       -2.88991015e-02, -3.32117788e-02, -1.26792818e-01, -8.64633545e-03,
       -1.03339665e-02, -1.83805656e-02,  2.60190237e-02, -4.86015752e-02,
       -4.20669280e-02, -3.68275195e-02,  6.24631196e-02, -6.46939799e-02,
        2.79593766e-02,  3.66472267e-02, -2.07238570e-02, -1.18935127e-02,
       -5.20533957e-02, -5.71092125e-03, -3.65023986e-02,  1.81146730e-02,
        5.75629696e-02, -1.20269656e-02, -1.02160620e-02,  1.51121290e-02,
        3.35738547e-02, -1.14121931e-02, -2.13363972e-02, -5.16641028e-02,
        6.94205612e-02, -1.47264125e-03, -4.67739627e-03, -8.15976858e-02,
       -1.00058995e-01, -9.14961025e-02,  9.46687534e-02, -5.55385686e-02,
       -2.12671887e-03, -