# Stateful AI Agent for PII Scanning in Colab

This notebook implements a stateful AI agent using LangGraph and Gemini to scan CSV files for Personally Identifiable Information (PII). The agent identifies PII, generates a report, and outputs a masked version of the CSV.

**Agent Workflow:**
1.  **Load CSV:** Ingests the source CSV file.
2.  **Regex Scan:** Performs an initial pass using regular expressions to flag potential PII.
3.  **LLM Classification:** Uses a Gemini model to analyze column names and data samples to adjudicate and enrich the findings.
4.  **Consolidate:** Merges the findings from the regex and LLM steps.
5.  **Mask & Save:** Applies masking rules to the identified PII columns and saves the masked CSV and a JSON report.
6.  **Generate Report:** Creates a human-readable Markdown summary of the findings.

### 1. Setup and Installation

This block installs the necessary Python libraries required to run the agent. We keep dependencies minimal as specified.

In [None]:
!pip install -U langgraph langchain langchain-google-genai pandas pydantic python-dotenv

### 2. Configure API Key

To use the Gemini model, you need a Google API key. 

**Instructions:**
1.  Click on the **🔑 (key) icon** in the left sidebar of Colab.
2.  Click **"Add a new secret"**.
3.  Name the secret `GOOGLE_API_KEY`.
4.  Paste your API key into the "Value" field.
5.  Make sure the "Notebook access" toggle is enabled.

The code below will securely access this key. If the secret is not found, it will fall back to checking for an environment variable (useful for local development).

In [None]:
import os
from google.colab import userdata

# Attempt to get the key from Colab's user secrets
try:
    api_key = userdata.get('GOOGLE_API_KEY')
    os.environ['GOOGLE_API_KEY'] = api_key
    print("Successfully loaded GOOGLE_API_KEY from Colab secrets.")
except userdata.SecretNotFoundError:
    print("Secret 'GOOGLE_API_KEY' not found in Colab secrets.")
    # Fallback for local execution if needed, though GOOGLE_API_KEY is preferred
    if 'GEMINI_API_KEY' in os.environ:
        print("Using GEMINI_API_KEY from environment variables.")
    elif 'GOOGLE_API_KEY' not in os.environ:
        print("ERROR: Please set up the GOOGLE_API_KEY secret in Colab.")
    else:
        print("Using GOOGLE_API_KEY from environment variables.")


### 3. Imports and Pydantic Models

This cell imports all the required libraries and defines the Pydantic models. These models provide structured data validation and settings for the agent's state (`AgentState`), configuration (`Config`), and output reports (`PIIReport`, `ColumnFinding`). Using Pydantic ensures that data flowing through the agent is well-defined and type-safe.

In [None]:
from __future__ import annotations

import argparse
import hashlib
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple

import pandas as pd
from pydantic import BaseModel, Field

# LangGraph / LangChain + Gemini
from langgraph.graph import StateGraph, END
from langgraph.types import Stream
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI

# -----------------------------
# Pydantic Models (State & I/O)
# -----------------------------
class PIISample(BaseModel):
    column: str
    row_index: int
    raw_value: str
    hashed_preview: str

class ColumnFinding(BaseModel):
    column: str
    pii_types: List[str] = Field(default_factory=list)  # e.g., ["EMAIL", "PHONE"]
    confidence: float = 0.0
    rationale: Optional[str] = None
    examples: List[PIISample] = Field(default_factory=list)

class PIIReport(BaseModel):
    columns_flagged: List[ColumnFinding] = Field(default_factory=list)
    total_rows: int = 0
    summary: Dict[str, Any] = Field(default_factory=dict)

class MaskRule(BaseModel):
    pii_type: str  # "EMAIL", "PHONE", "SSN", "CREDIT_CARD", "IP", "DOB", "ADDRESS", "NAME"
    strategy: Literal[
        "redact_all",        # replace with ****
        "partial_email",     # keep local first char & domain TLD
        "partial_phone",     # keep last 4 digits
        "hash_consistent",   # sha256 stable token
        "ipv4_subnet",       # zero last octet
        "year_only",         # for DOB
    ]

