In [None]:
!pip install -q "datasets==2.20.0" transformers sentence-transformers accelerate

In [None]:
import json
import random
import re
import string
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import TapasTokenizer, TapasForQuestionAnswering

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
def canonical_table_name(name: str) -> str:
    if "csv/" in name:
        name = name[name.index("csv/"):]
    if name.endswith(".csv") or name.endswith(".tsv"):
        name = name[:-4]
    return name


SCHEMA_JSON_PATH = "schema.json"  

with open(SCHEMA_JSON_PATH, "r", encoding="utf-8") as f:
    schema_entries = json.load(f)

len(schema_entries), schema_entries[0]

In [None]:
schema_by_canonical = {}

for entry in schema_entries:
    raw_name = entry["table_name"]  
    canon = canonical_table_name(raw_name)
    schema_by_canonical[canon] = entry

len(schema_by_canonical)


In [None]:
wtq = load_dataset("stanfordnlp/wikitablequestions", "random-split-1")
train_ds = wtq["train"]

len(train_ds), train_ds[0]

In [None]:
table_to_indices = defaultdict(list)
tables_in_train = set()

for idx, ex in enumerate(train_ds):
    raw_name = ex["table"]["name"]  
    canon = canonical_table_name(raw_name)
    if canon in schema_by_canonical:
        tables_in_train.add(canon)
        table_to_indices[canon].append(idx)

len(tables_in_train)

In [None]:
canonical_table_ids = sorted(list(tables_in_train))
print("Tables in schema.json:", len(schema_by_canonical))
print("Tables in WTQ train & schema.json intersection:", len(canonical_table_ids))
print("First few table ids:", canonical_table_ids[:5])

table_id_to_df = {}
table_id_to_columns = {}
table_id_to_sample_rows = {}

for tid in canonical_table_ids:
    first_idx = table_to_indices[tid][0]
    ex = train_ds[first_idx]
    header = ex["table"]["header"]
    rows = ex["table"]["rows"]

    df = pd.DataFrame(rows, columns=header)
    table_id_to_df[tid] = df
    table_id_to_columns[tid] = header
    table_id_to_sample_rows[tid] = rows[:2]

len(table_id_to_df), list(table_id_to_df.keys())[:3]


In [None]:
all_candidate_q_indices = []
for tid in canonical_table_ids:
    all_candidate_q_indices.extend(table_to_indices[tid])

print("Total questions over these tables:", len(all_candidate_q_indices))

NUM_QUESTIONS = 200  # change to large number if trying to evaluate on entire set, limited to 200 as example

if len(all_candidate_q_indices) < NUM_QUESTIONS:
    print(f"Warning: fewer than {NUM_QUESTIONS} questions available; using all.")
    eval_indices = all_candidate_q_indices
else:
    eval_indices = random.sample(all_candidate_q_indices, k=NUM_QUESTIONS)

len(eval_indices)


In [None]:
eval_examples = []
for idx in eval_indices:
    ex = train_ds[idx]
    raw_name = ex["table"]["name"]
    canon = canonical_table_name(raw_name)

    eval_examples.append(
        {
            "id": ex["id"],
            "question": ex["question"],
            "answers": ex["answers"],
            "table_id": canon,
        }
    )

print("Prepared eval examples:", len(eval_examples))
eval_examples[0]


In [None]:
table_ids_for_embeddings = []
schema_embed_texts = []

for tid in canonical_table_ids:
    entry = schema_by_canonical[tid]

    table_name = entry["table_name"]
    semantic_name = entry.get("semantic_name", "")
    description = entry.get("description", "")
    columns = entry.get("columns", [])

    text = (
        f"Table ID: {tid}. "
        f"Semantic name: {semantic_name}. "
        f"Description: {description}. "
        f"Columns: {', '.join(columns)}."
    )

    table_ids_for_embeddings.append(tid)
    schema_embed_texts.append(text)

len(schema_embed_texts), schema_embed_texts[0]


In [None]:
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device)

