# Combination Attack Demo

This is a demo of the Combination Attack, running on the E5 model, using the MSMARCO dataset, with 5 random queries.

Load model, datasets:

In [None]:
import torch
from sentence_transformers import SentenceTransformer, util
from utils import data_utils
from huggingface_hub import hf_hub_download
import json
import numpy as np
from datasets import load_dataset
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

model = SentenceTransformer("intfloat/e5-base-v2").to(device)
print("Loaded E5 model.")

corpus, queries, _, _ = data_utils.load_dataset(
    dataset_name="msmarco",
    data_split="dev",
    embedder_model_name="intfloat/e5-base-v2",
)
print("Loaded MSMARCO dataset.")

local_results_path = hf_hub_download(
    repo_id="MatanBT/retrieval-datasets-similarities",
    filename="msmarco-test_1.0_e5-base-v2_cos_sim.json",
    repo_type="dataset",
)
with open(local_results_path) as f:
    results = json.load(f)
print("Loaded HF results.")

toxic_prefixes = [
    row["text"]
    for row in load_dataset("toxigen/toxigen-data")["train"]
    if row["toxicity_ai"] == 5 and 50 < len(row["text"]) < 100
]
print("Loaded ToxiGen dataset.")

Initiate variables:

In [None]:
trials = 5
seed = 42

np.random.seed(seed)
chosen_qids = np.random.choice(list(queries.keys()), size=(trials,), replace=False)
chosen_poisons = np.random.choice(toxic_prefixes, size=(trials,), replace=False)

info_rankings = []
stuffing_rankings = []
adv_rankings = []

best_similarities = []
info_similarities = []
stuffing_similarities = []
adv_similarities = []

Perform attacks:

In [None]:
from attack import BlackBoxAttack


def calc_ranking(sim, qid, results):
    ranking = 0
    for pid, score in results[qid].items():
        if score > sim:
            ranking += 1
    return ranking


for i, qid in enumerate(tqdm(chosen_qids)):
    print(f"Trial number {i + 1}")

    best_pid = list(results[qid].keys())[0]

    # original passage and query
    q = queries[qid]
    print(f"{qid}: {q}")
    p = corpus[best_pid]["text"]
    print(f"{best_pid}: {p}")
    info = "passage: " + chosen_poisons[i].strip()
    print(f"info: {info}")

    # batch encodings (except adv which depends on attack)
    base_texts = [
        q,
        p,
        info,
        info + " " + q.replace("query: ", ""),
    ]  # q, best passage, info, stuffing
    with torch.inference_mode():
        encs = model.encode(base_texts, convert_to_tensor=True)
    q_enc, p_enc, info_enc, stuffing_enc = encs

    # calculate similarities
    best_sim = util.cos_sim(q_enc, p_enc).item()
    print(f"Similarity between query and original best passage: {best_sim}")
    info_sim = util.cos_sim(q_enc, info_enc).item()
    print(f"Similarity between query and info passage: {info_sim}")
    stuffing_sim = util.cos_sim(q_enc, stuffing_enc).item()
    print(f"Similarity between query and stuffing passage: {stuffing_sim}")

    # store similarities
    best_similarities.append(best_sim)
    info_similarities.append(info_sim)
    stuffing_similarities.append(stuffing_sim)

    # store rankings
    info_rankings.append(calc_ranking(info_sim, qid, results))
    stuffing_rankings.append(calc_ranking(stuffing_sim, qid, results))

    # perform attack
    print("Performing black-box attack...")
    bb_attack = BlackBoxAttack(model, q, sim=util.cos_sim)
    tokens, _, _ = bb_attack.combination_attack(
        info,
        p_init=0.3946,
        total_tokens=72,
        num_iters=2940,
        random_pool_per_pos=316,
    )

    # craft adversarial passage
    p_adv = info + " " + " ".join(tokens)
    print(f"adversarial: {p_adv}")

    # calculate adv similarity
    with torch.inference_mode():
        p_adv_enc = model.encode(p_adv, convert_to_tensor=True)
    adv_sim = util.cos_sim(q_enc, p_adv_enc).item()
    print(f"Similarity between query and attacked passage: {adv_sim}")

    # store adv similarity and ranking
    adv_similarities.append(adv_sim)
    adv_rankings.append(calc_ranking(adv_sim, qid, results))

    print("\n", flush=True)

Summarize results:

In [None]:
print(f"Rankings of original passages: {info_rankings}")
print(f"Rankings of stuffing passages: {stuffing_rankings}")
print(f"Rankings of attacked passages: {adv_rankings}")
print(f"Average ranking of attacked passages: {sum(adv_rankings) / trials}")
print(f"Top 1 in {sum(r == 0 for r in adv_rankings)} out of {trials} attacks!")
print(f"Top 5 in {sum(r < 5 for r in adv_rankings)} out of {trials} attacks!")
print(f"Top 10 in {sum(r < 10 for r in adv_rankings)} out of {trials} attacks!")

print()
print(f"Similarities of best passages: {best_similarities}")
print(f"Similarities of original passages: {info_similarities}")
print(f"Similarities of stuffing passages: {stuffing_similarities}")
print(f"Similarities of attacked passages: {adv_similarities}")
print(f"Average similarity of attacked passages: {sum(adv_similarities) / trials}")