<a href="https://colab.research.google.com/github/dtim-upc/LOKI/blob/main/Initial-All-Pairs/Baseline_Bi_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install xformers



In [None]:
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer, util
from accelerate import Accelerator
from sklearn.metrics.pairwise import cosine_similarity
import torch
from tqdm.notebook import tqdm

In [None]:
# Define folder paths for input and output
data_folder = "/content/input_data"
output_folder = "/content/output_data"

In [None]:
# Make output directory if it does not exist
os.makedirs(output_folder, exist_ok=True)

In [None]:
# Initialize Accelerator for faster computation
accelerator = Accelerator()

In [None]:
# Load bi-encoder model
bi_encoder = SentenceTransformer("dunzhang/stella_en_400M_v5", trust_remote_code=True, device='cuda' if torch.cuda.is_available() else 'cpu')
bi_encoder.use_xformers = True  # Enable xformers for better memory management
bi_encoder = accelerator.prepare(bi_encoder)

Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: ['new.pooler.dense.bias', 'new.pooler.dense.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def load_data(file_path, num_samples=None):
    """Load data from a single json file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    if num_samples is not None:
        return data[:num_samples]
    return data

In [None]:
# Function to create row representation for table

def create_row_representation(row, headers, title=None, caption=None):
    """Create a string representation for each row by combining headers and their corresponding cell values."""
    # Ensure the row and headers have the same length
    if len(row) != len(headers):
        row = row[:len(headers)] if len(row) > len(headers) else row + [''] * (len(headers) - len(row))
    row_representation = ' '.join([f"{headers[i]}: {cell}" for i, cell in enumerate(row)])
    if title:
        row_representation = f"Title: {title} " + row_representation
    if caption:
        row_representation = row_representation + f" Caption: {caption}"
    return row_representation

In [None]:
# Function to prepare dataset

def prepare_dataset(data):
    """Prepare row-sentence pairs from the dataset, including table title and caption, considering all possible table-paragraph combinations."""
    dataset = []
    failed_examples = []

    for table_example_idx, table_example in enumerate(data):
        table = table_example['table']
        headers = table[0]
        rows = table[1:]
        title = table_example.get('title', None)
        caption = table_example.get('caption', None)

        # Add title and caption once per table
        table_representation = ""
        if title:
            table_representation += f"Title: {title} "
        if caption:
            table_representation += f"Caption: {caption} "

        if any(header == '' for header in headers):
            failed_examples.append(headers)
            print(f"Warning: Empty header in example with headers: {headers}")

        for paragraph_example_idx, paragraph_example in enumerate(data):
            paragraph = paragraph_example['sentence_context']
            for row in rows:
                row_text = table_representation + create_row_representation(row, headers)
                for sentence in paragraph:
                    dataset.append((row_text, sentence))  # Pair rows with sentences from all examples

    if failed_examples:
        print(f"Total examples with empty headers: {len(failed_examples)}")
    return dataset

In [None]:
# Scoring function using bi-encoder

def compute_similarity_bi_encoder(row_sentence_pairs):
    """Compute similarity scores using the bi-encoder model."""
    # Embed rows and sentences separately
    rows, sentences = zip(*row_sentence_pairs)
    with accelerator.autocast():
        row_embeddings = bi_encoder.encode(rows, convert_to_tensor=True, device='cuda' if torch.cuda.is_available() else 'cpu')
        sentence_embeddings = bi_encoder.encode(sentences, convert_to_tensor=True, device='cuda' if torch.cuda.is_available() else 'cpu')

    # Compute cosine similarity between rows and sentences
    scores = util.pytorch_cos_sim(row_embeddings, sentence_embeddings).diagonal().tolist()
    return scores

In [None]:
# Function to generate labeled dataset

def generate_labeled_dataset(data):
    """Generate a labeled dataset with aggregated similarity scores for all table-paragraph pairs."""
    row_sentence_pairs = prepare_dataset(data)
    similarity_scores = compute_similarity_bi_encoder(row_sentence_pairs)

    # Group scores by table-paragraph pairs
    labeled_data = []
    index = 0
    progress_bar = tqdm(total=len(data), desc="Processing tables")

    for table_example_idx, table_example in enumerate(data):
        table = table_example['table']
        num_rows = len(table[1:])  # Number of rows excluding headers

        for paragraph_example_idx, paragraph_example in enumerate(data):
            paragraph = paragraph_example['sentence_context']
            num_sentences = len(paragraph)

            # Collect all scores for the current table-paragraph pair
            scores = []
            for row_idx in range(num_rows):
                for sentence_idx in range(num_sentences):
                    scores.append(similarity_scores[index])
                    index += 1

            # Aggregate the scores for the current table-paragraph pair
            avg_score = sum(scores) / len(scores) if scores else 0
            labeled_data.append({
                "id": table_example.get('id', None),  # Include original ID
                "title": table_example.get('title', None),  # Include title
                "caption": table_example.get('caption', None),  # Include caption
                "table": table,  # Include full table
                "paragraph": paragraph,  # Include full paragraph
                "score": round(avg_score, 3)  # keep score between [0, 1] rounded to 3 decimals
            })
        progress_bar.update(1)

    progress_bar.close()
    return labeled_data

In [None]:
# Function to save labeled dataset

def save_labeled_dataset(labeled_data, output_path):
    """Save the labeled dataset to a json file."""
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(labeled_data, f, indent=4, ensure_ascii=False)

In [None]:
# Run the main pipeline in a notebook
if __name__ == "__main__":
    accelerator.wait_for_everyone()
    # Load data from input folder
    data_file_path = os.path.join(data_folder, 'data.json')
    data = load_data(data_file_path, num_samples=3)

    # Generate labeled dataset
    labeled_data = generate_labeled_dataset(data)
    # Verify that all table-paragraph pairs are included in the labeled data
    expected_pairs = len(data) * len(data)
    if len(labeled_data) != expected_pairs:
        print(f"Warning: Mismatch in number of labeled entries ({len(labeled_data)}) and expected table-paragraph pairs ({expected_pairs})")

    # Save output dataset
    output_path = os.path.join(output_folder, 'labeled_dataset.json')
    save_labeled_dataset(labeled_data, output_path)
    print(f"Labeled dataset saved to {output_path}")

Processing tables:   0%|          | 0/3 [00:00<?, ?it/s]

Labeled dataset saved to /content/output_data/labeled_dataset.json