schema_table_embeddings = embedder.encode(
    schema_embed_texts,
    convert_to_numpy=True,
    normalize_embeddings=True,
)

schema_table_embeddings.shape


In [None]:
def rank_tables_for_query(query, table_embeddings, table_ids):
    q_vec = embedder.encode(
        [query],
        convert_to_numpy=True,
        normalize_embeddings=True,
    )[0]

    scores = np.dot(table_embeddings, q_vec)
    order = np.argsort(-scores)

    ranked_ids = [table_ids[i] for i in order]
    ranked_scores = [float(scores[i]) for i in order]
    return ranked_ids, ranked_scores

test_q = eval_examples[0]["question"]
rank_tables_for_query(test_q, schema_table_embeddings, table_ids_for_embeddings)[0][:5]


In [None]:
TAPAS_MODEL_NAME = "google/tapas-large-finetuned-wtq"

tapas_tokenizer = TapasTokenizer.from_pretrained(TAPAS_MODEL_NAME)
tapas_model = TapasForQuestionAnswering.from_pretrained(TAPAS_MODEL_NAME).to(device)
tapas_model.eval()
TAPAS_MODEL_NAME


In [None]:
@torch.inference_mode()
def tapas_answer(df: pd.DataFrame, question: str, max_rows: int = 64):
    table = df.head(max_rows).astype(str)

    inputs = tapas_tokenizer(
        table=table,
        queries=[question],
        padding="max_length",
        return_tensors="pt",
    )

    inputs_on_device = {k: v.to(device) for k, v in inputs.items()}
    outputs = tapas_model(**inputs_on_device)

    inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    logits = outputs.logits.detach().cpu()
    agg_logits = outputs.logits_aggregation.detach().cpu() if outputs.logits_aggregation is not None else None

    if agg_logits is None:
        predicted_coords = [[]]
        predicted_aggs = [0]
    else:
        predicted_coords, predicted_aggs = tapas_tokenizer.convert_logits_to_predictions(
            inputs_cpu, logits, agg_logits
        )

    coords = predicted_coords[0] if predicted_coords else []
    agg_idx = predicted_aggs[0] if predicted_aggs else 0

    if tapas_model.config.aggregation_labels:
        id2agg = tapas_model.config.aggregation_labels
    else:
        id2agg = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}

    aggregation = id2agg.get(agg_idx, "UNKNOWN")

    selected_cells = []
    for row_idx, col_idx in coords:
        value = table.iat[row_idx, col_idx]
        selected_cells.append(str(value))

    if not coords:
        answer_value = "<no cells selected>"
    elif aggregation == "NONE":
        answer_value = ", ".join(selected_cells)
    else:
        try:
            numeric_cells = [float(v) for v in selected_cells]
        except ValueError:
            numeric_cells = None

        if numeric_cells is None:
            answer_value = ", ".join(selected_cells)
        else:
            if aggregation == "SUM":
                answer_value = str(sum(numeric_cells))
            elif aggregation == "AVERAGE":
                answer_value = str(sum(numeric_cells) / len(numeric_cells))
            elif aggregation == "COUNT":
                answer_value = str(len(numeric_cells))
            else:
                answer_value = ", ".join(selected_cells)

    return {
        "answer": answer_value,
        "aggregation": aggregation,
        "coordinates": coords,
        "selected_cells": selected_cells,
    }

test_ex = eval_examples[0]
test_df = table_id_to_df[test_ex["table_id"]]
tapas_answer(test_df, test_ex["question"])


In [None]:
def normalize_answer(s):
    if s is None:
        return ""
    s = s.lower()

    def remove_punc(text):
        return "".join(ch for ch in text if ch not in set(string.punctuation))

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    return white_space_fix(remove_articles(remove_punc(s)))


def exact_match_score(prediction, gold_answers):
    norm_pred = normalize_answer(prediction)
    return int(any(norm_pred == normalize_answer(a) for a in gold_answers))


