# Imports

In [None]:
import os
import gc
import json
import faiss
import torch
import pickle
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from dataclasses import dataclass
from collections import defaultdict
from torch_geometric.nn import SAGEConv
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [None]:
@dataclass
class Config:
    DATASET_PATH: str = "./Dataset"
    DATABASE_JSON: str = f"{DATASET_PATH}/database.json"
    TRAIN_CSV: str = f"{DATASET_PATH}/train_set.csv"
    TEST_CSV: str = f"{DATASET_PATH}/test_public.csv"
    IMAGE_DIR: str = f"{DATASET_PATH}/database_images_compressed90"

    PRECOMPUTED_PATH: str = "./eventa_embeddings_Qwen3"
    ARTICLE_EMBEDDINGS_FILE: str = f"{PRECOMPUTED_PATH}/database_embeddings_Qwen3.npy"
    ARTICLE_IDS_FILE: str = f"{PRECOMPUTED_PATH}/database_article_ids_Qwen3.npy"

    GRAPH_SAGE_MODEL: str = "/storage32Tb/jay/mixtralModel/ADL Project/graphSageTrain/Results/best_model.pt"

    EMBEDDING_MODEL: str = "Qwen/Qwen3-Embedding-0.6B"
    CLIP_MODEL: str = "openai/clip-vit-large-patch14"

    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"

    EMBEDDING_BATCH_SIZE: int = 16
    IMG_BATCH: int = 8
    TOP_K_ARTICLES: int = 100
    TOP_K_IMAGES: int = 10

    GRAPH_SAGE_HIDDEN_DIM: int = 1024
    GRAPH_SAGE_EMBED_DIM: int = 512

    TRAIN_VAL_SPLIT: float = 0.95
    RANDOM_SEED: int = 42

config = Config()
print("Device:", config.DEVICE)

# Utility Function

In [None]:
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def normalize_embeddings(E):
    E = E.astype(np.float32)
    faiss.normalize_L2(E)
    return E

def train_val_split(df, split_ratio=0.9, seed=42):
    train = df.sample(frac=split_ratio, random_state=seed)
    val = df.drop(train.index)
    return train.reset_index(drop=True), val.reset_index(drop=True)

# Metric Function

In [None]:
def recall_at_k(pred, gt, k):
    hits = 0
    for p, g in zip(pred, gt):
        if g in p[:k]:
            hits += 1
    return hits / len(gt)


def evaluate_article_retrieval(pred, gt):
    return {
        "Recall@1": recall_at_k(pred, gt, 1),
        "Recall@5": recall_at_k(pred, gt, 5),
        "Recall@10": recall_at_k(pred, gt, 10),
        "Recall@20": recall_at_k(pred, gt, 20),
    }


def evaluate_image_retrieval(pred, gt):
    return {
        "Recall@1": recall_at_k(pred, gt, 1),
        "Recall@5": recall_at_k(pred, gt, 5),
        "Recall@10": recall_at_k(pred, gt, 10),
        "Recall@20": recall_at_k(pred, gt, 20),
    }

# Dataset/DataLoader

In [None]:
database = load_json(config.DATABASE_JSON)
train_df = pd.read_csv(config.TRAIN_CSV)
train_df, val_df = train_val_split(train_df, config.TRAIN_VAL_SPLIT, config.RANDOM_SEED)

article_image_map = defaultdict(list)

for aid, data in database.items():
    imgs = data.get("images", []) or []
    for img in imgs:
        if isinstance(img, str):
            iid = os.path.splitext(os.path.basename(img))[0]
            article_image_map[aid].append(iid)
            continue

        if isinstance(img, dict):
            for key in ["image_id", "id", "file_name", "filename", "path", "file"]:
                if key in img:
                    iid = os.path.splitext(os.path.basename(img[key]))[0]
                    article_image_map[aid].append(iid)
                    break

print("Train:", len(train_df), "Val:", len(val_df))

# GraphSAGE Model Definition

In [None]:
class GraphSAGEEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.conv = SAGEConv(hidden_dim, hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x, edge_index):
        h = F.relu(self.input_proj(x))
        h = self.conv(h, edge_index)
        h = self.mlp(h)
        return F.normalize(h, dim=-1)