class Config(BaseModel):
    sample_rows_for_llm: int = 8
    sample_rows_for_regex: int = 200
    max_examples_per_column: int = 5
    mask_rules: List[MaskRule] = Field(default_factory=lambda: [
        MaskRule(pii_type="EMAIL", strategy="partial_email"),
        MaskRule(pii_type="PHONE", strategy="partial_phone"),
        MaskRule(pii_type="SSN", strategy="hash_consistent"),
        MaskRule(pii_type="CREDIT_CARD", strategy="hash_consistent"),
        MaskRule(pii_type="IP", strategy="ipv4_subnet"),
        MaskRule(pii_type="DOB", strategy="year_only"),
        MaskRule(pii_type="ADDRESS", strategy="redact_all"),
        MaskRule(pii_type="NAME", strategy="redact_all"),
    ])
    pii_regex: Dict[str, str] = Field(default_factory=lambda: {
        "EMAIL": r"(?i)\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b",
        "PHONE": r"(?:(?:\+?1[-.\s]?)?(?:\(?\d{3}\)?|\d{3})[-.\s]?\d{3}[-.\s]?\d{4})",
        "SSN": r"\b\d{3}-?\d{2}-?\d{4}\b",
        "CREDIT_CARD": r"\b(?:\d[ -]*?){13,16}\b",
        "IP": r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
        "DOB": r"\b(?:\d{4}[-/]\d{1,2}[-/]\d{1,2}|\d{1,2}[-/]\d{1,2}[-/]\d{2,4})\b",
        # ADDRESS/NAME are hard via regex; handled by LLM hints mostly
    })

class AgentState(BaseModel):
    # Inputs
    input_csv: str
    outdir: str
    model: str = "gemini-1.5-pro"
    config: Config = Field(default_factory=Config)

    # Working data
    df_head: Optional[List[Dict[str, Any]]] = None
    columns: List[str] = Field(default_factory=list)
    llm_column_analysis: List[ColumnFinding] = Field(default_factory=list)
    regex_column_analysis: List[ColumnFinding] = Field(default_factory=list)

    # Outputs
    report: Optional[PIIReport] = None
    masked_csv_path: Optional[str] = None
    findings_json_path: Optional[str] = None
    report_md_path: Optional[str] = None

    # Meta
    total_rows: int = 0
    errors: List[str] = Field(default_factory=list)
    logs: List[str] = Field(default_factory=list)

### 4. Utility Functions

These are helper functions used across different nodes of the agent. They handle tasks like creating secure hashes of data (for privacy when sending samples to the LLM) and ensuring the output directory exists.

In [None]:
def sha256_token(v: str) -> str:
    return hashlib.sha256(v.encode("utf-8", errors="ignore")).hexdigest()[:12]


def hashed_preview(v: Any) -> str:
    s = str(v)[:64]
    return sha256_token(s)


def ensure_outdir(path: str) -> Path:
    p = Path(path)
    p.mkdir(parents=True, exist_ok=True)
    return p

### 5. LangGraph Agent Nodes

Each function below represents a single node in our stateful graph. A node is a unit of work that receives the current `AgentState`, performs a task, and returns the updated state. We break down the agent's logic into these modular, testable, and reusable components.

#### Node 1: `node_load_csv`

This is the entry point of the graph. It reads the CSV file specified in the initial state, extracts metadata like row count and column names, and stores a small sample of the data in the state for later use. It also attaches the full DataFrame to a temporary attribute `_df_cache` for efficient access by subsequent nodes, avoiding the need to pass the entire dataset through the state's main fields.

In [None]:
def node_load_csv(state: AgentState) -> AgentState:
    try:
        df = pd.read_csv(state.input_csv)
        state.total_rows = len(df)
        state.columns = list(df.columns)
        # store a small head snapshot (not full data) for reporting
        state.df_head = df.head(10).to_dict(orient="records")
        state.logs.append(f"Loaded CSV with {state.total_rows} rows and {len(state.columns)} columns.")
        state._df_cache = df  # type: ignore[attr-defined]
    except Exception as e:
        state.errors.append(f"load_csv: {e}")
    return state

