In [6]:
"""
Refined Black-Box Memorization Attack for BioGPT:
- Auto-regressive text generation
- Sliding-window perplexity analysis
- Embedding-based similarity (BioBERT-NLI) for near-verbatim detection

Steps:
  1. Loads BioGPT as a black-box model (only using model.generate, plus a minimal forward pass for perplexity).
  2. Loads a local PubMed/PMC corpus from papers.json (with 'title' and 'abstract' fields).
  3. Builds candidate prompts from titles and partial abstracts.
  4. Generates text with top-k/nucleus sampling.
  5. Computes:
       - zlib ratio (a naive compression-based signal)
       - sliding-window perplexity across the generated output
  6. Uses domain-specific embeddings to compare each generated completion to the corpus.
  7. Prints and saves suspicious results.

Requires:
  - transformers
  - torch
  - sentence-transformers
"""

import os
import json
import random
import zlib
import re
from typing import List, Dict

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
from sentence_transformers import SentenceTransformer, util


In [4]:
###############################################################################
# CONFIGURATION
###############################################################################
BIOGPT_MODEL_NAME = "microsoft/BioGPT-Large"  # Or "microsoft/BioGPT"
CORPUS_JSON_PATH = "papers.json"             # local PubMed/PMC dataset
GENERATIONS_FILE = "biogpt_generations.json"
ATTACK_RESULTS_FILE = "biogpt_attack_results.json"

NUM_GENERATIONS = 1000    # number of completions to generate
TOKENS_TO_GENERATE = 200  # length of each generated text
TEMPERATURE = 0.8
TOP_K = 50
TOP_P = 0.95

WINDOW_SIZE = 50          # sliding window size for perplexity
STRIDE_FRACTION = 0.5     # overlap fraction for sliding window (e.g. 0.5 => half overlap)

EMB_MODEL_NAME = "pritamdeka/BioBERT-NLI-mean-tokens"  # domain-specific embedding model
EMB_SIM_THRESHOLD = 0.85   # similarity threshold to consider near-verbatim


In [5]:
###############################################################################
# Helper Functions
###############################################################################
def load_pubmed_data(json_path: str) -> List[Dict]:
    """Load the PubMed abstracts data from a JSON file."""
    if not os.path.exists(json_path):
        print(f"[ERROR] File not found: {json_path}")
        return []
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data

