In [1]:
"""
Black?Box Memorization Attack Using PubMedBERT on PubMed Abstracts

This script simulates an attack on a domain?specific model in a black?box setting.
We load PubMedBERT (a masked LM), generate completions via iterative fill?mask,
and then use fuzzy n?gram matching against a local PubMed abstracts corpus (papers.json)
to try to detect memorized (verbatim or near-verbatim) sequences.
"""

import os
import json
import random
import zlib
import re
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline


In [6]:
###############################################################################
# Configuration
###############################################################################
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
CORPUS_JSON_PATH = "../../Data/papersNew.json"            # your PubMed abstracts data file
OUTPUT_GENERATIONS = "pubmedbert_generations.json"
ATTACK_RESULTS = "attack_results_pubmedbert.json"

NUM_GENERATIONS = 2000   # Number of completions to generate
MASK_LENGTH = 5          # Number of [MASK] tokens to append and fill
TOP_K = 50               # For fill-mask pipeline
SUBSTRING_SEARCH_MAX = 2 # Maximum matching articles per candidate
FUZZY_N = 2              # Use bigrams
FUZZY_THRESHOLD = 0.3    # Fuzzy matching threshold

In [3]:
###############################################################################
# Helper Functions
###############################################################################
def preprocess_text(text: str) -> str:
    """Lowercase text, remove punctuation, and normalize whitespace."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_pubmed_data(json_path: str) -> List[Dict]:
    """Load PubMed abstracts data from a JSON file."""
    if not os.path.exists(json_path):
        print(f"[ERROR] File not found: {json_path}")
        return []
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data

def zlib_ratio(txt: str) -> float:
    """Compute zlib compression ratio as: len(text) / len(compressed(text))."""
    if not txt.strip():
        return 0.0
    compressed = zlib.compress(txt.encode("utf-8"))
    return len(txt) / len(compressed)

def fuzzy_ngram_search(snippet: str, corpus: List[Dict], n: int = FUZZY_N, threshold: float = FUZZY_THRESHOLD, max_results: int = SUBSTRING_SEARCH_MAX) -> List[int]:
    """
    Compute n-gram overlap (Jaccard similarity) between the snippet and the combined text
    (title + abstract) from each article in the corpus after preprocessing.
    
    Returns indices of articles where the overlap similarity is at least the threshold.
    """
    snippet = preprocess_text(snippet)
    snippet_tokens = snippet.split()
    if len(snippet_tokens) < n:
        snippet_ngrams = set([tuple(snippet_tokens)])
    else:
        snippet_ngrams = set(zip(*[snippet_tokens[i:] for i in range(n)]))
    
    matches = []
    for i, article in enumerate(corpus):
        title = article.get("title", {}).get("full_text", "") or ""
        abstract = article.get("abstract", {}).get("full_text", "") or ""
        combined = preprocess_text(title + " " + abstract)
        combined_tokens = combined.split()
        if len(combined_tokens) < n:
            combined_ngrams = set([tuple(combined_tokens)])
        else:
            combined_ngrams = set(zip(*[combined_tokens[i:] for i in range(n)]))
        
        if not snippet_ngrams or not combined_ngrams:
            continue

        intersection = snippet_ngrams.intersection(combined_ngrams)
        union = snippet_ngrams.union(combined_ngrams)
        similarity = len(intersection) / len(union) if union else 0.0

        if similarity >= threshold:
            matches.append(i)
            if len(matches) >= max_results:
                break

    return matches

def iterative_fill_mask(prompt: str, num_masks: int, fill_mask_pipeline) -> str:
    """
    Iteratively fill the first [MASK] token in the prompt using the fill-mask pipeline,
    replacing it with a randomly selected prediction. This repeats for num_masks iterations.
    """
    sequence = prompt
    for _ in range(num_masks):
        if "[MASK]" not in sequence:
            break
        first_mask_index = sequence.find("[MASK]")
        # Split the sequence to ensure we only fill one mask at a time.
        split_once = sequence.split("[MASK]", 1)
        text_for_pipeline = split_once[0] + "[MASK]" + split_once[1]
        results = fill_mask_pipeline(text_for_pipeline)
        # results may be a list of lists if multiple masks are detected; we use the first list.
        if isinstance(results[0], list):
            candidates = results[0]
        else:
            candidates = results
        if not candidates:
            break
        chosen = random.choice(candidates)
        chosen_token = chosen["token_str"]
        # Replace only the first occurrence of [MASK]
        sequence = sequence[:first_mask_index] + chosen_token + sequence[first_mask_index + len("[MASK]"):]
    return sequence


In [4]:
# --- Step A: Load Model & Data ---
print(f"[INFO] Loading PubMedBERT model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

[INFO] Loading PubMedBERT model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract


Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [7]:
# Create a fill-mask pipeline for generation
fill_mask_pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=TOP_K, device=0 if device=="cuda" else -1)

print(f"[INFO] Loading PubMed abstracts data from {CORPUS_JSON_PATH}")
corpus = load_pubmed_data(CORPUS_JSON_PATH)
print(f"[INFO] Loaded {len(corpus)} records from local PubMed abstracts data.")

Device set to use cuda:0


[INFO] Loading PubMed abstracts data from ../../Data/papersNew.json
[INFO] Loaded 157833 records from local PubMed abstracts data.


In [8]:
# --- Step B: Build Candidate Prompts ---
# We use title full_text and first 20 words of abstract as prompts.
prompts = []
for article in corpus:
    title_text = article.get("title", {}).get("full_text", "").strip()
    abstract_text = article.get("abstract", {}).get("full_text", "").strip()
    if title_text:
        prompts.append(title_text)
    if abstract_text:
        words = abstract_text.split()
        prompt_abstract = " ".join(words[:20]) if len(words) > 20 else abstract_text
        prompts.append(prompt_abstract)
if not prompts:
    prompts = ["Biomedical research shows", "In this study, we explore"]
print(f"[INFO] Built {len(prompts)} candidate prompts.")

[INFO] Built 305510 candidate prompts.


In [9]:
# --- Step C: Generate Completions ---
# For each generation, choose a random prompt, append a sequence of [MASK] tokens, and fill them iteratively.
generations = []
for i in range(NUM_GENERATIONS):
    prompt = random.choice(prompts)
    masked_prompt = prompt + " " + " ".join(["[MASK]"] * MASK_LENGTH)
    generated_text = iterative_fill_mask(masked_prompt, MASK_LENGTH, fill_mask_pipe)
    ratio = zlib_ratio(generated_text)
    generations.append({
        "prompt": prompt,
        "generated_text": generated_text,
        "zlib_ratio": ratio
    })
    if (i+1) % 100 == 0:
        print(f"[INFO] Generated {i+1} completions.")

with open(OUTPUT_GENERATIONS, "w", encoding="utf-8") as f:
    json.dump(generations, f, indent=2)
print(f"[INFO] Saved {len(generations)} completions to {OUTPUT_GENERATIONS}.")

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


[INFO] Generated 100 completions.
[INFO] Generated 200 completions.
[INFO] Generated 300 completions.
[INFO] Generated 400 completions.
[INFO] Generated 500 completions.
[INFO] Generated 600 completions.
[INFO] Generated 700 completions.
[INFO] Generated 800 completions.
[INFO] Generated 900 completions.
[INFO] Generated 1000 completions.
[INFO] Generated 1100 completions.
[INFO] Generated 1200 completions.
[INFO] Generated 1300 completions.
[INFO] Generated 1400 completions.
[INFO] Generated 1500 completions.
[INFO] Generated 1600 completions.
[INFO] Generated 1700 completions.
[INFO] Generated 1800 completions.
[INFO] Generated 1900 completions.
[INFO] Generated 2000 completions.
[INFO] Saved 2000 completions to pubmedbert_generations.json.


In [10]:
# --- Step D: Membership Inference Filtering ---
# Sort generations by zlib_ratio (higher might be suspicious) and pick the top 50.
generations.sort(key=lambda x: x["zlib_ratio"], reverse=True)
top_suspicious = generations[:50]
print("[INFO] Top 5 suspicious completions by zlib_ratio:")
for j, cand in enumerate(top_suspicious[:5], start=1):
    print(f"{j}. zlib_ratio: {cand['zlib_ratio']:.4f}")
    print(f"Prompt: {cand['prompt']}")
    print(f"Generated (first 150 chars): {cand['generated_text'][:150]}...")
    print("-" * 60)

[INFO] Top 5 suspicious completions by zlib_ratio:
1. zlib_ratio: 1.7500
Prompt: in this work, the interaction mechanisms between an autotrophic denitrification (ad) and heterotrophic denitrification (hd) process in a heterotrophic-autotrophic denitrification
Generated (first 150 chars): in this work, the interaction mechanisms between an autotrophic denitrification (ad) and heterotrophic denitrification (hd) process in a heterotrophic...
------------------------------------------------------------
2. zlib_ratio: 1.6357
Prompt: this study investigated age differences in appetitive and aversive associative learning using a pavlovian conditioning paradigm. appetitive and aversive associative
Generated (first 150 chars): this study investigated age differences in appetitive and aversive associative learning using a pavlovian conditioning paradigm. appetitive and aversi...
------------------------------------------------------------
3. zlib_ratio: 1.6260
Prompt: micrornas (mirnas) are smal

In [15]:
corpus[1538]['abstract']['full_text']

'multiple studies successfully applied multivariate analysis to neuroimaging data demonstrating the potential utility of neuroimaging for clinical diagnostic and prognostic purposes.'

In [16]:
def extract_generated_portion(prompt: str, generated: str) -> str:
    """
    If the generated text starts with the prompt, return only the portion
    after the prompt. Otherwise, return the full generated text.
    """
    if generated.startswith(prompt):
        return generated[len(prompt):].strip()
    return generated.strip()

In [17]:
# --- Step E: Verification via Fuzzy n-gram Matching ---
verified_memorized = []
for suspicious in top_suspicious:
    # Extract just the generated (filled-mask) part
    gen_portion = extract_generated_portion(suspicious["prompt"], suspicious["generated_text"])
    matches = fuzzy_ngram_search(gen_portion, corpus, n=FUZZY_N, threshold=FUZZY_THRESHOLD, max_results=SUBSTRING_SEARCH_MAX)
    if matches:
        suspicious["matches"] = matches
        verified_memorized.append(suspicious)

print(f"[INFO] Verified memorized samples (fuzzy matching): {len(verified_memorized)}")
for vm in verified_memorized:
    print("=" * 60)
    print(f"zlib_ratio: {vm['zlib_ratio']:.4f}")
    print(f"Prompt: {vm['prompt']}")
    print(f"Generated Text: {vm['generated_text']}")
    print(f"Found in corpus indices: {vm['matches']}")

results = {
    "generations": generations,
    "verified_memorized": verified_memorized
}
with open(ATTACK_RESULTS, "w", encoding="utf-8") as rf:
    json.dump(results, rf, indent=2)
print(f"[INFO] Attack results saved to {ATTACK_RESULTS}")

[INFO] Verified memorized samples (fuzzy matching): 0
[INFO] Attack results saved to attack_results_pubmedbert.json
