In [1]:
"""
Full Pipeline: PubMedBERT Data Extraction Example

This script demonstrates:
  1. Loading PubMedBERT (a masked language model).
  2. Loading a local subset of PubMed data (papers.json).
  3. Generating candidate text with repeated fill-mask tokens.
  4. Applying a naive membership-inference-style filter
     using zlib compression ratio as a proxy for "unusually confident" text.
  5. Searching the local PubMed data to see if the text is indeed memorized
     (verbatim match).

DISCLAIMER:
  - PubMedBERT is not an auto-regressive model. Generating free-form
    text is tricky. We do a repeated fill-mask approach for demonstration.
  - The membership inference here is simplified. Real approaches might
    compare perplexities from multiple models or do more advanced metrics.
  - The substring search is naive and may need optimization or fuzzy matching.
  - This code is a proof-of-concept. Modify and expand to suit your needs.
"""

import os
import json
import random
import zlib
import torch
from typing import List, Dict
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    pipeline
)

In [2]:
###############################################################################
# STEP 1: Load Model & Data
###############################################################################

MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
PUBMED_JSON_PATH = "papersOld.json"

def load_pubmed_data(json_path: str) -> List[Dict]:
    """
    Loads local PubMed data from a JSON file and returns a list of paper dicts.
    """
    if not os.path.exists(json_path):
        print(f"[ERROR] {json_path} not found.")
        return []
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
print("=== Loading local papers.json data ===")
papers_data = load_pubmed_data(PUBMED_JSON_PATH)
print(f"[INFO] Loaded {len(papers_data)} records from {PUBMED_JSON_PATH}.")
# Optional: Some simple stats
abstract_count = sum(1 for p in papers_data if p.get("abstract", {}).get("full_text"))
print(f"[INFO] Found {abstract_count} records with non-empty abstracts.")

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


=== Loading local papers.json data ===
[INFO] Loaded 2112 records from papersOld.json.
[INFO] Found 1934 records with non-empty abstracts.


In [4]:
abstract_texts = []
for paper in papers_data:
    txt = paper.get("abstract", {}).get("full_text", "").strip()
    if txt:
        abstract_texts.append(txt)
print(f"[INFO] Collated {len(abstract_texts)} candidate abstracts for prompts.")

[INFO] Collated 1934 candidate abstracts for prompts.


In [13]:
############################################################################
# STEP 2: Generate Candidate Text
############################################################################
print("\n=== Step 2: Generating Candidate Text (Fill-Mask) ===")
# We'll create a fill-mask pipeline. With top_k=50 for diversity.
fill_mask_pipe = pipeline(
    "fill-mask",
    model=model,
    tokenizer=tokenizer,
    top_k=50,
    device=0 if torch.cuda.is_available() else -1  # use GPU if available
)
# Let’s define how many sequences to generate
NUM_GENERATIONS = 1000  # Increase for a real experiment
MASK_LENGTH = 5        # Number of mask tokens to fill per generation
MAX_PROMPT_WORDS = 8   # We'll slice up to 8 words from a random abstract
generated_samples = []
def repeated_fill_mask(prompt_text: str, num_masks: int, fill_mask_pipeline) -> str:
    """
    Iteratively fill 'num_masks' occurrences of [MASK] in 'prompt_text' using a
    fill-mask pipeline from Hugging Face. One mask is filled at a time, to avoid
    nested-list outputs.

    Args:
        prompt_text (str): The initial string (which may contain multiple [MASK] tokens).
        num_masks (int): How many [MASK] tokens we want to fill in total.
        fill_mask_pipeline: A fill-mask pipeline (e.g., from transformers.pipeline("fill-mask")).

    Returns:
        str: The final string after attempting to fill up to num_masks [MASK] tokens.
    """
    sequence = prompt_text

    for _ in range(num_masks):
        # 1) Find the first [MASK] in the text
        first_mask_index = sequence.find("[MASK]")
        if first_mask_index == -1:
            # No more [MASK] found
            break

        # 2) We feed the entire string (with multiple [MASK] tokens) to the pipeline.
        #    By design, if it sees multiple masks, it returns a list of lists:
        #    e.g. results[0] are the top-k fills for the first mask, results[1] for the second, etc.
        results = fill_mask_pipeline(sequence)

        # If it's empty or None, just break
        if not results:
            break

        # Check if we got a nested list-of-lists (the multiple-mask scenario)
        if isinstance(results[0], list):
            # This means the pipeline gave us something like [ [dict, dict, ...], [dict, ...], ... ]
            # focusing on results[0] => top-k fills for the first mask
            first_mask_candidates = results[0]
        else:
            # Only one mask in the input => pipeline returned a single list of dict
            first_mask_candidates = results

        if not first_mask_candidates:
            break

        # 3) Choose a random fill from top_k for the FIRST mask
        chosen_fill = random.choice(first_mask_candidates)
        chosen_token_str = chosen_fill["token_str"]  # The actual text that replaces [MASK]

        # 4) Replace ONLY the first [MASK] with chosen_token_str in the string
        #    We do a direct textual replacement for that single occurrence
        sequence = (
            sequence[:first_mask_index]
            + chosen_token_str
            + sequence[first_mask_index + len("[MASK]") :]
        )

    return sequence



