<a href="https://colab.research.google.com/github/dtim-upc/LOKI/blob/main/Reducing-Pairs/Dataset_Pair_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer, SimilarityFunction, util
from tqdm.notebook import tqdm
import torch
import random
from collections import Counter, defaultdict

In [None]:
# Define folder paths for input and output
data_folder = "/content/input_data"
output_folder = "/content/output_data"
os.makedirs(output_folder, exist_ok=True)

In [None]:
# Load the pre-trained SentenceTransformer model with DOT_PRODUCT similarity function
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer("all-MiniLM-L6-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT, device=device)
# model = SentenceTransformer("all-mpnet-base-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT, device=device)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
# Load dataset
file_path = os.path.join(data_folder, 'formatted_data_cleaned.json')
with open(file_path, 'r', encoding='utf-8') as f:
  data = json.load(f)

In [None]:
# Generate embeddings for each paragraph
embeddings = []
paragraph_ids = []
for entry in tqdm(data, desc="Generating paragraph embeddings"):
  sentence_context = entry.get('sentence_context', [])
  paragraph_id = entry.get('id', None)
  paragraph_ids.append(paragraph_id)

  # Generate sentence embeddings and perform mean-pooling to get paragraph-level embedding
  sentence_embeddings = model.encode(sentence_context, convert_to_tensor=True)
  paragraph_embedding = sentence_embeddings.mean(dim=0)
  embeddings.append(paragraph_embedding)

Generating paragraph embeddings:   0%|          | 0/3157 [00:00<?, ?it/s]

In [None]:
# Convert embeddings to numpy array for similarity computation
embeddings = torch.stack(embeddings)

# Compute pairwise similarities
similarities = model.similarity(embeddings, embeddings)

In [None]:
# Containers for classified groups
positive_group = []
hard_negative_group = []
reserved_list_dict = defaultdict(list)

# Initialize all possible ranges
all_possible_ranges = [f"{r:.2f} to {r + 0.01:.2f}" for r in np.arange(-1.0, 0.3, 0.01)]
reserved_list_dict = {key: [] for key in all_possible_ranges}
range_counts = Counter({key: 0 for key in all_possible_ranges})

In [None]:
# Thresholds for classification: (<0.3) Extreme Negative; (0.3 to <0.7) Hard Negative; (>=0.7) Positive/Highly Similar Items
positive_threshold = 0.69
negative_threshold = 0.3

In [None]:
# Step 1: Add already provided pairs (manual pairs) - Given by the Dataset
manual_positive_pairs = []
for entry in data:
    paragraph_id = entry.get("id")
    manual_positive_pairs.append({
        "paragraph_1": paragraph_id,
        "paragraph_2": paragraph_id,
        "similarity": 1.0  # Perfect similarity for provided pairs
    })

positive_group.extend(manual_positive_pairs)

In [None]:
# Step 2: Add similarity-based pairs from computed similarities
num_paragraphs = len(paragraph_ids)
total_comparisons = (num_paragraphs * (num_paragraphs - 1)) // 2
progress_bar = tqdm(total=total_comparisons, desc="Adding similarity-based pairs")
for i in range(num_paragraphs):
    for j in range(i + 1, num_paragraphs):
        similarity_score = similarities[i, j].item()
        if similarity_score > positive_threshold:
            positive_group.append({
                "paragraph_1": paragraph_ids[i],
                "paragraph_2": paragraph_ids[j],
                "similarity": round(similarity_score, 3)
            })
        elif similarity_score >= negative_threshold:
            hard_negative_group.append({
                "paragraph_1": paragraph_ids[i],
                "paragraph_2": paragraph_ids[j],
                "similarity": round(similarity_score, 3)
            })
        else:
            # Determine the range key for Extreme Negative Cases in order to pick representative samples
            range_start = round(similarity_score, 2)
            range_end = round(range_start + 0.01, 2)
            range_key = f"{range_start:.2f} to {range_end:.2f}"
            if range_key in reserved_list_dict:
                reserved_list_dict[range_key].append({
                    "paragraph_1": paragraph_ids[i],
                    "paragraph_2": paragraph_ids[j],
                    "similarity": round(similarity_score, 3)
                })
                range_counts[range_key] += 1
        progress_bar.update(1)
progress_bar.close()

Adding similarity-based pairs:   0%|          | 0/4981746 [00:00<?, ?it/s]

In [None]:
# Step 3: Prune extreme negative group
S = sum(range_counts.values())
M = len(positive_group) + len(hard_negative_group)
remaining_target_count = M
allocations = {}
capped_ranges = set()
excess_pairs = 0

In [None]:
# Pre-compute allocations
for range_key in range_counts.keys():
    s_i = range_counts[range_key]  # Available data in range
    p_i = s_i / S if S > 0 else 0  # Proportion of total data
    a_i = round(p_i * M)           # Initial allocation
    allocations[range_key] = a_i

In [None]:
# Adjust allocations for availability
for range_key in allocations.keys():
    s_i = range_counts[range_key]
    a_i = allocations[range_key]
    if a_i > s_i:
        excess = a_i - s_i
        allocations[range_key] = s_i
        excess_pairs += excess
        capped_ranges.add(range_key)

In [None]:
# Redistribute excess pairs
if excess_pairs > 0:
    total_capacity = sum(range_counts[range_key] - allocations[range_key] for range_key in allocations)
    if total_capacity > 0:
        for range_key in allocations.keys():
            available = range_counts[range_key] - allocations[range_key]
            if available > 0:
                p_i = available / total_capacity
                additional_allocation = min(round(p_i * excess_pairs), available)
                allocations[range_key] += additional_allocation
                excess_pairs -= additional_allocation
                if excess_pairs <= 0:
                    break

In [None]:
# Collect pruned extreme negatives
pruned_extreme_negative_group = []
for range_key, num_pairs in allocations.items():
    if num_pairs > 0:
        available_pairs = reserved_list_dict[range_key]
        num_pairs = min(num_pairs, len(available_pairs))
        selected_pairs = random.sample(available_pairs, num_pairs)
        pruned_extreme_negative_group.extend(selected_pairs)

In [None]:
# Print group statistics
print("Total Positive Pairs:", len(positive_group))
print("Total Hard Negative Pairs:", len(hard_negative_group))
print("Total Extreme Negative Pairs:", len(pruned_extreme_negative_group))

Total Positive Pairs: 3713
Total Hard Negative Pairs: 30756
Total Extreme Negative Pairs: 34469


In [None]:
# Save groups to output folder
with open(os.path.join(output_folder, 'positive_group.json'), 'w', encoding='utf-8') as f:
    json.dump(positive_group, f, indent=4, ensure_ascii=False)

with open(os.path.join(output_folder, 'hard_negative_group.json'), 'w', encoding='utf-8') as f:
    json.dump(hard_negative_group, f, indent=4, ensure_ascii=False)

with open(os.path.join(output_folder, 'extreme_negative_group.json'), 'w', encoding='utf-8') as f:
    json.dump(pruned_extreme_negative_group, f, indent=4, ensure_ascii=False)

print("Classification completed and saved to output folder.")

Classification completed and saved to output folder.