class DualEncoderWithGraph(nn.Module):
    def __init__(self, input_dim, hidden_dim, embed_dim):
        super().__init__()
        self.article_encoder = GraphSAGEEncoder(input_dim, hidden_dim, embed_dim)
        self.caption_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.Linear(hidden_dim, embed_dim)
        )
    def encode_article(self, x, edge_index):
        return self.article_encoder(x, edge_index)
    def encode_caption(self, x):
        return F.normalize(self.caption_proj(x), dim=-1)

# Load Precomputed Article Embeddings and IDs

In [None]:
database_embeddings = np.load(config.ARTICLE_EMBEDDINGS_FILE).astype(np.float32)
database_article_ids = np.load(config.ARTICLE_IDS_FILE, allow_pickle=True).tolist()

print("Article embeddings shape:", database_embeddings.shape)
print("Number of article ids:", len(database_article_ids))

# Load GraphSAGE Model + Edge Index

In [None]:
def load_graphsage_model():
    ckpt = torch.load(config.GRAPH_SAGE_MODEL, map_location=config.DEVICE)

    edge_index = ckpt["article_edge_index"].to(config.DEVICE)
    saved_article_ids = ckpt["article_ids"]

    input_dim = database_embeddings.shape[1]
    hidden_dim = config.GRAPH_SAGE_HIDDEN_DIM
    embed_dim = config.GRAPH_SAGE_EMBED_DIM

    model = DualEncoderWithGraph(input_dim, hidden_dim, embed_dim).to(config.DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    return model, edge_index, saved_article_ids


graphsage_model, graphsage_edge_index, graphsage_article_ids = load_graphsage_model()

print("Loaded GraphSAGE model. Edge index:", graphsage_edge_index.shape)

In [None]:
print(len(graphsage_article_ids))
print(graphsage_edge_index.max())
print(database_embeddings.shape[0])

# Compute GraphSAGE Article Embeddings

In [None]:
def encode_articles_in_batches(model, all_emb, edge_index, device, batch_size=512):
    model.eval()
    N = all_emb.shape[0]
    out = np.zeros((N, config.GRAPH_SAGE_EMBED_DIM), dtype=np.float32)

    for start in range(0, N, batch_size):
        end = min(start + batch_size, N)
        batch_nodes = list(range(start, end))

        neighbors = set(batch_nodes)
        rows = edge_index[0].cpu().numpy()
        cols = edge_index[1].cpu().numpy()
        for n in batch_nodes:
            idx = np.where(rows == n)[0]
            neighbors.update(cols[idx].tolist())

        neighbors = sorted(list(neighbors))
        nid = {n: i for i, n in enumerate(neighbors)}
        
        x_sub = torch.from_numpy(all_emb[neighbors]).float().to(device)

        mask = [i for i, src in enumerate(rows) if src in nid and cols[i] in nid]
        sub_edges = torch.tensor([
            [nid[rows[i]] for i in mask],
            [nid[cols[i]] for i in mask]
        ], dtype=torch.long).to(device)

        with torch.no_grad():
            z = model.encode_article(x_sub, sub_edges).cpu().numpy()

        for node in batch_nodes:
            out[node] = z[nid[node]]

        del x_sub, z, sub_edges
        torch.cuda.empty_cache()

    return out

graphsage_article_emb = encode_articles_in_batches(
    graphsage_model, database_embeddings, graphsage_edge_index, config.DEVICE
)

print("GraphSAGE article embeddings:", graphsage_article_emb.shape)

# Build FAISS Index from GraphSAGE Article Embeddings

In [None]:
faiss_index = faiss.IndexFlatIP(config.GRAPH_SAGE_EMBED_DIM)

faiss.normalize_L2(graphsage_article_emb)
faiss_index.add(graphsage_article_emb)

print("FAISS index built with:", faiss_index.ntotal, "articles")

# Load Qwen Caption Encoder + GraphSAGE Caption Projection

In [None]:
caption_encoder = SentenceTransformer(
    config.EMBEDDING_MODEL,
    device=config.DEVICE,
    trust_remote_code=True
)
if config.DEVICE == "cuda":
    caption_encoder.half()


def encode_captions_graphsage(texts, batch=64):
    outs = []
    for i in range(0, len(texts), batch):
        raw = caption_encoder.encode(texts[i:i+batch], convert_to_numpy=True)
        raw_t = torch.from_numpy(raw).float().to(config.DEVICE)
        with torch.no_grad():
            z = graphsage_model.encode_caption(raw_t).cpu().numpy()
        outs.append(z)
    return np.vstack(outs)

# CLIP Image Reranker

In [None]:
clip_model = CLIPModel.from_pretrained(config.CLIP_MODEL).to(config.DEVICE).eval()
clip_processor = CLIPProcessor.from_pretrained(config.CLIP_MODEL)

def path_for(img_id):
    if "." not in img_id:
        return os.path.join(config.IMAGE_DIR, f"{img_id}.jpg")
    return os.path.join(config.IMAGE_DIR, img_id)

@torch.no_grad()
def rerank_images_clip(query, image_ids, top_k):
    images = []
    valid = []

    for iid in image_ids:
        p = path_for(iid)
        try:
            images.append(Image.open(p).convert("RGB"))
            valid.append(iid)
        except:
            pass

    if not images:
        return ["#"] * top_k

    text_inputs = clip_processor(
        text=[query[:200]],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=77
    ).to(config.DEVICE)

    text_emb = clip_model.get_text_features(**text_inputs)

    all_embs = []
    for i in range(0, len(images), config.IMG_BATCH):
        batch = clip_processor(images=images[i:i+config.IMG_BATCH], return_tensors="pt", padding=True).to(config.DEVICE)
        img_feats = clip_model.get_image_features(**batch)
        all_embs.append(img_feats)

    img_emb = torch.cat(all_embs, dim=0)

    text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
    img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)

    sims = (img_emb @ text_emb.T).squeeze(-1).cpu().numpy()
    ranked = sorted(zip(valid, sims), key=lambda x: x[1], reverse=True)

    return [iid for iid, _ in ranked[:top_k]]