def f1_score(prediction, gold_answers):
    def f1_single(pred, gold):
        pred_tokens = normalize_answer(pred).split()
        gold_tokens = normalize_answer(gold).split()
        if len(pred_tokens) == 0 or len(gold_tokens) == 0:
            return 0.0
        common = set(pred_tokens) & set(gold_tokens)
        if not common:
            return 0.0
        prec = len(common) / len(pred_tokens)
        rec = len(common) / len(gold_tokens)
        return 2 * prec * rec / (prec + rec)

    if not gold_answers:
        return 0.0

    return max(f1_single(prediction, a) for a in gold_answers)


In [None]:
def update_recall_at_k(ranked_ids, gold_id, k, counter):
    top_k = ranked_ids[:k]
    if gold_id in top_k:
        return counter + 1
    return counter

hits_at_1 = hits_at_10 = hits_at_50 = 0
em_scores = []
f1_scores = []

num_eval = len(eval_examples)

for i, ex in enumerate(eval_examples):
    q = ex["question"]
    gold_table_id = ex["table_id"]
    gold_answers = ex["answers"]

    # ---- Retrieval ----
    ranked_ids, _ = rank_tables_for_query(q, schema_table_embeddings, table_ids_for_embeddings)

    hits_at_1 = update_recall_at_k(ranked_ids, gold_table_id, 1, hits_at_1)
    hits_at_10 = update_recall_at_k(ranked_ids, gold_table_id, 10, hits_at_10)
    hits_at_50 = update_recall_at_k(ranked_ids, gold_table_id, 50, hits_at_50)

    # ---- End-to-end QA (top-1 table) ----
    top1_table_id = ranked_ids[0]
    df = table_id_to_df[top1_table_id]
    ans = tapas_answer(df, q)
    pred = ans["answer"]

    em_scores.append(exact_match_score(pred, gold_answers))
    f1_scores.append(f1_score(pred, gold_answers))

    if (i + 1) % 10 == 0:
        print(f"Processed {i+1}/{num_eval} questions...")

num_eval


*** R@1,10,50 & EM + F1 Evaluation Metrics ***

In [None]:
def pct(x, total):
    return 100.0 * x / total if total > 0 else 0.0

metrics = {
    "R@1": pct(hits_at_1, num_eval),
    "R@10": pct(hits_at_10, num_eval),
    "R@50": pct(hits_at_50, num_eval),
    "EM": 100.0 * float(np.mean(em_scores)) if em_scores else 0.0,
    "F1": 100.0 * float(np.mean(f1_scores)) if f1_scores else 0.0,
}

metrics_df = pd.DataFrame([metrics])
metrics_df


In [None]:
TOP_K_ANALYSIS = 5  

detailed_logs = []

for ex in eval_examples:
    q = ex["question"]
    gold_table_id = ex["table_id"]
    gold_answers = ex["answers"]

    ranked_ids, ranked_scores = rank_tables_for_query(
        q, schema_table_embeddings, table_ids_for_embeddings
    )

    try:
        gold_rank = ranked_ids.index(gold_table_id) + 1
    except ValueError:
        gold_rank = None

    top1_table_id = ranked_ids[0]
    top1_df = table_id_to_df[top1_table_id]
    qa_out = tapas_answer(top1_df, q)
    pred_answer = qa_out["answer"]

    em = exact_match_score(pred_answer, gold_answers)
    f1 = f1_score(pred_answer, gold_answers)

    log_entry = {
        "id": ex["id"],
        "question": q,
        "gold_table_id": gold_table_id,
        "gold_answers": gold_answers,
        "ranked_table_ids": ranked_ids[:TOP_K_ANALYSIS],
        "ranked_scores": ranked_scores[:TOP_K_ANALYSIS],
        "gold_rank": gold_rank,
        "top1_table_id": top1_table_id,
        "pred_answer": pred_answer,
        "em": em,
        "f1": f1,
        "tapas_raw": qa_out,
    }
    detailed_logs.append(log_entry)

len(detailed_logs), detailed_logs[0]