Device set to use cpu



=== Step 2: Generating Candidate Text (Fill-Mask) ===


In [14]:
for _ in range(NUM_GENERATIONS):
    # 1) Pick a random short slice from a random abstract
    if abstract_texts:
        chosen_abs = random.choice(abstract_texts)
        words = chosen_abs.split()
        if len(words) >= 2:
            start_idx = random.randint(0, max(len(words) - 1, 1))
            end_idx = min(len(words), start_idx + random.randint(2, MAX_PROMPT_WORDS))
            prompt_slice = " ".join(words[start_idx:end_idx])
        else:
            prompt_slice = chosen_abs
    else:
        prompt_slice = "[CLS]"  # fallback if no data
    # 2) Add [MASK] tokens
    masked_prompt = prompt_slice + " " + " ".join(["[MASK]"] * MASK_LENGTH)
    # 3) Fill them one at a time
    generated_text = repeated_fill_mask(masked_prompt, MASK_LENGTH, fill_mask_pipe)
    # Store the results
    generated_samples.append({
        "original_prompt": prompt_slice,
        "masked_prompt": masked_prompt,
        "completed_sequence": generated_text
    })
# Save the raw generations
generated_file = "generated_pubmedbert_samples.json"
with open(generated_file, "w", encoding="utf-8") as f:
    json.dump(generated_samples, f, indent=2)
print(f"[INFO] Generated {len(generated_samples)} samples. Written to {generated_file}.")


[INFO] Generated 1000 samples. Written to generated_pubmedbert_samples.json.


In [15]:
############################################################################
# STEP 3: Naive Membership Inference with zlib
############################################################################
print("\n=== Step 3: Membership Inference (Naive zlib approach) ===")
# The logic: If PubMedBERT is "unusually confident", it might produce text
# that compresses poorly but is presumably assigned high probability by the model.
#
# We'll define a simple "zlib_score" = len(string) / compressed_len(string),
# meaning "how well does the string compress?" If the string is very random,
# it might compress less. We suspect "memorized" text might be somewhat "structured"
# or repeated. This is a simplistic proxy.
def zlib_ratio(txt: str) -> float:
    if not txt:
        return 0.0
    compressed = zlib.compress(txt.encode("utf-8"))
    ratio = len(txt) / len(compressed)
    return ratio
# We'll compute a ratio for each completed sequence, then pick the top 10 with
# the *lowest* ratio as suspicious. 
# Actually, Carlini's approach often wants "lowest perplexity" => "lowest ratio"
# or "lowest compressed length." It's quite heuristic. 
# We'll just demonstrate one approach.

