In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from tqdm import tqdm
import pandas as pd

In [None]:
# -----------------------
# 1. Load XLM-R (large)
# -----------------------
def load_xlmr(model_name="FacebookAI/xlm-roberta-large", device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    return tokenizer, model, device

In [None]:
# -----------------------
# 2. Get embeddings
# -----------------------
def get_batch_embeddings(texts, tokenizer, model, device, pooling="mean"):
    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=128,
    ).to(device)

    with torch.no_grad():
        outputs = model(**encodings, output_hidden_states=True)
        hidden_states = outputs.last_hidden_state

    if pooling == "cls":
        embeddings = hidden_states[:, 0, :]
    elif pooling == "mean":
        mask = encodings["attention_mask"].unsqueeze(-1).expand(hidden_states.size())
        masked_hidden = hidden_states * mask
        embeddings = masked_hidden.sum(dim=1) / mask.sum(dim=1)
    else:
        raise ValueError("pooling must be 'cls' or 'mean'")

    return embeddings.cpu()

In [None]:
# -----------------------
# 3. Shift to strictly positive
# -----------------------
def shift_to_positive(embeddings, eps=1e-6):
    if isinstance(embeddings, torch.Tensor):
        min_vals = embeddings.min(dim=0).values
        shift = -min_vals + eps
        return embeddings + shift
    elif isinstance(embeddings, np.ndarray):
        min_vals = embeddings.min(axis=0)
        shift = -min_vals + eps
        return embeddings + shift
    else:
        raise TypeError("embeddings must be torch.Tensor or np.ndarray")

In [None]:
# -----------------------
# 4. Scale + Round to Integers
# -----------------------
def scale_and_round(embeddings, scale=15560):
    if isinstance(embeddings, torch.Tensor):
        return torch.round(embeddings * scale).int()
    elif isinstance(embeddings, np.ndarray):
        return np.round(embeddings * scale).astype(int)
    else:
        raise TypeError("embeddings must be torch.Tensor or np.ndarray")

In [None]:
# -----------------------
# 5. Full Pipeline
# -----------------------
def embed_csv(
    csv_path,
    text_col="Sentence",
    batch_size=16,
    pooling="mean",
    save_path="embeddings_int.csv",
    model_name="FacebookAI/xlm-roberta-large",
    scale=15560,
):
    df = pd.read_csv(csv_path)

    if text_col not in df.columns:
        raise ValueError(f"CSV must have a '{text_col}' column.")

    texts = df[text_col].astype(str).tolist()
    tokenizer, model, device = load_xlmr(model_name)

    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"):
        batch = texts[i : i + batch_size]
        batch_embeds = get_batch_embeddings(batch, tokenizer, model, device, pooling)
        all_embeddings.append(batch_embeds)

    all_embeddings = torch.cat(all_embeddings, dim=0)

    # Step 1: shift to positive
    positive_embeddings = shift_to_positive(all_embeddings)

    # Step 2: scale + round to integers
    int_embeddings = scale_and_round(positive_embeddings, scale=scale)

    # Save to CSV
    emb_cols = [f"emb_{i}" for i in range(int_embeddings.shape[1])]
    int_emb_df = pd.DataFrame(int_embeddings.numpy(), columns=emb_cols)
    int_emb_df.to_csv(save_path, index=False)

    print(f"Saved integer embeddings to {save_path}")
    print("Integer embeddings shape:", int_embeddings.shape)
    print("Example row:", int_embeddings[0][:10].tolist())

    return int_embeddings

In [None]:
# -----------------------
# Usage
# -----------------------
if __name__ == "__main__":
    csv_file = "../../data/En-Ba-Dataset(20k_4)/dataset_cleaned.csv"

    embeddings = embed_csv(
        csv_file,
        text_col="Sentence",
        batch_size=4,
        pooling="mean",
        save_path="embeddings_int.csv",
    )

    print("Final embeddings shape:", embeddings.shape)

In [None]:
# -----------------------
# 5. Inspect embedding stats
# -----------------------
print("Embedding stats before scaling/rounding:")
print("Min:", embeddings.min().item())
print("Max:", embeddings.max().item())