#### Node 2: `node_regex_scan`

This node performs a fast, preliminary scan for PII using a predefined set of regular expressions. It operates on a sample of the data to remain efficient. For each column, it counts matches for different PII types (e.g., EMAIL, PHONE) and collects a few examples. This provides a baseline set of findings that can be used to guide the more sophisticated LLM analysis in the next step.

In [None]:
def node_regex_scan(state: AgentState) -> AgentState:
    try:
        df = getattr(state, "_df_cache")  # type: ignore[attr-defined]
        findings: List[ColumnFinding] = []
        sample_df = df.head(state.config.sample_rows_for_regex)
        for col in state.columns:
            col_find = ColumnFinding(column=col)
            series = sample_df[col].astype(str).fillna("")
            matches: Dict[str, int] = {}
            examples: List[PIISample] = []
            for pii_type, pattern in state.config.pii_regex.items():
                count = series.str.contains(pattern, regex=True).sum()
                if count > 0:
                    matches[pii_type] = int(count)
            # collect a few examples
            if matches:
                for idx, val in series.items():
                    for pii_type, pattern in state.config.pii_regex.items():
                        if series.loc[idx] and pd.notna(series.loc[idx]) and pd.notna(val):
                            if pd.Series([str(val)]).str.contains(pattern, regex=True).iloc[0]:
                                examples.append(
                                    PIISample(
                                        column=col,
                                        row_index=int(idx),
                                        raw_value=str(val)[:64],
                                        hashed_preview=hashed_preview(val),
                                    )
                                )
                                break
                col_find.pii_types = sorted(matches.keys())
                # simple confidence heuristic
                col_find.confidence = min(1.0, sum(matches.values()) / max(1, state.config.sample_rows_for_regex))
                col_find.rationale = "Regex/heuristic match counts over sample"
                col_find.examples = examples[: state.config.max_examples_per_column]
            findings.append(col_find)
        state.regex_column_analysis = findings
        state.logs.append("Completed regex scan.")
    except Exception as e:
        state.errors.append(f"regex_scan: {e}")
    return state

#### Node 3: `node_llm_classify`

This node leverages a Gemini LLM to perform a more nuanced analysis. It sends the column names, the hints from the regex scan, and a small number of *hashed* data samples to the model. By sending only hashed previews, we protect the raw data from being exposed to the LLM. The LLM is prompted to act as a security specialist and return a structured JSON response identifying PII types, its confidence, and a rationale for its decision for each column.

In [None]:
def node_llm_classify(state: AgentState) -> AgentState:
    try:
        # Configure Gemini (prefers GOOGLE_API_KEY)
        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            raise RuntimeError("Missing GOOGLE_API_KEY. Please configure it in Colab secrets.")
        llm = ChatGoogleGenerativeAI(model=state.model, google_api_key=api_key, temperature=0.1)

        # Prepare redacted samples per column (hash previews only)
        sample_rows = []
        df = getattr(state, "_df_cache")  # type: ignore[attr-defined]
        head = df.head(state.config.sample_rows_for_llm)
        for i, row in head.iterrows():
            sample_rows.append({c: hashed_preview(row[c]) for c in state.columns})

        system = SystemMessage(
            content=(
                "You are a Senior Security Engineer specializing in data privacy and compliance.\n"
                "Task: Given CSV schema and hashed examples, identify which columns likely contain PII.\n"
                "Return ONLY a JSON array of objects: {column, pii_types:[...], confidence:0..1, rationale}.\n"
                "PII types limited to: EMAIL, PHONE, SSN, CREDIT_CARD, IP, DOB, ADDRESS, NAME, NONE.\n"
                "Prefer precision over recall; do not hallucinate.\n"
            )
        )
        human = HumanMessage(
            content=json.dumps({
                "columns": state.columns,
                "regex_hints": [cf.model_dump() for cf in state.regex_column_analysis],
                "hashed_samples": sample_rows,
            })
        )
        resp = llm.invoke([system, human])
        text = resp.content if isinstance(resp, AIMessage) else str(resp)
        parsed: List[Dict[str, Any]] = []
        try:
            # Clean the text to extract only the JSON part
            json_text = text.strip().replace('`json', '').replace('`', '')
            parsed = json.loads(json_text)  # expect array
        except Exception:
            # attempt to extract JSON block with regex as a fallback
            m = re.search(r"\\[\\s*{[\\s\\S]*}\\s*\\]", text)
            if m:
                parsed = json.loads(m.group(0))
            else:
                raise
        findings = []
        for item in parsed:
            findings.append(
                ColumnFinding(
                    column=item.get("column", ""),
                    pii_types=[pt for pt in item.get("pii_types", []) if pt != "NONE"],
                    confidence=float(item.get("confidence", 0.0)),
                    rationale=item.get("rationale"),
                )
            )
        state.llm_column_analysis = findings
        state.logs.append("LLM classification completed.")
    except Exception as e:
        state.errors.append(f"llm_classify: {e}")
    return state