# Validation Inference (GraphSAGE → FAISS → CLIP)

In [None]:

val_queries = val_df["caption"].tolist()
val_gt_articles = val_df["retrieved_article_id"].astype(str).tolist()
val_gt_images = val_df["retrieved_image_id"].astype(str).tolist()

Q = encode_captions_graphsage(val_queries, batch=config.EMBEDDING_BATCH_SIZE)
faiss.normalize_L2(Q)

_, idx = faiss_index.search(Q, config.TOP_K_ARTICLES)
val_article_preds = [[database_article_ids[i] for i in row] for row in idx]

val_image_preds = []
for q, arts in tqdm(zip(val_queries, val_article_preds), total=len(val_queries)):
    imgs = []
    for a in arts:
        imgs.extend(article_image_map.get(a, []))
    val_image_preds.append(rerank_images_clip(q, imgs, config.TOP_K_IMAGES))

article_metrics = evaluate_article_retrieval(val_article_preds, val_gt_articles)
image_metrics = evaluate_image_retrieval(val_image_preds, val_gt_images)

print("Article Retrieval:", article_metrics)
print("Image Retrieval:", image_metrics)

# Test Inference + Submission

In [None]:
test_df = pd.read_csv(config.TEST_CSV)

test_queries = test_df["query_text"].tolist()
test_ids = test_df["query_index"].tolist()

Q_test = encode_captions_graphsage(test_queries, batch=config.EMBEDDING_BATCH_SIZE)
faiss.normalize_L2(Q_test)

_, idx = faiss_index.search(Q_test, config.TOP_K_ARTICLES)
test_article_preds = [[database_article_ids[i] for i in row] for row in idx]

test_image_preds = []
for q, arts in tqdm(zip(test_queries, test_article_preds), total=len(test_queries)):
    imgs = []
    for a in arts:
        imgs.extend(article_image_map.get(a, []))
    test_image_preds.append(rerank_images_clip(q, imgs, config.TOP_K_IMAGES))

rows = []
for qid, imgs in zip(test_ids, test_image_preds):
    imgs = imgs + ["#"] * (config.TOP_K_IMAGES - len(imgs))
    rows.append([qid] + imgs)

sub = pd.DataFrame(rows, columns=["query_id"] + [f"image_id_{i+1}" for i in range(config.TOP_K_IMAGES)])
sub.to_csv("submission.csv", index=False)

print("Saved submission.csv")