<a href="https://colab.research.google.com/github/dtim-upc/LOKI/blob/main/Model-1/Model_1_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Import libraries
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer, SimilarityFunction, util
import torch
from tqdm.notebook import tqdm

In [3]:
# Define folder paths (Update with your file paths)
input_folder = "/content/"
output_folder = "/content/output/"
positive_pairs_path = os.path.join(input_folder, "positive_group.json")
hard_negative_pairs_path = os.path.join(input_folder, "hard_negative_group.json")
extreme_negative_pairs_path = os.path.join(input_folder, "extreme_negative_group.json")
full_data_path = os.path.join(input_folder, "formatted_data_cleaned.json")

# Create output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

In [4]:
# Load bi-encoder model
device = "cuda" if torch.cuda.is_available() else "cpu"
bi_encoder = SentenceTransformer("all-mpnet-base-v2", similarity_fn_name=SimilarityFunction.DOT_PRODUCT, device=device)
print("Model loaded successfully")

Model loaded successfully


In [5]:
# Load data files
print("Loading Original Data...")
with open(full_data_path, 'r', encoding='utf-8') as f:
    full_data = json.load(f)

# Create a dictionary for fast lookup by ID
data_by_id = {entry["id"]: entry for entry in full_data}
print("Data loaded and dictionary created successfully")

Loading Original Data...
Data loaded and dictionary created successfully


In [6]:
# Load pairs for each subset
print("Loading Contrastive Subsets...")
with open(positive_pairs_path, 'r', encoding='utf-8') as f:
    positive_pairs = json.load(f)
with open(hard_negative_pairs_path, 'r', encoding='utf-8') as f:
    hard_negative_pairs = json.load(f)
with open(extreme_negative_pairs_path, 'r', encoding='utf-8') as f:
    extreme_negative_pairs = json.load(f)

Loading Contrastive Subsets...


In [7]:
def create_row_representation(row, headers, title=None, caption=None):
    """Create a string representation for each row by combining headers and their corresponding cell values."""
    if not headers or len(headers) == 0:  # Handle missing or empty headers
        headers = [f"Column {i+1}" for i in range(len(row))]  # Generate generic column names

    if len(row) != len(headers):  # Ensure rows and headers align
        row = row[:len(headers)] if len(row) > len(headers) else row + [''] * (len(headers) - len(row))

    row_representation = ' '.join([f"{headers[i]}: {cell}" for i, cell in enumerate(row)])
    if caption:
        row_representation = f"Caption: {caption} " + row_representation
    if title:
        row_representation = f"Title: {title} " + row_representation
    return row_representation

In [8]:
def save_results_to_json(results, subset_name):
    """Save results to a JSON file."""
    output_path = os.path.join(output_folder, f"{subset_name}_dot_product.json")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    print(f"Dot product similarity for {subset_name} saved to: {output_path}")

In [9]:
def visualize_first_pair(results, subset_name, pairs):
    """Visualize the results for the first row and paragraph in a subset."""
    print(f"\n--- Visualizing the First Result ({subset_name}) ---")
    first_pair = pairs[0]
    table_id = first_pair["table"]
    paragraph_id = first_pair["paragraph"]

    # Retrieve the table and paragraph data
    table_entry = data_by_id[table_id]
    paragraph_entry = data_by_id[paragraph_id]
    headers = table_entry["table"][0] if len(table_entry["table"]) > 1 else None
    row = table_entry["table"][1] if len(table_entry["table"]) > 1 else table_entry["table"][0]  # First row
    sentences = paragraph_entry["sentence_context"]

    # Display the table row representation
    row_representation = create_row_representation(row, headers, table_entry.get("table_title"), table_entry.get("caption"))
    print(f"Row Representation:\n{row_representation}\n")
    print("Paragraph Sentences:")
    for idx, sentence in enumerate(sentences):
        print(f"Sentence {idx+1}: {sentence}")

    # Display similarity scores
    similarity_matrix = results[f"table_{table_id}_paragraph_{paragraph_id}"]
    print("\nSimilarity Scores:")
    for idx, score in enumerate(similarity_matrix[0]):  # First row's similarities
        print(f"Row -> Sentence {idx+1}: {score:.4f}")