#### Node 4: `node_consolidate`

This node merges the results from the regex scan and the LLM classification. It combines the identified PII types for each column, taking the highest confidence score and appending the rationales. This creates a single, unified list of findings. It then populates the main `PIIReport` in the state with these consolidated findings and a summary of PII types found.

In [None]:
def _merge_findings(regex: List[ColumnFinding], llm: List[ColumnFinding]) -> List[ColumnFinding]:
    by_col: Dict[str, ColumnFinding] = {c.column: c for c in regex}
    for lf in llm:
        if lf.column in by_col:
            base = by_col[lf.column]
            merged_types = sorted(set(base.pii_types) | set(lf.pii_types))
            base.pii_types = merged_types
            base.confidence = max(base.confidence, lf.confidence)
            base.rationale = (base.rationale or "") + " | LLM: " + (lf.rationale or "")
        else:
            by_col[lf.column] = lf
    return list(by_col.values())


def node_consolidate(state: AgentState) -> AgentState:
    try:
        cols = _merge_findings(state.regex_column_analysis, state.llm_column_analysis)
        # Filter out columns with no detected PII types
        flagged_cols = [c for c in cols if c.pii_types]
        report = PIIReport(columns_flagged=flagged_cols, total_rows=state.total_rows)
        # summary counts by pii type
        summary: Dict[str, int] = {}
        for c in flagged_cols:
            for t in c.pii_types:
                summary[t] = summary.get(t, 0) + 1
        report.summary = {"columns_with_type": summary}
        state.report = report
        state.logs.append("Consolidated findings.")
    except Exception as e:
        state.errors.append(f"consolidate: {e}")
    return state

#### Node 5: `node_mask_and_save`

This node takes the final, consolidated report and applies the configured masking rules to the original data. It iterates through the columns identified as containing PII and applies the appropriate masking strategy (e.g., redacting, hashing, or partial replacement). The resulting masked DataFrame is then written to `masked.csv`, and the detailed findings are saved to `findings.json` in the specified output directory.

In [None]:
def apply_mask(value: Any, pii_types: List[str], rules: List[MaskRule]) -> str:
    s = str(value)
    if not s or pd.isna(value) or not pii_types:
        return s
    # apply first matching rule by preference order
    for t in pii_types:
        rule = next((r for r in rules if r.pii_type == t), None)
        if not rule:
            continue
        strat = rule.strategy
        if strat == "redact_all":
            return "████"
        if strat == "hash_consistent":
            return f"token_{sha256_token(s)}"
        if strat == "partial_email":
            parts = s.split("@")
            if len(parts) == 2:
                local, domain = parts
                tld = domain.split(".")[-1] if "." in domain else "dom"
                return f"{local[:1]}***@***.{tld}"
            return f"token_{sha256_token(s)}"
        if strat == "partial_phone":
            digits = ''.join(ch for ch in s if ch.isdigit())
            return f"***-***-{digits[-4:]}" if len(digits) >= 4 else "***-***-****"
        if strat == "ipv4_subnet":
            parts = s.split('.')
            if len(parts) == 4:
                return '.'.join(parts[:3] + ['0'])
            return f"ip_{sha256_token(s)}"
        if strat == "year_only":
            m = re.search(r"(\d{4})", s)
            return m.group(1) if m else "YYYY"
    # fallback
    return f"token_{sha256_token(s)}"


