<a href="https://colab.research.google.com/github/dtim-upc/THOR/blob/main/Reducing-Pairs/Data_Preprocessing_LLMs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install xformers

In [None]:
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer, SimilarityFunction, util
from tqdm.notebook import tqdm
import torch
import random
from collections import Counter, defaultdict

from accelerate import Accelerator

In [None]:
# Define folder paths for input and output
data_folder = "/content/input_data"
output_folder = "/content/output_data"
os.makedirs(output_folder, exist_ok=True)

In [None]:
# Load the pre-trained SentenceTransformer model with DOT_PRODUCT similarity function
device = "cuda" if torch.cuda.is_available() else "cpu"
# model = SentenceTransformer("all-MiniLM-L6-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT, device=device)
# model = SentenceTransformer("all-mpnet-base-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT, device=device)

# This will use one of the top model (no. 5) from MTAB
# Initialize Accelerator for faster computation
accelerator = Accelerator()
model = SentenceTransformer("dunzhang/stella_en_1.5B_v5", trust_remote_code=True, device=device)
model.use_xformers = True  # Enable xformers for better memory management
model = accelerator.prepare(model)

modules.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/397 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/174k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

modeling_qwen.py:   0%|          | 0.00/65.3k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dunzhang/stella_en_1.5B_v5:
- modeling_qwen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/6.17G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

tokenization_qwen.py:   0%|          | 0.00/10.8k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dunzhang/stella_en_1.5B_v5:
- tokenization_qwen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/370 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

2_Dense_1024/config.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/6.30M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.30M [00:00<?, ?B/s]

In [None]:
# Load dataset
file_path = os.path.join(data_folder, 'formatted_data.json')
with open(file_path, 'r', encoding='utf-8') as f:
  data = json.load(f)

In [None]:
query_prompt_name = "s2s_query"

In [None]:
# Generate query embeddings for whole dataset - actually our sentences are queries ;)
query_dataset_embeddings = []
paragraph_ids = []
for entry in tqdm(data, desc="Generating paragraph embeddings"):
  sentence_context = entry.get('sentence_context', [])
  paragraph_id = entry.get('id', None)
  paragraph_ids.append(paragraph_id)

  # Generate sentence embeddings and perform mean-pooling to get paragraph-level embedding
  query_sentence_embeddings = model.encode(sentence_context,  prompt_name=query_prompt_name, convert_to_tensor=True)
  query_paragraph_embedding = query_sentence_embeddings.mean(dim=0)
  query_dataset_embeddings.append(query_paragraph_embedding)

Generating paragraph embeddings:   0%|          | 0/3157 [00:00<?, ?it/s]

In [None]:
# Generate paragraph embeddings for whole dataset
paragraph_dataset_embeddings = []
paragraph_ids = []
for entry in tqdm(data, desc="Generating paragraph embeddings"):
  sentence_context = entry.get('sentence_context', [])
  paragraph_id = entry.get('id', None)
  paragraph_ids.append(paragraph_id)

  # Generate sentence embeddings and perform mean-pooling to get paragraph-level embedding
  sentence_embeddings = model.encode(sentence_context, convert_to_tensor=True)
  paragraph_embedding = sentence_embeddings.mean(dim=0)
  paragraph_dataset_embeddings.append(paragraph_embedding)

Generating paragraph embeddings:   0%|          | 0/3157 [00:00<?, ?it/s]

In [None]:
# Convert embeddings to numpy array for similarity computation
query_dataset_embeddings = torch.stack(query_dataset_embeddings)
paragraph_dataset_embeddings = torch.stack(paragraph_dataset_embeddings)

# Compute pairwise similarities
similarities = model.similarity(query_dataset_embeddings, paragraph_dataset_embeddings)

In [None]:
# Containers for classified groups
positive_group = []
hard_negative_group = []
reserved_list_dict = defaultdict(list)

In [None]:
# Initialize all possible ranges
all_possible_ranges = [f"{r:.2f} to {r + 0.01:.2f}" for r in np.arange(-1.0, 0.3, 0.01)]
reserved_list_dict = {key: [] for key in all_possible_ranges}
range_counts = Counter({key: 0 for key in all_possible_ranges})

In [None]:
# Thresholds for classification: (<0.3) Extreme Negative; (0.3 to <0.7) Hard Negative; (>=0.7) Positive/Highly Similar IItems
positive_threshold = 0.69
negative_threshold = 0.3