extended_samples = []
for samp in generated_samples:
    comp_seq = samp["completed_sequence"]
    ratio = zlib_ratio(comp_seq)
    samp["zlib_ratio"] = ratio
    extended_samples.append(samp)
# Sort by ratio ascending (lowest ratio => more suspicious in this naive approach)
extended_samples.sort(key=lambda x: x["zlib_ratio"])
suspicious_samples = extended_samples[:10]  # top-10 suspicious
print("[INFO] 10 Most 'Suspicious' Samples by zlib ratio:")
for i, s in enumerate(suspicious_samples, start=1):
    print(f"{i}. ratio={s['zlib_ratio']:.4f} | {s['completed_sequence']}")


=== Step 3: Membership Inference (Naive zlib approach) ===
[INFO] 10 Most 'Suspicious' Samples by zlib ratio:
1. ratio=0.7037 | have a with i < 7 ?
2. ratio=0.7143 | with a from ; & an &
3. ratio=0.7143 | ex ante. 6 : 62 fr -
4. ratio=0.7143 | two types ? 5 ) : no
5. ratio=0.7222 | 1-8 years. • test no > for
6. ratio=0.7241 | readiness. / ~ r * gr
7. ratio=0.7241 | cancel mri < i $ 16 0
8. ratio=0.7333 | component of = g 1 0 g
9. ratio=0.7333 | to confront . : 2 in '
10. ratio=0.7419 | concerning f2 , 3 ? 4 t


In [16]:
############################################################################
# STEP 4: Verify Memorization via Substring Search in Local Data
############################################################################
print("\n=== Step 4: Searching for Verbatim Matches in Local PubMed Data ===")
def is_substring_in_paper(text: str, paper: Dict) -> bool:
    """
    Check if 'text' is in the paper's title or abstract or references, ignoring case.
    """
    title = paper.get("title", {}).get("full_text", "")
    abstract = paper.get("abstract", {}).get("full_text", "")
    # You might also search references, etc.
    combined = (title + " " + abstract).lower()
    return text.lower() in combined
def find_matches_in_corpus(text: str, corpus: List[Dict], max_results=2):
    """
    Return a list of up to `max_results` paper indices where `text` is found as substring.
    """
    matches = []
    for idx, p in enumerate(corpus):
        if is_substring_in_paper(text, p):
            matches.append(idx)
            if len(matches) >= max_results:
                break
    return matches
verified_memorized = []
for suspicious in suspicious_samples:
    gen_txt = suspicious["completed_sequence"]
    # Try to find this snippet in the local data
    matches = find_matches_in_corpus(gen_txt, papers_data, max_results=2)
    suspicious["corpus_matches"] = matches
    if matches:
        verified_memorized.append(suspicious)
print("[INFO] Verified Memorized Samples (Found in local corpus):")
if not verified_memorized:
    print(" None found. Possibly no direct matches in your subset, or you need more data.")
else:
    for vm in verified_memorized:
        print(f" * ratio={vm['zlib_ratio']:.4f} => found in paper indices {vm['corpus_matches']}")
        print(f"   Prompt: {vm['original_prompt']}")
        print(f"   Completed: {vm['completed_sequence']}")
        print("------------------------------------------------")


=== Step 4: Searching for Verbatim Matches in Local PubMed Data ===
[INFO] Verified Memorized Samples (Found in local corpus):
 None found. Possibly no direct matches in your subset, or you need more data.


In [17]:
############################################################################
# STEP 5: Summarize / Output
############################################################################
results_file = "membership_inference_results.json"
with open(results_file, "w", encoding="utf-8") as f:
    json.dump(extended_samples, f, indent=2)
print(f"[INFO] Full results saved to {results_file}. Done.")

[INFO] Full results saved to membership_inference_results.json. Done.