def preprocess_text(text: str) -> str:
    """Lowercase, remove punctuation, and normalize whitespace."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def zlib_ratio(txt: str) -> float:
    """Compute zlib compression ratio as a membership inference metric."""
    if not txt.strip():
        return 0.0
    compressed = zlib.compress(txt.encode("utf-8"))
    return len(txt) / len(compressed)

def sliding_window_perplexity(text: str, model, tokenizer, device: str, window_size: int = WINDOW_SIZE) -> Dict:
    """
    Compute perplexity for overlapping sliding windows of size 'window_size'.
    Returns the minimum perplexity and the average perplexity across windows.
    """
    # Tokenize text (we assume text is a string)
    encodings = tokenizer(text, return_tensors="pt")
    input_ids = encodings.input_ids[0]  # shape: (seq_len,)
    seq_len = input_ids.size(0)
    if seq_len < window_size:
        window_size = seq_len
    window_perplexities = []
    # Slide window over the sequence (with a stride, e.g., half the window size)
    stride = window_size // 2 if window_size > 1 else 1
    for i in range(0, seq_len - window_size + 1, stride):
        window_ids = input_ids[i:i+window_size].unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(window_ids, labels=window_ids)
        loss = outputs.loss
        ppl = torch.exp(loss).item()
        window_perplexities.append(ppl)
    if not window_perplexities:
        return {"min_ppl": None, "avg_ppl": None}
    return {"min_ppl": min(window_perplexities), "avg_ppl": sum(window_perplexities)/len(window_perplexities)}

def embedding_similarity(generated_text: str, corpus: List[Dict], emb_model, field: str = "title_abstract") -> List[int]:
    """
    Compute embedding similarity between the generated text and each article in the corpus.
    For each article, we combine title and abstract fields (preprocessed).
    Returns indices of articles where cosine similarity is above EMB_SIM_THRESHOLD.
    """
    gen_emb = emb_model.encode(generated_text, convert_to_tensor=True)
    indices = []
    corpus_texts = []
    for article in corpus:
        title = article.get("title", {}).get("full_text", "") or ""
        abstract = article.get("abstract", {}).get("full_text", "") or ""
        combined = preprocess_text(title + " " + abstract)
        corpus_texts.append(combined)
    corpus_embs = emb_model.encode(corpus_texts, convert_to_tensor=True)
    cos_scores = util.cos_sim(gen_emb, corpus_embs)[0]  # 1D tensor of similarity scores
    for idx, score in enumerate(cos_scores):
        if score.item() >= EMB_SIM_THRESHOLD:
            indices.append(idx)
    return indices

In [None]:
# --- Step A: Load Model & Data ---
print(f"[INFO] Loading PubMedBERT model (masked LM): {BIOGPT_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(BIOGPT_MODEL)
model = AutoModelForCausalLM.from_pretrained(BIOGPT_MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f"[INFO] Loading PubMed abstracts data from {CORPUS_JSON_PATH}")
corpus = load_pubmed_data(CORPUS_JSON_PATH)
print(f"[INFO] Loaded {len(corpus)} records from local PubMed abstracts.")

# Build candidate prompts from titles and first 20 words of abstracts
prompts = []
for article in corpus:
    title = article.get("title", {}).get("full_text", "").strip()
    abstract = article.get("abstract", {}).get("full_text", "").strip()
    if title:
        prompts.append(title)
    if abstract:
        words = abstract.split()
        prompt_abstract = " ".join(words[:20]) if len(words) > 20 else abstract
        prompts.append(prompt_abstract)
if not prompts:
    prompts = ["Biomedical research shows", "In this study, we explore"]
print(f"[INFO] Built {len(prompts)} candidate prompts.")

# --- Step B: Generate Completions ---
generations = []
# For PubMedBERT, we simulate generation using an iterative fill-mask approach.
# Create a fill-mask pipeline for the model.
fill_mask_pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=TOP_K, device=0 if device=="cuda" else -1)

def iterative_fill_mask(prompt: str, num_masks: int, fill_mask_pipeline) -> str:
    """
    Iteratively fill the first [MASK] token in the prompt using the fill-mask pipeline.
    """
    sequence = prompt
    for _ in range(num_masks):
        if "[MASK]" not in sequence:
            break
        first_mask_index = sequence.find("[MASK]")
        split_once = sequence.split("[MASK]", 1)
        text_for_pipeline = split_once[0] + "[MASK]" + split_once[1]
        results = fill_mask_pipeline(text_for_pipeline)
        if isinstance(results[0], list):
            candidates = results[0]
        else:
            candidates = results
        if not candidates:
            break
        chosen = random.choice(candidates)
        chosen_token = chosen["token_str"]
        sequence = sequence[:first_mask_index] + chosen_token + sequence[first_mask_index+len("[MASK]"):]
    return sequence

for i in range(NUM_GENERATIONS):
    prompt = random.choice(prompts)
    # Append MASK tokens
    masked_prompt = prompt + " " + " ".join(["[MASK]"] * MASK_LENGTH)
    generated = iterative_fill_mask(masked_prompt, MASK_LENGTH, fill_mask_pipe)
    # Compute zlib ratio
    z_ratio = zlib_ratio(generated)
    # Compute sliding-window perplexity
    ppl_stats = sliding_window_perplexity(generated, model, tokenizer, device, window_size=WINDOW_SIZE)
    generations.append({
        "prompt": prompt,
        "generated_text": generated,
        "zlib_ratio": z_ratio,
        "min_window_ppl": ppl_stats["min_ppl"],
        "avg_window_ppl": ppl_stats["avg_ppl"]
    })
    if (i+1) % 100 == 0:
        print(f"[INFO] Generated {i+1} completions.")

with open(GENERATIONS_FILE, "w", encoding="utf-8") as f:
    json.dump(generations, f, indent=2)
print(f"[INFO] Saved {len(generations)} completions to {GENERATIONS_FILE}.")

# --- Step C: Membership Inference Filtering ---
# Here, we flag suspicious completions as those with high zlib_ratio
# and/or very low sliding-window perplexity. For simplicity, we sort by zlib_ratio.
generations.sort(key=lambda x: x["zlib_ratio"], reverse=True)
top_suspicious = generations[:50]
print("[INFO] Top 5 suspicious completions by zlib_ratio:")
for j, cand in enumerate(top_suspicious[:5], start=1):
    print(f"{j}. zlib_ratio: {cand['zlib_ratio']:.4f}, min_window_ppl: {cand['min_window_ppl']:.2f}")
    print(f"Prompt: {cand['prompt']}")
    print(f"Generated (first 150 chars): {cand['generated_text'][:150]}...")
    print("-" * 60)

# --- Step D: Verification via Embedding-Based Similarity ---
# Initialize SentenceTransformer embedding model
emb_model = SentenceTransformer(EMB_MODEL_NAME)
verified_memorized = []
for suspicious in top_suspicious:
    gen_text = suspicious["generated_text"]
    emb_matches = embedding_similarity(gen_text, corpus, emb_model, field="title_abstract")
    if emb_matches:
        suspicious["embedding_matches"] = emb_matches
        verified_memorized.append(suspicious)

print(f"[INFO] Verified memorized samples via embedding similarity: {len(verified_memorized)}")
for vm in verified_memorized:
    print("=" * 60)
    print(f"zlib_ratio: {vm['zlib_ratio']:.4f}, min_window_ppl: {vm['min_window_ppl']:.2f}")
    print(f"Prompt: {vm['prompt']}")
    print(f"Generated Text: {vm['generated_text']}")
    print(f"Embedding matches in corpus indices: {vm['embedding_matches']}")

# Save final results
results = {
    "generations": generations,
    "verified_memorized": verified_memorized
}
with open(ATTACK_RESULTS, "w", encoding="utf-8") as rf:
    json.dump(results, rf, indent=2)
print(f"[INFO] Attack results saved to {ATTACK_RESULTS}")