In [24]:
import json
import re
import time
import zlib
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime

import torch
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForMaskedLM as BertModel
from tqdm import tqdm

In [25]:
# Configuration
BIOGPT_MODEL = "microsoft/BioGPT"
PUBMEDBERT_MODEL = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
PUBMED_DATA = "./Data/papersNew.json"
PMC_DATA = "./Data/pubmed_2010_2024_intelligence.json"
OUTPUT_FILE = "extraction_results.json"
MAX_SAMPLES = 10000  # Total samples to generate
BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [26]:
biogpt_tokenizer = AutoTokenizer.from_pretrained(BIOGPT_MODEL)
biogpt_model = AutoModelForCausalLM.from_pretrained(BIOGPT_MODEL).to(DEVICE)
pubmedbert_tokenizer = AutoTokenizer.from_pretrained(PUBMEDBERT_MODEL)
pubmedbert_model = BertModel.from_pretrained(PUBMEDBERT_MODEL).to(DEVICE)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [44]:
# Load training data
def load_training_data():
    pmc = []

    with open(PUBMED_DATA, "r", encoding="utf-8") as f:
        pubmed = json.load(f)

    with open(PMC_DATA, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:  # skip empty lines
                pmc.append(json.loads(line))
            
    return pubmed, pmc

# Preprocess PMC data (ADD THIS STEP)
def preprocess_pmc(pmc_data):
    processed = []
    for entry in pmc_data:
        if "full_text" in entry and entry["full_text"]:  # Check XML exists
            sections = parse_pmc_text(entry["full_text"])
            processed.append({
                "pmcid": entry.get("pmc_id", "N/A"),
                "sections": sections  # Structured sections
            })
        else:
            processed.append({
                "pmcid": entry.get("pmc_id", "N/A"),
                "sections": {}  # Empty if no content
            })
    return processed

In [28]:
pubmed_data, pmc_data = load_training_data()

In [31]:
# Filter out records with empty abstracts
pmc_data = [record for record in pmc_data if record.get("abstract", "").strip()]

print(f"Total records after filtering (non-empty abstracts): {len(pmc_data)}")


Total records after filtering (non-empty abstracts): 51568


In [38]:
# Preprocess PMC full text
def parse_pmc_text(text):
    soup = BeautifulSoup(text, "xml")
    sections = {}
    for sec in soup.find_all("sec"):
        title = sec.find("title").get_text(strip=True) if sec.find("title") else "No Title"
        content = " ".join(p.get_text(strip=True) for p in sec.find_all("p"))
        sections[title] = content
    return sections

# Generate prefixes from training data
def generate_prefixes(data, max_length=10):
    prefixes = set()
    for entry in data:
        if "abstract" in entry:
            text = entry["abstract"]
            sents = re.split(r'(?<=[.!?])\s+', text)
            for sent in sents:
                tokens = sent.split()[:max_length]
                prefixes.add(" ".join(tokens) + "...")
                
    return list(prefixes)

In [33]:
pubmed_prefixes = generate_prefixes(pubmed_data)


In [39]:
pmc_prefixes = generate_prefixes(pmc_data, max_length=15)

In [42]:
# Text generation strategies
def generate_with_biogpt(prefix, strategy="top_n", max_length=256):
    inputs = biogpt_tokenizer(prefix, return_tensors="pt").to(DEVICE)
    
    if strategy == "top_n":
        outputs = biogpt_model.generate(
            **inputs,
            max_length=max_length,
            do_sample=True,
            top_k=40,
            num_return_sequences=1
        )
    elif strategy == "temperature_decay":
        outputs = biogpt_model.generate(
            **inputs,
            max_length=max_length,
            do_sample=True,
            temperature=10.0,
            temperature_decay=0.5,
            num_return_sequences=1
        )
    else:
        raise ValueError("Unknown strategy")
        
    return biogpt_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Membership inference metrics
def calculate_perplexity(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
    return torch.exp(loss).item()

def calculate_zlib_entropy(text):
    compressed = zlib.compress(text.encode())
    return len(compressed) / len(text)

def sliding_window_perplexity(text, window_size=50):
    min_ppl = float("inf")
    for i in range(len(text) - window_size):
        window = text[i:i+window_size]
        ppl = calculate_perplexity(window, biogpt_model, biogpt_tokenizer)
        if ppl < min_ppl:
            min_ppl = ppl
    return min_ppl

# Verification functions
def exact_match_check(sample):
    for entry in pubmed_data:
        if sample in entry["abstract"]["full_text"]:
            return True
    for entry in pmc_data:
        if any(sample in sec for sec in entry["sections"].values()):
            return True
    return False

def fuzzy_match_check(sample, threshold=90):
    matches = []
    for entry in pubmed_data:
        score = fuzz.token_set_ratio(sample, entry["abstract"]["full_text"])
        if score >= threshold:
            matches.append(entry["pmid"])
    for entry in pmc_data:
        for sec in entry["sections"].values():
            score = fuzz.token_set_ratio(sample, sec)
            if score >= threshold:
                matches.append(entry["pmcid"])
    return matches

# Anonymization
def anonymize(text):
    # PHI patterns
    patterns = {
        "phone": r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}",
        "email": r"[\w.-]+@[\w.-]+",
        "ssn": r"\d{3}-\d{2}-\d{4}",
        "address": r"\d+\s+[A-Za-z]+\s+(?:Ave|St|Dr|Ln|Blvd)\.?",
        "name": r"[A-Z][a-z]+ [A-Z][a-z]+"
    }
    
    for key, pattern in patterns.items():
        text = re.sub(pattern, f"[{key.upper()} REDACTED]", text)
    return text

In [43]:
results = []
prefixes = pubmed_prefixes + pmc_prefixes

with ThreadPoolExecutor() as executor:
    futures = []
    for prefix in prefixes[:MAX_SAMPLES]:
        futures.append(executor.submit(
            generate_with_biogpt,
            prefix,
            strategy="top_n",
            max_length=256
        ))
        
    for future in tqdm(as_completed(futures), total=len(futures)):
        sample = future.result()
        sample = anonymize(sample)
        
        # Calculate metrics
        bio_ppl = calculate_perplexity(sample, biogpt_model, biogpt_tokenizer)
        pubmed_ppl = calculate_perplexity(sample, pubmedbert_model, pubmedbert_tokenizer)
        ratio = bio_ppl / pubmed_ppl
        entropy = calculate_zlib_entropy(sample)
        window_ppl = sliding_window_perplexity(sample)
        
        # Verification
        exact_match = exact_match_check(sample)
        fuzzy_matches = fuzzy_match_check(sample)
        
        results.append({
            "timestamp": datetime.now().isoformat(),
            "sample": sample,
            "metrics": {
                "perplexity_ratio": ratio,
                "zlib_entropy": entropy,
                "min_window_ppl": window_ppl
            },
            "verification": {
                "exact_match": exact_match,
                "fuzzy_matches": fuzzy_matches
            }
        })
        
        # Save periodically
        if len(results) % 100 == 0:
            with open(OUTPUT_FILE, "a") as f:
                for res in results:
                    json.dump(res, f)
                    f.write("\n")
            results = []

  0%|          | 0/10000 [01:40<?, ?it/s]


KeyboardInterrupt: 