In [None]:
import pandas as pd
import psycopg2
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import base64
import json
import logging
import os
from concurrent.futures import ProcessPoolExecutor
from functools import partial

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ehr_pipeline.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Initialize Presidio
analyzer = AnalyzerEngine()
anonymizer = AnonymizerEngine()

# Database configuration (replace with your PostgreSQL details)
DB_CONFIG = {
    'dbname': 'ehr_db',
    'user': 'secure_user',
    'password': 'secure_password',
    'host': 'localhost',
    'port': '5432'
}

# Encryption key (in production, use AWS KMS or similar)
ENCRYPTION_KEY = get_random_bytes(32)  # AES-256 key
NONCE = get_random_bytes(12)  # GCM nonce

# Mock language model API (replace with actual model API)
def mock_language_model(anonymized_text):
    logger.info("Processing text with language model")
    return f"Diagnosis: High risk of condition X for {anonymized_text}"

# Database setup
def setup_database():
    try:
        conn = psycopg2.connect(**DB_CONFIG)
        cursor = conn.cursor()
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS phi_mapping (
                id SERIAL PRIMARY KEY,
                record_id VARCHAR(50),
                token VARCHAR(50),
                encrypted_phi TEXT
            )
        """)
        conn.commit()
        logger.info("Database table created or verified")
        return conn, cursor
    except Exception as e:
        logger.error(f"Database setup failed: {e}")
        raise

# Encrypt PHI
def encrypt_phi(phi, key, nonce):
    cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
    ciphertext, tag = cipher.encrypt_and_digest(phi.encode('utf-8'))
    return base64.b64encode(nonce + ciphertext + tag).decode('utf-8')

# Decrypt PHI
def decrypt_phi(encrypted_phi, key):
    try:
        raw = base64.b64decode(encrypted_phi)
        nonce, ciphertext, tag = raw[:12], raw[12:-16], raw[-16:]
        cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
        decrypted = cipher.decrypt_and_verify(ciphertext, tag)
        return decrypted.decode('utf-8')
    except Exception as e:
        logger.error(f"Decryption failed: {e}")
        raise

# Anonymize single record
def anonymize_record(record, record_id, cursor):
    try:
        text = record.get('clinical_notes', '')
        if not text:
            return None, None

        # Detect PHI
        results = analyzer.analyze(text=text, entities=["PERSON", "DATE_TIME", "PHONE_NUMBER", "MEDICAL_RECORD"], language="en")
        
        # Create mapping
        mapping = {}
        counter = 1
        for result in results:
            original = text[result.start:result.end]
            token = f"{result.entity_type}_{record_id}_{counter:03d}"
            mapping[original] = token
            encrypted_phi = encrypt_phi(original, ENCRYPTION_KEY, NONCE)
            cursor.execute(
                "INSERT INTO phi_mapping (record_id, token, encrypted_phi) VALUES (%s, %s, %s)",
                (record_id, token, encrypted_phi)
            )
            counter += 1
        
        # Anonymize text
        anonymized_result = anonymizer.anonymize(
            text=text,
            analyzer_results=results,
            operators={entity: {"type": "replace", "new_value": mapping.get(text[result.start:result.end], f"{entity}_{record_id}_{counter:03d}")}
                       for entity, result in [(r.entity_type, r) for r in results]}
        )
        
        logger.info(f"Anonymized record {record_id}")
        return anonymized_result.text, mapping
    except Exception as e:
        logger.error(f"Anonymization failed for record {record_id}: {e}")
        return None, None

# Process batch of records
def process_batch(batch, cursor):
    results = []
    for index, record in batch.iterrows():
        record_id = str(record.get('patient_id', index))
        anonymized_text, mapping = anonymize_record(record, record_id, cursor)
        if anonymized_text:
            results.append((record_id, anonymized_text, mapping))
    return results

# Deanonymize model output
def deanonymize_output(record_id, anonymized_text, cursor):
    try:
        cursor.execute("SELECT token, encrypted_phi FROM phi_mapping WHERE record_id = %s", (record_id,))
        mapping = {decrypt_phi(row[1], ENCRYPTION_KEY): row[0] for row in cursor.fetchall()}
        
        deanonymized_text = anonymized_text
        for original, token in mapping.items():
            deanonymized_text = deanonymized_text.replace(token, original)
        
        logger.info(f"Deanonymized record {record_id}")
        return deanonymized_text
    except Exception as e:
        logger.error(f"Deanonymization failed for record {record_id}: {e}")
        raise

# Main pipeline
def ehr_anonymization_pipeline(input_file, output_file):
    try:
        # Setup database
        conn, cursor = setup_database()
        
        # Load EHR data
        df = pd.read_csv(input_file)
        logger.info(f"Loaded {len(df)} records from {input_file}")
        
        # Batch processing with multiprocessing
        batch_size = 100
        batches = [df[i:i + batch_size] for i in range(0, len(df), batch_size)]
        results = []
        
        with ProcessPoolExecutor() as executor:
            for batch_results in executor.map(partial(process_batch, cursor=cursor), batches):
                results.extend(batch_results)
        
        conn.commit()  # Commit anonymized mappings
        
        # Process with language model and deanonymize
        output_data = []
        for record_id, anonymized_text, _ in results:
            model_output = mock_language_model(anonymized_text)
            deanonymized_output = deanonymize_output(record_id, model_output, cursor)
            output_data.append({'record_id': record_id, 'output': deanonymized_output})
        
        # Save results
        output_df = pd.DataFrame(output_data)
        output_df.to_csv(output_file, index=False)
        logger.info(f"Saved results to {output_file}")
        
        # Clean up
        cursor.close()
        conn.close()
        
        return output_data
    except Exception as e:
        logger.error(f"Pipeline failed: {e}")
        raise

# Example usage
if __name__ == "__main__":
    # Sample CSV: columns = ['patient_id', 'clinical_notes']
    sample_data = pd.DataFrame({
        'patient_id': ['P001', 'P002'],
        'clinical_notes': [
            "Patient John Doe was admitted on 2025-07-12 with phone 555-123-4567.",
            "Patient Jane Smith visited on 2025-07-10, MRN 123456."
        ]
    })
    sample_data.to_csv('ehr_input.csv', index=False)
    
    ehr_anonymization_pipeline('ehr_input.csv', 'ehr_output.csv')