In [1]:
%%capture
!pip install transformers
!pip install pandas

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModelForSequenceClassification
from typing import List, Optional
import time


class Qwen3Reranker:
    def __init__(
        self,
        model_name: str = "Qwen/Qwen3-Reranker-4B",
        max_length: int = 2048,
        device: Optional[str] = None,
        dtype: torch.dtype = torch.float16
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=dtype
        ).eval()

        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")

        self.max_length = max_length
        self.prefix = (
            "<|im_start|>system\n"
            "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
            'Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
        )
        self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

    def _format_instruction(self, query: str, doc: str, instruction: Optional[str] = None) -> str:
        if instruction is None:
            instruction = "Given a web search query, retrieve relevant passages that answer the query"
        return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

    def _tokenize_batch(self, texts: List[str]):
        """Batch tokenization for speed."""
        modified_texts = [self.prefix + t + self.suffix for t in texts]
        inputs = self.tokenizer(
            modified_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return inputs

    @torch.no_grad()
    def _score_single(self, input_ids, attention_mask) -> float:
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1, :]
        true_vector = logits[:, self.token_true_id]
        false_vector = logits[:, self.token_false_id]
        scores = torch.stack([false_vector, true_vector], dim=1)
        scores = torch.nn.functional.log_softmax(scores, dim=1)
        prob_yes = scores[:, 1].exp().item()
        return prob_yes

    def score_pairs(self, queries: List[str], docs: List[str], instruction: Optional[str] = None) -> List[float]:
        pairs = [self._format_instruction(q, d, instruction) for q, d in zip(queries, docs)]
        inputs = self._tokenize_batch(pairs)

        scores = []
        for i in range(len(docs)):
            score = self._score_single(
                inputs['input_ids'][i].unsqueeze(0),
                inputs['attention_mask'][i].unsqueeze(0)
            )
            scores.append(score)
            torch.cuda.empty_cache()  # free memory after each forward
        return scores

    def rerank(self, query: str, docs: List[str], instruction: Optional[str] = None) -> tuple[list[str], list[float]]:
        queries = [query] * len(docs)
        scores = self.score_pairs(queries, docs, instruction)
        ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
        if ranked:
            ranked_docs, ranked_scores = zip(*ranked)
        else:
            ranked_docs, ranked_scores = [], []
        return list(ranked_docs), list(ranked_scores)

    def rerank_dict(self, query: str, docs: List[str], instruction: Optional[str] = None) -> List[dict]:
        ranked_docs, ranked_scores = self.rerank(query, docs, instruction)
        return [{"doc": d, "score": s} for d, s in zip(ranked_docs, ranked_scores)]


In [None]:
class indus_reranker:
    def __init__(self, model_name="nasa-impact/nasa-smd-ibm-ranker", device=None, max_length=512):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        self.max_length = max_length

    def rerank(self, query, docs, debug=False):
        """
        Args:
            query (str): The query string
            docs (list of str): List of documents to rerank
            debug (bool): If True, print debug info
        Returns:
            tuple: (sorted_docs, scores)
                - sorted_docs: list of documents sorted by relevance
                - scores: list of corresponding probabilities
        """
        start_time = time.time()

        # Tokenize query-doc pairs
        encodings = self.tokenizer(
            text=[query] * len(docs),      # query repeated
            text_pair=docs,                # each doc as the pair
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)


        # Compute logits and probabilities
        with torch.no_grad():
            logits = self.model(**encodings).logits

            if debug:
                print("Logits shape:", logits.shape)
                print("Sample logits:", logits[:5].cpu().numpy())

            # Handle binary classification (2 logits) or single logit (sigmoid)
            if logits.shape[1] == 2:
                probs = torch.nn.functional.softmax(logits, dim=-1)[:, 1]  # positive class
            else:  # shape [N, 1] → use sigmoid
                probs = torch.sigmoid(logits).squeeze(-1)

            if debug:
                print("Sample probs:", probs[:5].cpu().numpy())

        # Sort by probability descending
        sorted_pairs = sorted(zip(docs, probs.cpu().tolist()), key=lambda x: x[1], reverse=True)
        sorted_docs, scores = zip(*sorted_pairs)

        reranking_time = time.time() - start_time
        if debug:
            print(f"Reranking took {reranking_time:.4f} seconds")
            print("Sorted docs (top 3):", sorted_docs[:3])
            print("Sorted scores (top 3):", scores[:3])

        return list(sorted_docs), list(scores)


In [None]:
#reranker =indus_reranker(max_length=512)
reranker=Qwen3Reranker(max_length=2048)

In [None]:
import pandas as pd
from tqdm import tqdm

input_csv = "" # file with retrieved text
df = pd.read_csv(input_csv)

df["reranked_docs"] = None
df["reranked_scores"] = None