In [10]:
def compute_dot_product_for_subset(subset_name, pairs):
    """Compute dot product similarity for a given subset and return results."""
    print(f"\nProcessing subset: {subset_name}")
    dot_product_results = {}

    for pair in tqdm(pairs, desc=f"Processing {subset_name}"):
        table_id = pair["table"]
        paragraph_id = pair["paragraph"]
        table_entry = data_by_id[table_id]
        paragraph_entry = data_by_id[paragraph_id]

        table = table_entry["table"]
        title = table_entry.get("table_title", None)
        caption = table_entry.get("caption", None)
        sentences = paragraph_entry["sentence_context"]

        # Check for headers and rows
        headers = table[0] if len(table) > 1 else None  # Set headers if available
        rows = table[1:] if len(table) > 1 else table   # Treat the entire table as rows if no headers

        # Prepare row texts
        row_texts = [create_row_representation(row, headers, title, caption) for row in rows]
        row_texts = [text for text in row_texts if text.strip()]  # Remove empty strings
        sentences = [sentence for sentence in sentences if sentence.strip()]  # Remove empty sentences

        # Skip if either rows or sentences are empty
        if not row_texts or not sentences:
            print(f"Skipping table {table_id} or paragraph {paragraph_id} due to empty content.")
            continue

        # Compute embeddings
        row_embeddings = bi_encoder.encode(row_texts, convert_to_tensor=True, device=device)
        sentence_embeddings = bi_encoder.encode(sentences, convert_to_tensor=True, device=device)

        # Ensure tensors are on the same device
        row_embeddings = row_embeddings.to(device)
        sentence_embeddings = sentence_embeddings.to(device)

        # Dot product similarity
        dot_product_matrix = util.dot_score(row_embeddings, sentence_embeddings).cpu().numpy()
        dot_product_results[f"table_{table_id}_paragraph_{paragraph_id}"] = dot_product_matrix.tolist()
    return dot_product_results


In [11]:
# Step 1: Process subsets and save results
positive_results = compute_dot_product_for_subset("positive", positive_pairs)
save_results_to_json(positive_results, "positive")

# hard_negative_results = compute_dot_product_for_subset("hard_negative", hard_negative_pairs)
# save_results_to_json(hard_negative_results, "hard_negative")

# extreme_negative_results = compute_dot_product_for_subset("extreme_negative", extreme_negative_pairs)
# save_results_to_json(extreme_negative_results, "extreme_negative")


Processing subset: positive


Processing positive:   0%|          | 0/4585 [00:00<?, ?it/s]

Dot product similarity for positive saved to: /content/output/positive_dot_product.json


In [12]:
# Step 2: Visualize Results for the First Pair
visualize_first_pair(positive_results, "positive", positive_pairs)


--- Visualizing the First Result (positive) ---
Row Representation:
Title: Zeina Mina Caption: Heat 4 Rank: 1. Athlete: Valerie Brisco-Hooks (USA) Time: 51.42

Paragraph Sentences:
Sentence 1: Carlon Blackman (1): She competed in the women's 400 metres at the 1984 Summer Olympics in Los Angeles, finishing in sixth place in her first-round heat, with a time of 54.26.
Sentence 2: Carlon Blackman (2): Carlon Blackman (born 27 March 1965) is a Barbadian sprinter.
Sentence 3: Zeina Mina (1): Zeina Mina (Arabic: ; born January 1, 1963) is a Lebanese Olympic athlete.
Sentence 4: Zeina Mina (2): Athletics Women's 400 metres Round One
Sentence 5: Athletics at the 1984 Summer Olympics (1): At the 1984 Summer Olympics in Los Angeles, 41 events in athletics were contested.

Similarity Scores:
Row -> Sentence 1: 0.5005
Row -> Sentence 2: 0.2170
Row -> Sentence 3: 0.5300
Row -> Sentence 4: 0.7262
Row -> Sentence 5: 0.3266