def node_mask_and_save(state: AgentState) -> AgentState:
    try:
        df = getattr(state, "_df_cache")  # type: ignore[attr-defined]
        # Build map column -> pii_types
        pii_map: Dict[str, List[str]] = {c.column: c.pii_types for c in state.report.columns_flagged} if state.report else {}
        rules = state.config.mask_rules
        masked = df.copy()
        for col, types in pii_map.items():
            if not types:
                continue
            masked[col] = masked[col].apply(lambda v: apply_mask(v, types, rules))
        outdir = ensure_outdir(state.outdir)
        masked_path = outdir / "masked.csv"
        masked.to_csv(masked_path, index=False)
        state.masked_csv_path = str(masked_path)
        # Write findings.json
        findings_json = outdir / "findings.json"
        with open(findings_json, "w", encoding="utf-8") as f:
            json.dump(state.report.model_dump() if state.report else {}, f, indent=2)
        state.findings_json_path = str(findings_json)
        state.logs.append(f"Masked CSV saved to {masked_path}")
    except Exception as e:
        state.errors.append(f"mask_and_save: {e}")
    return state

#### Node 6: `node_generate_report`

This is the final node in the workflow. It creates a human-readable summary of the agent's findings in Markdown format. The report includes a summary of PII types discovered, details for each flagged column (including PII types, confidence, and hashed examples), and file paths. This report is saved as `report.md`.

In [None]:
def node_generate_report(state: AgentState) -> AgentState:
    try:
        outdir = ensure_outdir(state.outdir)
        md_path = outdir / "report.md"
        report = state.report or PIIReport()
        lines = [
            "# PII Detection Report",
            "",
            f"**Input:** `{state.input_csv}`",
            f"**Rows:** {state.total_rows}",
            "",
            "## Summary",
            "```json",
            json.dumps(report.summary, indent=2),
            "```",
            "",
            "## Flagged Columns Details",
        ]
        if not report.columns_flagged:
            lines.append("No PII was detected in any columns.")
        else:
            for c in report.columns_flagged:
                lines += [
                    f"### Column: `{c.column}`",
                    f"- **PII Types**: {', '.join(c.pii_types) if c.pii_types else 'None'}",
                    f"- **Confidence**: {c.confidence:.2f}",
                    f"- **Rationale**: {c.rationale or '-'}",
                ]
                if c.examples:
                    lines.append("- **Examples (hashed previews):**")
                    for ex in c.examples[:3]:
                        lines.append(f"  - `row {ex.row_index}`: {ex.hashed_preview}")
                lines.append("")
        with open(md_path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines))
        state.report_md_path = str(md_path)
        state.logs.append(f"Report written to {md_path}")
    except Exception as e:
        state.errors.append(f"generate_report: {e}")
    return state

### 6. Build and Compile the Graph

This function assembles all the individual nodes into a coherent workflow using `StateGraph`. It defines the sequence of operations by adding edges between the nodes, creating a linear progression from loading the CSV to generating the final report. Finally, it compiles the graph into a runnable application.

In [None]:
def build_graph() -> StateGraph:
    graph = StateGraph(AgentState)
    graph.add_node("load_csv", node_load_csv)
    graph.add_node("regex_scan", node_regex_scan)
    graph.add_node("llm_classify", node_llm_classify)
    graph.add_node("consolidate", node_consolidate)
    graph.add_node("mask_and_save", node_mask_and_save)
    graph.add_node("generate_report", node_generate_report)

    graph.set_entry_point("load_csv")
    graph.add_edge("load_csv", "regex_scan")
    graph.add_edge("regex_scan", "llm_classify")
    graph.add_edge("llm_classify", "consolidate")
    graph.add_edge("consolidate", "mask_and_save")
    graph.add_edge("mask_and_save", "generate_report")
    graph.add_edge("generate_report", END)
    return graph

