<a href="https://colab.research.google.com/github/dtim-upc/LOKI/blob/main/Initial-All-Pairs/Baseline_Cross_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import numpy as np
from sentence_transformers import CrossEncoder, SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
import torch
from tqdm.notebook import tqdm

In [None]:
# Define folder paths for input and output
data_folder = "/content/input_data"
output_folder = "/content/output_data"

In [None]:
# Make output directory if it does not exist
os.makedirs(output_folder, exist_ok=True)

In [None]:
# Load models
# cross_encoder_model_name = 'cross-encoder/stsb-distilroberta-base'
cross_encoder_model_name = 'cross-encoder/stsb-distilroberta-base'
cross_encoder = CrossEncoder(cross_encoder_model_name, device='cuda' if torch.cuda.is_available() else 'cpu')
# bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_data(file_path, num_samples=None):
    """Load data from a single json file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if num_samples is not None:
        return data[:num_samples]
    return data

In [None]:
# Function to create row representation for table

def create_row_representation(row, headers):
    """Create a string representation for each row by combining headers and their corresponding cell values."""
    # Report if any column header is empty
    if any(header == '' for header in headers):
        print(f"Warning: Empty header found. Headers: {headers}")
    # Ensure the row and headers have the same length
    if len(row) != len(headers):
        row = row[:len(headers)] if len(row) > len(headers) else row + [''] * (len(headers) - len(row))
    return ' '.join([f"{headers[i]}: {cell}" for i, cell in enumerate(row)])


In [None]:
# Function to prepare dataset

def prepare_dataset(data):
    """Prepare row-sentence pairs from the dataset, considering all possible table-paragraph combinations."""
    dataset = []
    failed_examples = []

    for table_example_idx, table_example in enumerate(data):
        table = table_example['table']
        headers = table[0]
        rows = table[1:]
        if any(header == '' for header in headers):
            failed_examples.append(headers)
            print(f"Warning: Empty header in example with headers: {headers}")

        for paragraph_example_idx, paragraph_example in enumerate(data):
            paragraph = paragraph_example['sentence_context']
            for row in rows:
                row_text = create_row_representation(row, headers)
                for sentence in paragraph:
                    dataset.append((row_text, sentence))  # Pair rows with sentences from all examples

    if failed_examples:
        print(f"Total examples with empty headers: {len(failed_examples)}")
    return dataset


In [None]:
# Scoring function using cross-encoder

def compute_similarity_cross_encoder(row_sentence_pairs):
    """Compute similarity scores using the cross-encoder model."""
    scores = []
    batch_size = 16  # Set a batch size to better utilize GPU
    for start_idx in tqdm(range(0, len(row_sentence_pairs), batch_size), desc="Scoring with Cross-Encoder"):
        batch_pairs = row_sentence_pairs[start_idx:start_idx + batch_size]
        # Convert batch_scores to a list to ensure compatibility with JSON serialization
        batch_scores = cross_encoder.predict(batch_pairs).tolist()
        scores.extend(batch_scores)
    return scores


In [None]:
# Function to generate labeled dataset

def generate_labeled_dataset(data):
    """Generate a labeled dataset with aggregated similarity scores for all table-paragraph pairs."""
    row_sentence_pairs = prepare_dataset(data)
    similarity_scores = compute_similarity_cross_encoder(row_sentence_pairs)

    # Group scores by table-paragraph pairs
    labeled_data = []
    index = 0
    for table_example_idx, table_example in enumerate(data):
        table = table_example['table']
        num_rows = len(table[1:])  # Number of rows excluding headers

        for paragraph_example_idx, paragraph_example in enumerate(data):
            paragraph = paragraph_example['sentence_context']
            num_sentences = len(paragraph)

            # Collect all scores for the current table-paragraph pair
            scores = []
            for row_idx in range(num_rows):
                for sentence_idx in range(num_sentences):
                    scores.append(similarity_scores[index])
                    index += 1

            # Aggregate the scores for the current table-paragraph pair
            avg_score = sum(scores) / len(scores) if scores else 0
            labeled_data.append({
                "table": table,  # Include full table
                "paragraph": paragraph,  # Include full paragraph
                "score": round(avg_score, 3)  # keep score between [0, 1] rounded to 3 decimals
            })

    return labeled_data

In [None]:
def save_labeled_dataset(labeled_data, output_path):
    """Save the labeled dataset to a json file."""
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(labeled_data, f, indent=4, ensure_ascii=False)

In [None]:
# Run the main pipeline in a notebook
if __name__ == "__main__":
    # Load data from input folder
    data_file_path = os.path.join(data_folder, 'data.json')
    data = load_data(data_file_path)

    # Generate labeled dataset
    labeled_data = generate_labeled_dataset(data)
    # Verify that all table-paragraph pairs are included in the labeled data
    expected_pairs = len(data) * len(data)
    if len(labeled_data) != expected_pairs:
        print(f"Warning: Mismatch in number of labeled entries ({len(labeled_data)}) and expected table-paragraph pairs ({expected_pairs})")

    # Save output dataset
    output_path = os.path.join(output_folder, 'labeled_dataset.json')
    save_labeled_dataset(labeled_data, output_path)
    print(f"Labeled dataset saved to {output_path}")

Scoring with Cross-Encoder:   0%|          | 0/1 [00:00<?, ?it/s]

Labeled dataset saved to /content/output_data/labeled_dataset.json
