In [None]:
import json
import numpy as np
import random
import torch
from trigger_passage_utils import JointOptimiser

# Set seed for reproducibility
seed = 123
trigger_lengths = [5]
passage_lengths = [30]
num_trigger_passage_pairs = 5
log_file = "joint_trigger_passage.tsv"

# Set seed across all relevant libraries
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Instantiate the joint optimiser
joint_opt = JointOptimiser(
    retriever_name="facebook/contriever",
    seed=seed
)

# Load training and testing queries
with open("./nq/queries.jsonl") as f:
    lines = [json.loads(line)["text"] for line in f]
    train_queries = lines[:500]
    test_queries = lines[500:1000]

# Load corpus embeddings
corpus_embeddings = torch.load("corpus_embeddings_10000.pt", map_location=joint_opt.device).to(joint_opt.device)

def get_rank_from_id(query_emb: torch.Tensor, poison_id: str, id_list: list, all_embeddings: torch.Tensor) -> int:
    """
    Compute the retrieval rank of a poisoned passage based on a string identifier.

    Args:
        query_emb (torch.Tensor): The embedding of the query.
        poison_id (str): The unique ID string of the poisoned passage.
        id_list (list): List of passage IDs corresponding to the embeddings.
        all_embeddings (torch.Tensor): Full corpus embeddings including the poisoned passage.

    Returns:
        int: The 1-based rank of the poisoned passage.
    """
    sims = torch.matmul(all_embeddings, query_emb.T).squeeze(1)
    sorted_idxs = torch.argsort(sims, descending=True)
    for rank, idx in enumerate(sorted_idxs.tolist(), start=1):
        if id_list[idx] == poison_id:
            return rank
    return len(id_list)

# Open log file for writing results
with open(log_file, "w", encoding="utf-8") as fout:
    fout.write("trigger_len\tpassage_len\ttrial\ttrigger\tpassage\titerations\ttriggered_rank\tclean_rank\n")

    for trig_len in trigger_lengths:
        for pass_len in passage_lengths:
            for trial in range(num_trigger_passage_pairs):

                # Run joint optimisation to obtain trigger and passage
                (trigger_ids, passage_ids), n_iter = joint_opt.generate_joint_trigger_and_passage(
                    clean_queries=train_queries,
                    trigger_len=trig_len,
                    passage_len=pass_len,
                    K=30,
                    T=200
                )

                # Decode token IDs to strings
                trigger_text = joint_opt.tokenizer.decode(trigger_ids, skip_special_tokens=True)
                passage_text = joint_opt.tokenizer.decode(passage_ids, skip_special_tokens=True)

                # Encode poisoned passage
                passage_emb = joint_opt.encode_passage(
                    passage_ids.unsqueeze(0),
                    torch.ones_like(passage_ids).unsqueeze(0),
                    torch.zeros_like(passage_ids).unsqueeze(0)
                ).detach()

                # Append poisoned passage to full corpus and assign a string ID
                poison_id = "poison"
                id_list = [str(i) for i in range(corpus_embeddings.size(0))] + [poison_id]
                all_embeddings = torch.cat([corpus_embeddings, passage_emb], dim=0)

                # Prepare and evaluate triggered test queries
                triggered_queries = [joint_opt.insert_trigger(q, trigger_text, location='end') for q in test_queries]
                triggered_ranks = [
                    get_rank_from_id(joint_opt.encode_query(q).unsqueeze(0), poison_id, id_list, all_embeddings)
                    for q in triggered_queries
                ]

                # Prepare and evaluate clean test queries
                clean_ranks = [
                    get_rank_from_id(joint_opt.encode_query(q).unsqueeze(0), poison_id, id_list, all_embeddings)
                    for q in test_queries
                ]

                # Compute average ranks
                avg_clean = np.mean(clean_ranks)
                avg_triggered = np.mean(triggered_ranks)

                # Write results to log
                log_line = (
                    f"{trig_len}\t"
                    f"{pass_len}\t"
                    f"{trial}\t"
                    f"{trigger_text}\t"
                    f"{passage_text}\t"
                    f"{n_iter}\t"
                    f"{avg_triggered:.2f}\t"
                    f"{avg_clean:.2f}"
                )
                print(log_line)
                fout.write(log_line + "\n")
                fout.flush()


  from .autonotebook import tqdm as notebook_tqdm
  corpus_embeddings = torch.load("corpus_embeddings_10000.pt", map_location=joint_opt.device).to(joint_opt.device)


5	30	0	sect monograph consular metropolitan mile	leap squadrons but sue pussy armory were zeta modes backed newcomers chariot majesty examines interests cut reviewers bearing promising meritorious corrupt banks bureau taxpayers concession reform oversight	130	1.39	4325.09


KeyboardInterrupt: 