In [None]:
# Iterate over the upper triangular part of the similarity matrix (excluding the diagonal)
num_paragraphs = len(paragraph_ids)
progress_bar = tqdm(total=(num_paragraphs * (num_paragraphs - 1)) // 2, desc="Classifying paragraph pairs")
for i in range(num_paragraphs):
  for j in range(i + 1, num_paragraphs):
    similarity_score = similarities[i, j].item()

    if similarity_score > positive_threshold:
      positive_group.append({
        "paragraph_1": paragraph_ids[i],
        "paragraph_2": paragraph_ids[j],
        "similarity": round(similarity_score, 3)
      })
    elif similarity_score >= negative_threshold:
      hard_negative_group.append({
        "paragraph_1": paragraph_ids[i],
        "paragraph_2": paragraph_ids[j],
        "similarity": round(similarity_score, 3)
      })
    else:
      # Determine the range key for Extreme Negative Cases in order to pick representative samples
      range_start = round(similarity_score, 2)
      range_end = round(range_start + 0.01, 2)
      range_key = f"{range_start:.2f} to {range_end:.2f}"
      if range_key in reserved_list_dict:
        reserved_list_dict[range_key].append({
          "paragraph_1": paragraph_ids[i],
          "paragraph_2": paragraph_ids[j],
          "similarity": round(similarity_score, 3)
        })
        range_counts[range_key] += 1
    progress_bar.update(1)

progress_bar.close()

Classifying paragraph pairs:   0%|          | 0/4981746 [00:00<?, ?it/s]

In [None]:
# Compute total number of extreme negative pairs
S = sum(range_counts.values())

# Extreme negative target count
M = len(positive_group) + len(hard_negative_group)
remaining_target_count = M

In [None]:
# Initialize allocations
allocations = {}
capped_ranges = set()
excess_pairs = 0

In [None]:
# Step 1: pre-compute initial allocations
for range_key in range_counts.keys():
  s_i = range_counts[range_key]  # Available data in range
  p_i = s_i / S if S > 0 else 0  # Proportion of total data
  a_i = round(p_i * M)           # Initial allocation
  allocations[range_key] = a_i

In [None]:
# Step 2: Adjust allocations for data availability - ranges having lower samples
for range_key in allocations.keys():
  s_i = range_counts[range_key]
  a_i = allocations[range_key]
  if a_i > s_i:
    excess = a_i - s_i
    allocations[range_key] = s_i
    excess_pairs += excess
    capped_ranges.add(range_key)

In [None]:
# Step 3: Redistribute excess pairs (Optimized)
if excess_pairs > 0:
  # Compute total available capacity
  total_capacity = sum(range_counts[range_key] - allocations[range_key] for range_key in allocations)

  # If there's no capacity left, we cannot redistribute
  if total_capacity == 0:
    excess_pairs = 0  # Cannot distribute further
  else:
    for range_key in allocations.keys():
      available = range_counts[range_key] - allocations[range_key]
      if available > 0:
        # Proportion of available capacity
        p_i = available / total_capacity if total_capacity > 0 else 0
        additional_allocation = min(round(p_i * excess_pairs), available)
        allocations[range_key] += additional_allocation
        excess_pairs -= additional_allocation
        if excess_pairs <= 0:
          break

In [None]:
# Collect selected pairs based on final allocations
pruned_extreme_negative_group = []
progress_bar = tqdm(total=len(allocations), desc="Pruning extreme negative pairs")

for range_key, num_pairs in allocations.items():
  if num_pairs > 0:
    available_pairs = reserved_list_dict[range_key]
    # Ensure we do not sample more pairs than available
    num_pairs = min(num_pairs, len(available_pairs))
    selected_pairs = random.sample(available_pairs, num_pairs)
    pruned_extreme_negative_group.extend(selected_pairs)
  progress_bar.update(1)

progress_bar.close()

Pruning extreme negative pairs:   0%|          | 0/130 [00:00<?, ?it/s]

In [None]:
# Save classified groups into separate JSON files
with open(os.path.join(output_folder, 'positive_group.json'), 'w', encoding='utf-8') as f:
  json.dump(positive_group, f, indent=4, ensure_ascii=False)

with open(os.path.join(output_folder, 'hard_negative_group.json'), 'w', encoding='utf-8') as f:
  json.dump(hard_negative_group, f, indent=4, ensure_ascii=False)

with open(os.path.join(output_folder, 'extreme_negative_group.json'), 'w', encoding='utf-8') as f:
  json.dump(pruned_extreme_negative_group, f, indent=4, ensure_ascii=False)

print("Classification completed and saved to output folder.")

Classification completed and saved to output folder.
