In [25]:
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm   
import torch
import pacmap

In [None]:
# cuda mps or cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print('Device:', device)

In [3]:
data_dir = Path("../../data")
train_claims_file = Path("train_claims_quantemp.json")

In [4]:
train_claims = pl.read_json(data_dir / train_claims_file)

In [None]:
# str length stats of doc and claim 
train_claims['doc'].str.len_chars().describe().rename({"value": "docs"}).join(train_claims['claim'].str.len_chars().describe().rename({"value": "claim"}), on='statistic')

In [42]:
# Transform/encode labels
LE = LabelEncoder()
# Fit and transform the 'label' column
label_encoded = LE.fit_transform(train_claims["label"].to_list())

# Add the encoded labels as a new column
train_claims = train_claims.with_columns(
    pl.Series(name="label_encoded", values=label_encoded)
)

In [6]:
model = "answerdotai/ModernBERT-base"

# context window size
if model in ["answerdotai/ModernBERT-base", "answerdotai/ModernBERT-large"]:
    context_window = 128 #8192 is max -> evidence. 128 is for claim

tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModel.from_pretrained(model).to(device)


# Base configuration for encode_plus
base_config = {
    'add_special_tokens': True,
    'padding': 'max_length',
    'truncation': True,
    'return_attention_mask': True,
    'return_tensors': 'pt',
    'pad_to_max_length': True,
    'max_length': context_window,
}

In [None]:
# Tokenize and encode documents
encoded = list()
for sequence in tqdm(train_claims['claim']):
    encoded.append(
        tokenizer.encode_plus(sequence, **base_config)
        .to(device)
        )

In [8]:
# When embedding evidence, cannot use batch_size > 1 due to long context window
batch_size = 20
dataloader = DataLoader(
            encoded,
            batch_size = batch_size,
        )

In [None]:
# Ensure model is in evaluation mode
model.eval()

cls_embeddings = []

# Iterate through batches in DataLoader
for batch in tqdm(dataloader):
    # Move batch to GPU
    b_input_ids = batch['input_ids'].reshape(-1, context_window)
    b_input_mask = batch['attention_mask'].reshape(-1, context_window)
    # Perform forward pass
    with torch.no_grad():
        outputs = model(b_input_ids, b_input_mask)
        # Extract the CLS token embeddings from the last hidden state
        cls_embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy())


In [11]:
# Concatenate the list of arrays into a single NumPy array
cls_embeddings_np = np.concatenate(cls_embeddings)

# Add the new column to the original DataFrame
embedding_df = pl.DataFrame({"claim_embedding": cls_embeddings_np.tolist()})
train_claims = train_claims.with_columns(embedding_df["claim_embedding"])

In [20]:
# Dimensionality Reduction PacMAP, UMAP, t-SNE

# initializing the pacmap instance
# Setting n_neighbors to "None" leads to an automatic choice shown below in "parameter" section
embedding = pacmap.PaCMAP(n_components=2, n_neighbors=10, MN_ratio=0.5, FP_ratio=2.0) 

claim_embedding_reduced = embedding.fit_transform(cls_embeddings_np, init="pca")

In [None]:
# Visualize
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
y = train_claims['label_encoded']
scatter = ax.scatter(claim_embedding_reduced[:, 0], claim_embedding_reduced[:, 1], cmap="Spectral", c=y, s=0.6)

encoded_to_label_map = {encoded: label for encoded, label in zip(train_claims["label_encoded"].to_list(), train_claims["label"].to_list())}
handles, _ = scatter.legend_elements()
legend_labels = [encoded_to_label_map[int(label)] for label in np.unique(y)]
ax.legend(handles, legend_labels, title="Labels", loc="best", fontsize='small')

# Add axis labels and title
ax.set_title("Scatter Plot of Claim Embeddings")
ax.set_xlabel("Component 1")
ax.set_ylabel("Component 2")

plt.show()