<a href="https://colab.research.google.com/github/micah-shull/AI_Agents/blob/main/381_RiskScoring_Utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""Risk scoring utilities for Governance & Compliance Orchestrator

Calculates risk scores for agents and overall system based on violations,
bias signals, and drift signals.
"""

from typing import Dict, Any, List
from config import GovernanceComplianceOrchestratorConfig


def calculate_agent_risk_score(
    agent_name: str,
    compliance_events: List[Dict[str, Any]],
    bias_signals: List[Dict[str, Any]],
    drift_signals: List[Dict[str, Any]],
    events_lookup: Dict[str, Dict[str, Any]],
    config: GovernanceComplianceOrchestratorConfig
) -> Dict[str, Any]:
    """
    Calculate risk score for a specific agent.

    Args:
        agent_name: Name of the agent
        compliance_events: List of compliance events (violations)
        bias_signals: List of bias signals
        drift_signals: List of drift signals
        events_lookup: Lookup dict for events by event_id
        config: Configuration with severity weights

    Returns:
        Risk score dict for the agent
    """
    # Filter violations for this agent by looking up the event
    agent_violations = []
    for event in compliance_events:
        event_id = event.get("event_id")
        if event_id and event_id in events_lookup:
            original_event = events_lookup[event_id]
            if original_event.get("agent_name") == agent_name:
                agent_violations.append(event)

    # Count violations by severity
    severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
    for violation in agent_violations:
        severity = violation.get("severity", "medium")
        if severity in severity_counts:
            severity_counts[severity] += 1

    # Count bias signals for this agent
    agent_bias = [s for s in bias_signals if s.get("agent_name") == agent_name]
    bias_count = len(agent_bias)

    # Count drift signals for this agent
    agent_drift = [s for s in drift_signals if s.get("agent_name") == agent_name]
    drift_count = len(agent_drift)

    # Calculate weighted risk score (0.0 to 1.0)
    total_violations = sum(severity_counts.values())
    weighted_violation_score = sum(
        severity_counts[sev] * config.severity_weights.get(sev, 0.5)
        for sev in severity_counts
    )

    # Normalize by total violations (if any)
    if total_violations > 0:
        violation_score = weighted_violation_score / total_violations
    else:
        violation_score = 0.0

    # Add bias and drift contributions
    bias_score = min(1.0, bias_count * 0.2)  # Each bias signal adds 0.2
    drift_score = min(1.0, drift_count * 0.15)  # Each drift signal adds 0.15

    # Combined risk score (weighted average)
    risk_score = (
        violation_score * 0.6 +
        bias_score * 0.25 +
        drift_score * 0.15
    )

    return {
        "agent_name": agent_name,
        "total_violations": total_violations,
        "severity_counts": severity_counts,
        "bias_signals_count": bias_count,
        "drift_signals_count": drift_count,
        "risk_score": round(risk_score, 3)
    }


def calculate_overall_risk_scores(
    compliance_events: List[Dict[str, Any]],
    bias_signals: List[Dict[str, Any]],
    drift_signals: List[Dict[str, Any]],
    agent_action_logs: List[Dict[str, Any]],
    events_lookup: Dict[str, Dict[str, Any]],
    config: GovernanceComplianceOrchestratorConfig
) -> Dict[str, Any]:
    """
    Calculate risk scores for all agents and overall system.

    Args:
        compliance_events: List of compliance events (violations)
        bias_signals: List of bias signals
        drift_signals: List of drift signals
        agent_action_logs: List of agent action log events
        events_lookup: Lookup dict for events by event_id
        config: Configuration with severity weights

    Returns:
        Risk scores dict with agent_scores and overall_risk_score
    """
    # Get unique agent names
    agent_names = set()
    for event in agent_action_logs:
        agent_name = event.get("agent_name")
        if agent_name:
            agent_names.add(agent_name)

    # Calculate risk score for each agent
    agent_scores = {}
    for agent_name in agent_names:
        agent_score = calculate_agent_risk_score(
            agent_name,
            compliance_events,
            bias_signals,
            drift_signals,
            events_lookup,
            config
        )
        agent_scores[agent_name] = agent_score

    # Calculate overall risk score (average of agent scores, weighted by violations)
    if agent_scores:
        total_violations_all = sum(
            score.get("total_violations", 0) for score in agent_scores.values()
        )

        if total_violations_all > 0:
            # Weight by violations
            weighted_sum = sum(
                score.get("risk_score", 0.0) * score.get("total_violations", 0)
                for score in agent_scores.values()
            )
            overall_risk_score = weighted_sum / total_violations_all
        else:
            # Simple average if no violations
            overall_risk_score = sum(
                score.get("risk_score", 0.0) for score in agent_scores.values()
            ) / len(agent_scores)
    else:
        overall_risk_score = 0.0

    return {
        "agent_scores": agent_scores,
        "overall_risk_score": round(overall_risk_score, 3),
        "total_violations": len(compliance_events),
        "total_bias_signals": len(bias_signals),
        "total_drift_signals": len(drift_signals)
    }

