In [None]:
# Install the Google GenAI SDK and Pydantic for structured output

#### RUN IN BASH
# pip install google-genai pydantic

In [None]:
from pydantic import BaseModel, Field

# ⚠️ Define this class in your Python script
class WLExtractionSchema(BaseModel):
    """Schema for extracting statistical and compliance features from a Warning Letter."""
    cfr_citation: str = Field(description="The specific CFR regulation cited (e.g., '21 CFR 312.60').")
    violation_quote: str = Field(description="The exact text quote from the letter describing the specific failure or observation.")
    protocol_section_focus: str = Field(description="Classify the part of the protocol that was broken (e.g., 'Eligibility Criteria', 'Endpoint Assessment', 'Dosing & Administration').")
    statistical_bias_inferred: str = Field(description="The primary statistical risk created by the failure. Must be one of: 'Selection Bias', 'Detection Bias', or 'Performance Bias'.")
    hardening_pattern: str = Field(description="A concise, generic summary of the failure pattern (e.g., 'Failure to pre-specify rescue medication rules').")

In [None]:
# ⚠️ This is the main instruction string you will feed the LLM
SYSTEM_INSTRUCTION = """
You are a highly specialized Regulatory Risk Analyst and Biostatistician. Your task is to analyze the provided FDA Warning Letter text and extract the single most significant clinical trial protocol violation.

Follow these rules for your analysis:
1.  **IDENTIFY:** Locate the specific section in the letter that details the investigator's failure to follow the protocol (21 CFR 312.60 or similar).
2.  **CITE:** Extract the exact CFR citation and the direct quote describing the failure.
3.  **CLASSIFY BIAS:** Based on the definition below, classify the violation's primary statistical risk:
    * **Selection Bias:** Compromises who enters/stays in the study (Eligibility, Recruitment).
    * **Detection Bias:** Compromises the measurement of the outcome (Endpoint Assessment, Blinding).
    * **Performance Bias:** Compromises treatment or care received (Dosing, Concomitant Medications, Adherence).
4.  **CATEGORIZE SECTION:** Determine which high-level protocol section was violated.

Output your findings STRICTLY as a JSON object conforming to the provided schema.
"""

In [None]:
import os
import json
from google import genai
from google.genai import types

# Initialize the client (ensure your API key is set up)
client = genai.Client()
MODEL_NAME = 'gemini-2.5-pro'  # Use the powerful model for complex reasoning

# Assuming your WL text files are in a directory named 'warning_letters_text/'
WL_DIRECTORY = 'warning_letters_text/'
OUTPUT_FILE = 'violation_pattern_library.jsonl' # JSON Lines format is good for large lists

def extract_violation_features(file_path):
    """Processes one Warning Letter file using the LLM."""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            wl_text = f.read()

        # Combine system instructions with the letter text
        full_prompt = SYSTEM_INSTRUCTION + "\n\n--- WARNING LETTER TEXT ---\n" + wl_text

        # Call the LLM
        response = client.models.generate_content(
            model=MODEL_NAME,
            contents=full_prompt,
            config=types.GenerateContentConfig(
                response_mime_type="application/json",
                response_schema=WLExtractionSchema,
            ),
        )

        # The response text is a JSON string conforming to the schema
        return json.loads(response.text)

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

# --- Main Processing Loop ---
all_patterns = []
for filename in os.listdir(WL_DIRECTORY):
    if filename.endswith(".txt"):
        print(f"Processing: {filename}...")
        
        # Use a limit for initial testing, then remove it to process all 1000
        # if len(all_patterns) >= 20: break 

        features = extract_violation_features(os.path.join(WL_DIRECTORY, filename))
        if features:
            all_patterns.append(features)
            
            # Save progress incrementally (using JSON Lines)
            with open(OUTPUT_FILE, 'a', encoding='utf-8') as outfile:
                outfile.write(json.dumps(features) + '\n')

print(f"\nCompleted extraction. Saved {len(all_patterns)} patterns to {OUTPUT_FILE}")