for idx, row in tqdm(df.iterrows(), total=len(df), desc="Reranking rows"):
    query = row["query_text"]
    sep = "<DOC_SEP>"
    # Convert pipe-separated string into list of docs
    retrieved_texts = [doc.strip() for doc in row["retrieved_text"].split(sep)]

    #print("Number of retrieved docs:", len(retrieved_texts))

    # Rerank
    ranked_docs, ranked_scores = reranker.rerank(query, retrieved_texts)

    # Save back to dataframe
    df.at[idx, "reranked_docs"] = sep.join(ranked_docs)
    df.at[idx, "reranked_scores"] = ranked_scores


output_csv = "reranked_qwen_1024_indus_20.csv"
df.to_csv(output_csv, index=False)
print(f"Reranked results saved to {output_csv}")


Reranking rows: 100%|██████████| 1140/1140 [02:10<00:00,  8.75it/s]


Reranked results saved to reranked_qwen_1024_indus_20.csv


#test

In [None]:
import pandas as pd
from collections import Counter
import string
from typing import List, Dict

def normalize_text(text: str) -> str:
    text = text.lower()
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = " ".join(text.split())
    return text

def is_reference_present_fuzzy(reference: str, document: str, threshold: float = 0.8) -> bool:
    ref_tokens = normalize_text(reference).split()
    doc_tokens = normalize_text(document).split()
    if not ref_tokens:
        return False
    matched_tokens = sum(1 for t in ref_tokens if t in doc_tokens)
    fraction_matched = matched_tokens / len(ref_tokens)
    return fraction_matched >= threshold

def compute_token_metrics_single_doc(
    references: List[str],
    retrieved_texts: List[str],
    threshold: float = 0.8
) -> Dict[str, float]:
    all_ref_tokens = []
    matched_tokens = []
    found_count = 0

    for ref in references:
        ref_tokens = normalize_text(ref).split()
        all_ref_tokens.extend(ref_tokens)

        matched_docs = [doc for doc in retrieved_texts if is_reference_present_fuzzy(ref, doc, threshold)]
        if matched_docs:
            found_count += 1
            for doc in matched_docs:
                matched_tokens.extend(normalize_text(doc).split())

    all_doc_tokens = []
    for doc in retrieved_texts:
        all_doc_tokens.extend(normalize_text(doc).split())

    if not all_ref_tokens:
        return {"iou": 0.0, "precision": 0.0, "recall": 0.0, "ref_found_ratio": 0.0}

    ref_counter = Counter(all_ref_tokens)
    match_counter = Counter(matched_tokens)
    doc_counter = Counter(all_doc_tokens)

    intersection_count = sum((ref_counter & match_counter).values())
    ref_count = sum(ref_counter.values())
    doc_count = sum(doc_counter.values())
    union_count = ref_count + doc_count - intersection_count

    iou = intersection_count / union_count if union_count > 0 else 0.0
    precision = intersection_count / doc_count if doc_count > 0 else 0.0
    recall = intersection_count / ref_count if ref_count > 0 else 0.0
    ref_found_ratio = found_count / len(references) if references else 0.0

    return {
        "iou": iou,
        "precision": precision,
        "recall": recall,
        "ref_found_ratio": ref_found_ratio,
    }

def compute_metrics_single_row(
    ref_str: str,
    ret_str: str,
    K: int,
    ref_sep: str = "|",
    doc_sep: str = "<DOC_SEP>",
    token_threshold: float = 0.95,
    rr_threshold: float = 1.0
) -> Dict[str, float]:
    references = [r.strip() for r in str(ref_str).split(ref_sep) if r.strip()]
    retrieved_texts = [doc.strip() for doc in str(ret_str).split(doc_sep) if doc.strip()]

    metrics = compute_token_metrics_single_doc(references, retrieved_texts[:K], threshold=token_threshold)

    rr = 0.0
    found_rank = None
    for rank, doc in enumerate(retrieved_texts[:K], start=1):
        if any(is_reference_present_fuzzy(ref, doc, threshold=rr_threshold) for ref in references):
            rr = 1.0 / rank
            found_rank = rank
            break

    metrics["reciprocal_rank"] = rr
    metrics["ref_rank"] = found_rank  # None if not found
    return metrics


In [None]:
df = pd.read_csv(output_csv)

top_k = 10
metrics_list = []

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Computing metrics"):
    metrics = compute_metrics_single_row(row["references"], row["reranked_docs"], top_k)
    metrics_list.append(metrics)

df_metrics = pd.DataFrame(metrics_list)
df = pd.concat([df, df_metrics], axis=1)

average_metrics = df_metrics.mean().to_dict()
print(f"\nAverage metrics across all rows @{top_k}:")
print(average_metrics)

ref_tok_10 = average_metrics['ref_found_ratio']
mrr_tok_10 = average_metrics['reciprocal_rank']

rank_counts = df_metrics["ref_rank"].value_counts().sort_index()
rank_percents = (rank_counts / len(df_metrics)) * 100

print("\nRank distribution:")
for r, c in rank_counts.items():
    pct = rank_percents[r]
    print(f"Rank {r}: {c} docs ({pct:.2f}%)")