In [2]:
#!/usr/bin/env python3
"""
Model Recovery Module for BantAI TravelAware
==========================================

This module provides functionality to analyze and fix potential model corruption issues.
"""

import os
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler

def analyze_model_file(model_path="bantai_model.pkl"):
    """Analyze the model file for corruption or issues."""
    issues = []
    
    # Check if file exists
    if not os.path.exists(model_path):
        issues.append(f"Model file {model_path} not found")
        return False, issues
    
    try:
        # Try loading the model
        model_data = joblib.load(model_path)
        
        # Check if it's a dict with expected keys
        if isinstance(model_data, dict):
            required_keys = ["model", "scaler", "features"]
            missing_keys = [key for key in required_keys if key not in model_data]
            if missing_keys:
                issues.append(f"Missing required keys in model file: {missing_keys}")
        
        # Check model type
        if not isinstance(model_data.get("model"), RandomForestClassifier):
            issues.append("Invalid model type - expected RandomForestClassifier")
        
        # Check scaler
        if not isinstance(model_data.get("scaler"), StandardScaler):
            issues.append("Invalid scaler type - expected StandardScaler")
        
        # Check features list
        features = model_data.get("features")
        if not isinstance(features, list) or not all(isinstance(f, str) for f in features):
            issues.append("Invalid features list format")
            
    except Exception as e:
        issues.append(f"Error loading model: {str(e)}")
        return False, issues
    
    return len(issues) == 0, issues

def attempt_model_recovery(model_path="bantai_model.pkl"):
    """Attempt to recover a corrupted model file."""
    try:
        # Try to load the model file
        model_data = joblib.load(model_path)
        
        # Initialize recovery status
        recovered = False
        recovery_notes = []
        
        # Check and fix model components
        if isinstance(model_data, dict):
            # 1. Check/fix model
            if "model" not in model_data or not isinstance(model_data["model"], RandomForestClassifier):
                model_data["model"] = RandomForestClassifier(n_estimators=200, random_state=42)
                recovery_notes.append("Replaced invalid model with new RandomForestClassifier")
                recovered = True
            
            # 2. Check/fix scaler
            if "scaler" not in model_data or not isinstance(model_data["scaler"], StandardScaler):
                model_data["scaler"] = StandardScaler()
                recovery_notes.append("Replaced invalid scaler with new StandardScaler")
                recovered = True
            
            # 3. Check/fix features list
            default_features = [
                "distance_km", "time_gap_hours", "travel_speed_kmh", "impossible_travel",
                "night_access", "business_hours", "hour_norm", "day_norm", "mobile_device",
                "device_change", "attack_ip", "high_latency", "failed_login", "asn_change",
                "foreign_access", "metro_manila", "rtt_ms"
            ]
            
            if "features" not in model_data or not isinstance(model_data["features"], list):
                model_data["features"] = default_features
                recovery_notes.append("Restored default feature list")
                recovered = True
            
            # Save recovered model if changes were made
            if recovered:
                joblib.dump(model_data, model_path)
                recovery_notes.append(f"Saved recovered model to {model_path}")
                
        else:
            # Create new model data structure if completely corrupted
            new_model_data = {
                "model": RandomForestClassifier(n_estimators=200, random_state=42),
                "scaler": StandardScaler(),
                "features": default_features
            }
            joblib.dump(new_model_data, model_path)
            recovered = True
            recovery_notes.append("Created new model file with default configuration")
            
        return recovered, recovery_notes
        
    except Exception as e:
        return False, [f"Recovery failed: {str(e)}"]

def comprehensive_model_fix():
    """
    Comprehensive function to analyze model issues and attempt recovery.
    Returns dict with analysis results and recovery status.
    """
    model_path = "bantai_model.pkl"
    
    # Step 1: Analyze current state
    is_valid, issues = analyze_model_file(model_path)
    
    result = {
        "initial_status": "valid" if is_valid else "invalid",
        "issues_found": issues,
        "recovery_attempted": False,
        "recovery_successful": False,
        "recovery_notes": [],
        "final_status": "valid" if is_valid else "invalid"
    }
    
    # Step 2: Attempt recovery if needed
    if not is_valid:
        result["recovery_attempted"] = True
        recovered, notes = attempt_model_recovery(model_path)
        result["recovery_successful"] = recovered
        result["recovery_notes"] = notes
        
        # Check final state
        final_valid, final_issues = analyze_model_file(model_path)
        result["final_status"] = "valid" if final_valid else "invalid"
        if not final_valid:
            result["remaining_issues"] = final_issues
    
    return result