### 7. Prepare Data and Define Inputs

This cell creates a sample `customers.csv` file in the Colab environment to make the notebook self-contained and runnable without requiring a file upload. It also defines the input and output directories that the agent will use.

In [None]:
# Create dummy data directory and a sample CSV file
os.makedirs("./data", exist_ok=True)
os.makedirs("./out", exist_ok=True)

csv_content = """
name,email,phone,notes,ip_address
Alice,alice@example.com,+1-415-555-0199,Call after 5pm,192.168.1.10
Bob,bob@acme.co,4155550188,VIP Customer,203.0.113.45
Charlie,charlie+test@gmail.com,555-867-5309,Met at conference,198.51.100.2
Diana,diana@work.net,(415) 555-0122,Follow up next week,2001:db8:85a3:8d3:1319:8a2e:370:7348
"""

input_csv_path = "./data/customers.csv"
with open(input_csv_path, "w") as f:
    f.write(csv_content.strip())

print(f"Sample CSV created at: {input_csv_path}")

# Define agent inputs
INPUT_CSV = input_csv_path
OUTPUT_DIR = "./out"
MODEL_NAME = "gemini-1.5-pro-latest" # Using the latest 1.5 Pro model

### 8. Run the Agent

This is the main execution block. It initializes the `AgentState` with the input parameters, compiles the graph, and then runs the agent. We use `app.stream` to execute the graph, which allows us to see logs and errors from each node in real-time as it completes its task. After the run, a final summary of the output file paths is printed.

In [None]:
# Initialize the state with our inputs
initial_state = AgentState(
    input_csv=INPUT_CSV, 
    outdir=OUTPUT_DIR, 
    model=MODEL_NAME
)

# Compile the graph
app = build_graph().compile()

# Stream execution (yields state updates per node)
print("--- Running PII Agent ---")
final_state = {}
for update in app.stream(initial_state, stream_mode="values"):
    if isinstance(update, AgentState):
        if update.logs:
            print(f"LOG: {update.logs[-1]}")
        if update.errors:
            print(f"ERROR: {update.errors[-1]}")
        final_state = update

# Final summary
print("\n=== Agent Run Complete ===")
print(f"Masked CSV:      {final_state.masked_csv_path}")
print(f"Findings JSON:   {final_state.findings_json_path}")
print(f"Report MD:       {final_state.report_md_path}")
if final_state.errors:
    print(f"Errors encountered: {final_state.errors}")

### 9. Review Outputs

The following cells display the contents of the files generated by the agent, allowing for immediate verification of the results.

#### Masked CSV (`masked.csv`)

In [None]:
masked_df = pd.read_csv(final_state.masked_csv_path)
print(masked_df.to_markdown(index=False))

#### Findings JSON (`findings.json`)

In [None]:
with open(final_state.findings_json_path, 'r') as f:
    print(f.read())

#### Markdown Report (`report.md`)

In [None]:
from IPython.display import display, Markdown

with open(final_state.report_md_path, 'r') as f:
    md_content = f.read()
    
display(Markdown(md_content))

### Self-Assessment & Suggestions

-   **Security**: We avoid sending raw values to the LLM; only hashed previews + regex hints are shared.
-   **Performance**: Regex scan is bounded to N rows; consider sampling strategies for very large CSVs.
-   **Extensibility**: Add custom rules per jurisdiction (GDPR special categories). Add entity-level masking (named entities inside free text) with a local spaCy NER fallback if LLM is not allowed.
-   **Reliability**: Add unit tests with synthetic data. Optionally enable forced function calling via Vertex AI tool config when you migrate to Vertex.
-   **Observability**: Persist `state.logs` and timings; integrate LangSmith or OpenTelemetry as needed.