In [None]:
analysis_rows = []
for log in detailed_logs:
    analysis_rows.append(
        {
            "question": log["question"],
            "gold_table": log["gold_table_id"],
            "top1_table": log["top1_table_id"],
            "gold_rank": log["gold_rank"],
            "EM": log["em"],
            "F1": log["f1"],
        }
    )

analysis_df = pd.DataFrame(analysis_rows)

analysis_df["retrieval_error"] = analysis_df["gold_rank"].apply(
    lambda r: (r is None) or (r > 1)
)
analysis_df["reader_error"] = analysis_df.apply(
    lambda row: (row["gold_rank"] == 1) and (row["EM"] == 0),
    axis=1,
)

analysis_df


In [None]:
num_total = len(analysis_df)
num_retrieval_errors = int(analysis_df["retrieval_error"].sum())
num_reader_errors = int(analysis_df["reader_error"].sum())

print(f"Total eval questions: {num_total}")
print(f"Retrieval errors (gold table not at rank 1): {num_retrieval_errors}")
print(f"Reader errors (correct table at rank 1 but EM=0): {num_reader_errors}")

In [None]:
# ERROR ANALYSIS: Visualization of Retriever Errors

NUM_CASES = 5  

retrieval_error_logs = [
    log for log, row in zip(detailed_logs, analysis_df.itertuples())
    if row.retrieval_error
]

print(f"Total retrieval error logs: {len(retrieval_error_logs)}")

for i, log in enumerate(retrieval_error_logs[:NUM_CASES]):
    print("=" * 80)
    print(f"Retrieval Error Example {i+1}")
    print("- Question:")
    print(f"  {log['question']}")
    print("- Gold table ID:", log["gold_table_id"])
    print("- Gold rank among retrieved:", log["gold_rank"])
    print("- Top-1 retrieved table ID:", log["top1_table_id"])
    print()

    gold_schema = schema_by_canonical.get(log["gold_table_id"])
    top1_schema = schema_by_canonical.get(log["top1_table_id"])

    if gold_schema:
        print("Gold table schema:")
        print("  semantic_name:", gold_schema.get("semantic_name"))
        print("  description  :", gold_schema.get("description"))
        print("  columns      :", ", ".join(gold_schema.get("columns", [])))
    else:
        print("Gold table schema: <not found in schema.json>")

    print()

    if top1_schema:
        print("Top-1 retrieved table schema:")
        print("  semantic_name:", top1_schema.get("semantic_name"))
        print("  description  :", top1_schema.get("description"))
        print("  columns      :", ", ".join(top1_schema.get("columns", [])))
    else:
        print("Top-1 table schema: <not found in schema.json>")

    print()

    top1_df = table_id_to_df[log["top1_table_id"]]
    print("Top-1 table preview (head):")
    display(top1_df.head(5))

    print()


In [None]:
# ERROR ANALYSIS: Visualization of Reader Errors

NUM_CASES = 5  
reader_error_logs = [
    log for log, row in zip(detailed_logs, analysis_df.itertuples())
    if row.reader_error
]

print(f"Total reader error logs: {len(reader_error_logs)}")

for i, log in enumerate(reader_error_logs[:NUM_CASES]):
    print("=" * 80)
    print(f"Reader Error Example {i+1}")
    print("- Question:")
    print(f"  {log['question']}")
    print("- Gold table ID:", log["gold_table_id"])
    print("- Gold answers :", log["gold_answers"])
    print("- Top-1 table ID (should match gold):", log["top1_table_id"])
    print("- TAPAS predicted answer:", log["pred_answer"])
    print("- EM:", log["em"], "F1:", log["f1"])
    print()

    gold_schema = schema_by_canonical.get(log["gold_table_id"])
    if gold_schema:
        print("Gold table schema:")
        print("  semantic_name:", gold_schema.get("semantic_name"))
        print("  description  :", gold_schema.get("description"))
        print("  columns      :", ", ".join(gold_schema.get("columns", [])))
    else:
        print("Gold table schema: <not found in schema.json>")

    print()

    df = table_id_to_df[log["gold_table_id"]]
    print("Table preview (head):")
    display(df.head(10))

    print()