In [2]:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="google/spiqa", filename="test-A/SPIQA_testA.json", repo_type="dataset", local_dir='.')
hf_hub_download(repo_id="google/spiqa", filename="test-A/SPIQA_testA_Images.zip", repo_type="dataset", local_dir='.')

SPIQA_testA_Images.zip:   0%|          | 0.00/121M [00:00<?, ?B/s]

'test-A/SPIQA_testA_Images.zip'

In [5]:
import json
import pandas as pd

with open('test-A/SPIQA_testA.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

# 2. Extract rows with just the three fields
rows = []
for paper in data.values():
    pid      = paper.get('paper_id', '')
    all_figs = ','.join(paper.get('all_figures', {}))
    for qa in paper.get('qa', []):
        rows.append({
            'paper_id':        pid,
            'question':        qa.get('question', ''),
            'reference_figure':qa.get('reference', ''),
            'all_figures':     all_figs
        })

df = pd.DataFrame(rows, columns=[
    'paper_id','question','reference_figure','all_figures'
])

df.head()

Unnamed: 0,paper_id,question,reference_figure,all_figures
0,1611.04684v1,What are the main differences between the educ...,1611.04684v1-Table1-1.png,"1611.04684v1-Table3-1.png,1611.04684v1-Table2-..."
1,1611.04684v1,Which model performs the best for response sel...,1611.04684v1-Table4-1.png,"1611.04684v1-Table3-1.png,1611.04684v1-Table2-..."
2,1611.04684v1,Which model performs best on the Ubuntu datase...,1611.04684v1-Table5-1.png,"1611.04684v1-Table3-1.png,1611.04684v1-Table2-..."
3,1611.04684v1,What is the role of the knowledge gates in the...,1611.04684v1-Figure1-1.png,"1611.04684v1-Table3-1.png,1611.04684v1-Table2-..."
4,1611.04684v1,How does the average number of answers per que...,1611.04684v1-Table2-1.png,"1611.04684v1-Table3-1.png,1611.04684v1-Table2-..."


In [7]:
import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel

In [12]:
def load_clip_model(model_dir: str, device: torch.device):
    """
    Load your fine-tuned CLIP model and processor from disk.
    """
    processor = CLIPProcessor.from_pretrained(model_dir)
    model     = CLIPModel.from_pretrained(model_dir).to(device)
    model.eval()
    return processor, model

def compute_image_embeddings(image_paths, processor, model, device):
    """
    Given a list of image file paths, load and preprocess them in a batch,
    run through CLIP, and return normalized image embeddings + their filenames.
    """
    images = [Image.open(p).convert("RGB") for p in image_paths]
    inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        img_feats = model.get_image_features(**inputs)
    # L2-normalize
    img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
    return img_feats, [os.path.basename(p) for p in image_paths]

def evaluate_retrieval(df: pd.DataFrame, image_root: str, processor, model, device):
    top1 = top3 = top5 = total = 0

    # Precompute image embeddings per paper_id (unchanged)
    paper_embeddings = {}
    for paper_id, group in df.groupby("paper_id"):
        figs = group["all_figures"].iloc[0].split(",")
        paths = [os.path.join(image_root, paper_id, fig) for fig in figs]
        feats, names = compute_image_embeddings(paths, processor, model, device)
        paper_embeddings[paper_id] = {"feats": feats, "names": names}

    # Loop over every question
    for _, row in tqdm(df.iterrows(), total=len(df)):
        q_text   = row["question"]
        ref      = row["reference_figure"]
        paper_id = row["paper_id"]

        # —– FIXED: add truncation and max_length here —–
        text_inputs = processor(
            text=[q_text],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=processor.tokenizer.model_max_length
        ).to(device)

        with torch.no_grad():
            txt_feats = model.get_text_features(**text_inputs)
        txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)

        img_feats = paper_embeddings[paper_id]["feats"]
        img_names = paper_embeddings[paper_id]["names"]

        sims = (txt_feats @ img_feats.T).squeeze(0).cpu().numpy()
        ranked_idxs  = np.argsort(-sims)
        ranked_names = [img_names[i] for i in ranked_idxs]

        total += 1
        if ref in ranked_names[:1]: top1 += 1
        if ref in ranked_names[:3]: top3 += 1
        if ref in ranked_names[:5]: top5 += 1

    print(f"Total: {total}")
    print(f"Top-1 Acc: {top1/total:.4f}")
    print(f"Top-3 Acc: {top3/total:.4f}")
    print(f"Top-5 Acc: {top5/total:.4f}")


In [9]:
IMAGE_ROOT = "test-A/SPIQA_testA_Images"
MODEL_DIR  = "clip_finetuned"
df["all_figures"]= df["all_figures"].str.strip()
df["reference_figure"]  = df["reference_figure"].str.strip()

In [10]:
DEVICE     = torch.device("mps")
processor, model = load_clip_model(MODEL_DIR, DEVICE)

In [14]:
evaluate_retrieval(df, IMAGE_ROOT, processor, model, DEVICE)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 666/666 [00:08<00:00, 74.84it/s]

Total: 666
Top-1 Acc: 0.2943
Top-3 Acc: 0.5661
Top-5 Acc: 0.7162





In [25]:
sizes = []
for path in os.listdir("test-A/SPIQA_testA_Images"):
    images = os.listdir(os.path.join("test-A/SPIQA_testA_Images", path))
    print(images)
    break
    sizes.append(len(images))

['1901.00056v2-Figure2-1.png', '1901.00056v2-Figure1-1.png', '1901.00056v2-Figure3-1.png', '1901.00056v2-Table5-1.png', '1901.00056v2-Table1-1.png', '1901.00056v2-Table3-1.png', '1901.00056v2-Table4-1.png', '1901.00056v2-Table6-1.png', '1901.00056v2-Table2-1.png']


In [22]:
len(sizes)

118

In [23]:
max(sizes)

29

In [24]:
min(sizes)

1