In [None]:
#!/usr/bin/env python3
"""
BantAI TravelAware - Advanced Fraud Detection System
====================================================

A comprehensive fraud detection system that combines:
- Geographic travel analysis with impossible travel detection
- Machine learning behavioral pattern recognition
- Technical indicators analysis
- Unified risk scoring system

Author: BantAI Team
Version: 1.0
"""

import os
import json
import time
import math
import joblib
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, roc_auc_score, precision_recall_curve, f1_score
from geopy.geocoders import Nominatim
from geopy.distance import geodesic
from sklearn.preprocessing import StandardScaler

# Optional libraries
try:
    from imblearn.over_sampling import SMOTE
    IMB_AVAILABLE = True
except Exception:
    IMB_AVAILABLE = False

# tqdm fallback
try:
    from tqdm import tqdm
except Exception:
    tqdm = lambda iterable, **kwargs: iterable

# plotting libs (optional display)
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")


class BantAI_TravelAware:
    def __init__(self,
                 cache_file="geocache.json",
                 geocode_delay=1.0,
                 ml_model_path="bantai_model.pkl",
                 ml_model_feature_columns=None):
        """
        cache_file: path to JSON cache for geocoding
        geocode_delay: seconds between geocode requests (respect Nominatim policies)
        ml_model_path: where to save the trained ML model
        ml_model_feature_columns: optional list of what ML features to expect (auto-filled)
        """
        self.model = None
        self.scaler = None
        self.geolocator = Nominatim(user_agent="bantai_ai")
        self.cache_file = cache_file
        self.geo_cache = self._load_cache()
        self.geocode_delay = geocode_delay
        self.ml_model_path = ml_model_path
        self.ml_feature_columns = ml_model_feature_columns

    # ------------------------
    # Cache handling
    # ------------------------
    def _load_cache(self):
        if os.path.exists(self.cache_file):
            try:
                with open(self.cache_file, "r", encoding="utf-8") as f:
                    return json.load(f)
            except Exception:
                return {}
        return {}

    def _save_cache(self):
        with open(self.cache_file, "w", encoding="utf-8") as f:
            json.dump(self.geo_cache, f)

    # ------------------------
    # Geocoding with cache + rate limit
    # ------------------------
    def get_coordinates(self, city, country):
        """Return (lat, lon) or None. Skips invalid inputs."""
        if pd.isna(city) or pd.isna(country):
            return None
        if str(city).strip() == "-" or str(country).strip() == "-":
            return None

        key = f"{city.strip()},{country.strip()}"
        if key in self.geo_cache:
            val = self.geo_cache[key]
            return tuple(val) if val is not None else None

        # Not in cache -> geocode
        try:
            location = self.geolocator.geocode(key, timeout=10)
            time.sleep(self.geocode_delay)  # rate limit
            if location:
                coords = (location.latitude, location.longitude)
                self.geo_cache[key] = coords
                self._save_cache()
                return coords
        except Exception as e:
            # store None to avoid repeated attempts
            print(f"⚠️ Geocoding error for {key}: {e}")
        self.geo_cache[key] = None
        self._save_cache()
        return None

    def precompute_all_coords(self, df):
        """Precompute coordinates for all unique (city,country) pairs in df."""
        if not {"city", "country"}.issubset(set(df.columns)):
            return
        unique_locations = df[["city", "country"]].drop_duplicates()
        print(f"🔍 Precomputing coordinates for {len(unique_locations)} unique city-country pairs...")
        for _, row in tqdm(unique_locations.iterrows(), total=len(unique_locations)):
            self.get_coordinates(row["city"], row["country"])
        print("✅ All coordinates cached!")

    # ------------------------
    # Travel & distance helpers
    # ------------------------
    def compute_distance_km(self, city1, country1, city2, country2):
        c1 = self.get_coordinates(city1, country1)
        c2 = self.get_coordinates(city2, country2)
        if c1 and c2:
            try:
                return geodesic(c1, c2).kilometers
            except Exception:
                return 0.0
        return 0.0

    def travel_plausibility(self, last_time, last_city, last_country, curr_time, curr_city, curr_country):
        """
        Returns dict with distance_km, time_gap_hours, travel_speed_kmh, plausible(bool), travel_risk (0-1)
        """
        distance_km = self.compute_distance_km(last_city, last_country, curr_city, curr_country)
        time_gap_hours = (curr_time - last_time).total_seconds() / 3600.0 if curr_time and last_time else np.inf

        max_speed = 900.0  # km/h for commercial aircraft
        min_travel_time = distance_km / max_speed if max_speed > 0 else np.inf
        buffer = 4.0  # hours for airport/layovers
        required = min_travel_time + buffer

        if distance_km == 0:
            plausible = True
            travel_risk = 0.0
        elif time_gap_hours >= required:
            plausible = True
            travel_risk = 0.0
        else:
            plausible = False
            # risk modifier proportional to how impossible
            ratio = (required - time_gap_hours) / (required + 1e-6)
            travel_risk = min(1.0, 0.6 + 0.4 * ratio)  # base high risk for impossible travel

        return {
            "distance_km": distance_km,
            "time_gap_hours": time_gap_hours,
            "travel_speed_kmh": (distance_km / time_gap_hours) if time_gap_hours > 0 else 0.0,
            "plausible": plausible,
            "travel_risk": travel_risk,
            "required_hours": required
        }

    # ------------------------
    # Behavioral consistency
    # ------------------------
    def behavioral_consistency_score(self, user_history_df, current_row):
        """
        user_history_df: DataFrame of previous logins for the same user (sorted by time)
        current_row: a dict/Series for current login
        Returns score in [0,1] where 1 = fully consistent.
        """
        # Default score
        score = 1.0
        factors = []

        # Device consistency
        prev_devices = user_history_df["Device Type"].dropna().unique().tolist() if "Device Type" in user_history_df else []
        curr_device = current_row.get("Device Type", None)
        if curr_device in prev_devices:
            factors.append("Known device")
        else:
            score -= 0.25
            factors.append("New device")

        # Time-of-day consistency
        if "Login Timestamp" in user_history_df:
            hist_hours = pd.to_datetime(user_history_df["Login Timestamp"]).dt.hour.dropna().tolist()
            if len(hist_hours) > 0:
                curr_hour = pd.to_datetime(current_row["Login Timestamp"]).hour
                avg = np.mean(hist_hours)
                if abs(curr_hour - avg) <= 3:
                    factors.append("Time consistent")
                else:
                    score -= 0.15
                    factors.append("Unusual hour")
        else:
            factors.append("No time history")

        # Location familiarity
        prev_countries = user_history_df["Country"].dropna().unique().tolist() if "Country" in user_history_df else []
        if current_row.get("Country") in prev_countries:
            factors.append("Known country")
            score += 0.05
        else:
            factors.append("New country")

        # Device/browser fingerprint approximation (based on UA string change)
        prev_ua = user_history_df["User Agent String"].dropna().astype(str).unique().tolist() if "User Agent String" in user_history_df else []
        curr_ua = str(current_row.get("User Agent String", ""))
        if any(curr_ua == ua for ua in prev_ua):
            factors.append("Same UA")
        else:
            score -= 0.1
            factors.append("UA changed")

        # Clamp
        score = max(0.0, min(1.0, score))
        return {"consistency_score": score, "factors": factors}

    # ------------------------
    # Technical indicator scoring
    # ------------------------
    def technical_score(self, row):
        """
        Build technical risk score in [0,1] based on available indicators.
        Use Round-Trip Time (ms) as high_latency, Is Attack IP, Login Successful.
        """
        score = 0.0
        factors = []

        # Attack IP
        attack_ip = False
        if "Is Attack IP" in row:
            val = row["Is Attack IP"]
            # Support both boolean and strings
            if isinstance(val, str):
                attack_ip = val.strip().upper() in ["TRUE", "1", "YES"]
            else:
                attack_ip = bool(val)
        if attack_ip:
            score += 0.6
            factors.append("Known attack IP")

        # High latency
        if "Round-Trip Time [ms]" in row:
            try:
                rtt = float(row["Round-Trip Time [ms]"])
                if rtt > 500:  # threshold for suspiciously high latency
                    score += 0.25
                    factors.append("High RTT")
            except Exception:
                pass

        # Failed login
        if "Login Successful" in row:
            ok = row["Login Successful"]
            # handle string forms
            if isinstance(ok, str):
                success = ok.strip().upper() in ["TRUE", "1", "YES"]
            else:
                success = bool(ok)
            if not success:
                score += 0.35
                factors.append("Failed login")

        # normalize to [0,1]
        score = min(1.0, score)
        return {"technical_score": score, "factors": factors}

    # ------------------------
    # Feature engineering for ML
    # ------------------------
    def build_ml_features_for_df(self, df):
        """
        Build ML feature matrix from chronological df.
        Expects df to contain: time (parsed), Country, City, Device Type, Round-Trip Time [ms], Is Attack IP, Login Successful, User Agent String, ASN
        Returns X (DataFrame), y (Series)
        Each sample corresponds to a login event with features computed w.r.t. previous event of same user.
        """
        rows = []
        labels = []
        groupby_user = "User ID" if "User ID" in df.columns else None

        if groupby_user:
            for user, group in df.groupby(groupby_user):
                group = group.sort_values("time")
                if len(group) < 2:
                    continue
                prev = None
                for idx, row in group.iterrows():
                    if prev is None:
                        prev = row
                        continue
                    curr = row
                    # Features: travel-based
                    dist = self.compute_distance_km(prev.get("City"), prev.get("Country"), curr.get("City"), curr.get("Country"))
                    time_gap = (pd.to_datetime(curr["time"]) - pd.to_datetime(prev["time"])).total_seconds() / 3600.0
                    travel_speed = dist / time_gap if time_gap > 0 else 0.0
                    impossible_flag = 1 if (not self.travel_plausibility(pd.to_datetime(prev["time"]), prev.get("City"), prev.get("Country"), pd.to_datetime(curr["time"]), curr.get("City"), curr.get("Country"))["plausible"]) else 0

                    # Temporal features
                    hour = pd.to_datetime(curr["time"]).hour
                    night_access = 1 if (hour >= 22 or hour <= 5) else 0
                    business_hours = 1 if (8 <= hour <= 17) else 0
                    hour_norm = hour / 23.0
                    day_norm = pd.to_datetime(curr["time"]).weekday() / 6.0

                    # Device features
                    mobile_device = 1 if isinstance(curr.get("Device Type", ""), str) and "mobile" in curr.get("Device Type", "").lower() else 0
                    device_change = 1 if prev.get("Device Type") != curr.get("Device Type") else 0

                    # Technical features
                    attack_ip = 1 if str(curr.get("Is Attack IP", False)).strip().upper() in ["TRUE", "1", "YES"] else 0
                    try:
                        rtt = float(curr.get("Round-Trip Time [ms]") if curr.get("Round-Trip Time [ms]") is not None else 0.0)
                    except Exception:
                        rtt = 0.0
                    high_latency = 1 if rtt > 500 else 0
                    failed_login = 0
                    if "Login Successful" in curr:
                        v = curr["Login Successful"]
                        if isinstance(v, str):
                            failed_login = 0 if v.strip().upper() in ["TRUE", "1", "YES"] else 1
                        else:
                            failed_login = 0 if bool(v) else 1

                    # ASN change (proxy suspicion)
                    asn_change = 1 if prev.get("ASN") != curr.get("ASN") else 0

                    # Foreign access
                    foreign_access = 0 if str(curr.get("Country", "")).strip().upper() in ["PH", "PHL", "PHILIPPINES"] else 1

                    # Simple metro_manila flag
                    metro_manila = 1 if str(curr.get("City", "")).strip().lower() in ["manila", "quezon city", "makati", "taguig", "pasig"] else 0

                    features = {
                        "distance_km": dist,
                        "time_gap_hours": time_gap,
                        "travel_speed_kmh": travel_speed,
                        "impossible_travel": impossible_flag,
                        "night_access": night_access,
                        "business_hours": business_hours,
                        "hour_norm": hour_norm,
                        "day_norm": day_norm,
                        "mobile_device": mobile_device,
                        "device_change": device_change,
                        "attack_ip": attack_ip,
                        "high_latency": high_latency,
                        "failed_login": failed_login,
                        "asn_change": asn_change,
                        "foreign_access": foreign_access,
                        "metro_manila": metro_manila,
                        "rtt_ms": rtt
                    }

                    rows.append(features)
                    labels.append(1 if str(curr.get("label", False)).strip().upper() in ["TRUE", "1", "YES"] else 0)
                    prev = curr
        else:
            # If no user id, fallback to consecutive global rows (less ideal)
            df = df.sort_values("time")
            prev = None
            for idx, curr in df.iterrows():
                if prev is None:
                    prev = curr
                    continue
                dist = self.compute_distance_km(prev.get("City"), prev.get("Country"), curr.get("City"), curr.get("Country"))
                time_gap = (pd.to_datetime(curr["time"]) - pd.to_datetime(prev["time"])).total_seconds() / 3600.0
                travel_speed = dist / time_gap if time_gap > 0 else 0.0
                impossible_flag = 1 if (not self.travel_plausibility(pd.to_datetime(prev["time"]), prev.get("City"), prev.get("Country"), pd.to_datetime(curr["time"]), curr.get("City"), curr.get("Country"))["plausible"]) else 0
                hour = pd.to_datetime(curr["time"]).hour
                night_access = 1 if (hour >= 22 or hour <= 5) else 0
                business_hours = 1 if (8 <= hour <= 17) else 0
                hour_norm = hour / 23.0
                day_norm = pd.to_datetime(curr["time"]).weekday() / 6.0
                mobile_device = 1 if isinstance(curr.get("Device Type", ""), str) and "mobile" in curr.get("Device Type", "").lower() else 0
                device_change = 1 if prev.get("Device Type") != curr.get("Device Type") else 0
                attack_ip = 1 if str(curr.get("Is Attack IP", False)).strip().upper() in ["TRUE", "1", "YES"] else 0
                try:
                    rtt_val = curr.get("Round-Trip Time [ms]", 0.0)
                    rtt = float(rtt_val) if rtt_val is not None else 0.0
                except Exception:
                    rtt = 0.0
                high_latency = 1 if rtt > 500 else 0
                failed_login = 0
                if "Login Successful" in curr:
                    v = curr["Login Successful"]
                    if isinstance(v, str):
                        failed_login = 0 if v.strip().upper() in ["TRUE", "1", "YES"] else 1
                    else:
                        failed_login = 0 if bool(v) else 1
                asn_change = 1 if prev.get("ASN") != curr.get("ASN") else 0
                foreign_access = 0 if str(curr.get("Country", "")).strip().upper() in ["PH", "PHL", "PHILIPPINES"] else 1
                metro_manila = 1 if str(curr.get("City", "")).strip().lower() in ["manila", "quezon city", "makati", "taguig", "pasig"] else 0

                features = {
                    "distance_km": dist,
                    "time_gap_hours": time_gap,
                    "travel_speed_kmh": travel_speed,
                    "impossible_travel": impossible_flag,
                    "night_access": night_access,
                    "business_hours": business_hours,
                    "hour_norm": hour_norm,
                    "day_norm": day_norm,
                    "mobile_device": mobile_device,
                    "device_change": device_change,
                    "attack_ip": attack_ip,
                    "high_latency": high_latency,
                    "failed_login": failed_login,
                    "asn_change": asn_change,
                    "foreign_access": foreign_access,
                    "metro_manila": metro_manila,
                    "rtt_ms": rtt
                }
                rows.append(features)
                labels.append(1 if str(curr.get("label", False)).strip().upper() in ["TRUE", "1", "YES"] else 0)
                prev = curr

        X = pd.DataFrame(rows)
        y = pd.Series(labels)
        X = X.fillna(0)
        return X, y

    # ------------------------
    # Enhanced Training / evaluation
    # ------------------------
    def train_model_from_csv(self,
                             csv_path,
                             nrows=20000,
                             label_column="Is Attack IP",
                             use_smote=True,
                             save_model=True,
                             test_size=0.2,
                             random_state=42,
                             threshold=0.5):
        """
        Enhanced training with threshold optimization and feature importance analysis.
        """
        # 1) read
        df = pd.read_csv(csv_path, nrows=nrows)

        # 2) sanitize - replace '-' with None in Region/City
        if "Region" in df.columns:
            df["Region"] = df["Region"].replace("-", None)
        if "City" in df.columns:
            df["City"] = df["City"].replace("-", None)

        # 3) map standard columns
        df = df.rename(columns={
            "Login Timestamp": "time",
            "Country": "Country",
            "City": "City",
            label_column: "label"
        })

        # 4) parse times
        df["time"] = pd.to_datetime(df["time"], errors="coerce")
        df = df.dropna(subset=["time", "Country"])

        print(f"📂 Loaded {len(df)} rows. Preview:")
        print(df.head(3))

        # 6) precompute coordinates
        self.precompute_all_coords(df.rename(columns={"Country": "country", "City": "city"}).rename_axis(None))

        # 7) create features & labels for ML
        df_fe = df.copy()
        if "User ID" not in df_fe.columns:
            df_fe["User ID"] = df_fe.get("User ID", None)
        
        df_fe = df_fe.rename(columns={
            "time": "time",
            "Country": "Country",
            "City": "City",
            "Device Type": "Device Type",
            "Round-Trip Time [ms]": "Round-Trip Time [ms]",
            "Is Attack IP": "Is Attack IP",
            "Login Successful": "Login Successful",
            "User Agent String": "User Agent String",
            "ASN": "ASN"
        })

        df_fe["time"] = pd.to_datetime(df_fe["time"], errors="coerce")
        X, y = self.build_ml_features_for_df(df_fe)

        if X.shape[0] == 0:
            print("⚠️ Not enough consecutive login pairs to build training examples.")
            return

        print("⚖️ Class distribution before train/test split:")
        print(y.value_counts())

        # 9) split
        try:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)
        except Exception:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

        # 10) optional SMOTE
        if use_smote and IMB_AVAILABLE:
            print("✨ Applying SMOTE oversampling on training set...")
            sm = SMOTE(random_state=random_state)
            X_train, y_train = sm.fit_resample(X_train, y_train)
            print("Post-SMOTE class distribution:", pd.Series(y_train).value_counts())

        # 11) scale numeric features
        self.scaler = StandardScaler()
        X_train_scaled = self.scaler.fit_transform(X_train)
        X_test_scaled = self.scaler.transform(X_test)

        # 12) train
        clf = RandomForestClassifier(n_estimators=200, random_state=random_state, class_weight="balanced")
        clf.fit(X_train_scaled, y_train)

        # 13) Enhanced evaluation with threshold optimization
        y_proba = clf.predict_proba(X_test_scaled)[:, 1]
        
        # Find optimal threshold using F1-score
        thresholds = np.arange(0.1, 0.9, 0.05)
        f1_scores = []
        for thresh in thresholds:
            y_pred_thresh = (y_proba >= thresh).astype(int)
            f1_scores.append(f1_score(y_test, y_pred_thresh))
        
        optimal_threshold = thresholds[np.argmax(f1_scores)]
        print(f"🎯 Optimal threshold based on F1-score: {optimal_threshold:.2f}")
        
        # Use optimal threshold for final predictions
        y_pred = (y_proba >= optimal_threshold).astype(int)

        print("📊 Model Performance (Optimal Threshold):")
        print(classification_report(y_test, y_pred))
        print(f"AUC = {roc_auc_score(y_test, y_proba):.3f}")
        
        # Feature importance analysis
        if hasattr(clf, 'feature_importances_'):
            feature_importance = pd.DataFrame({
                'feature': X.columns,
                'importance': clf.feature_importances_
            }).sort_values('importance', ascending=False)
            
            print("\n🔍 Top 10 Most Important Features:")
            print(feature_importance.head(10))
            
            # Plot feature importance
            plt.figure(figsize=(10, 6))
            plt.barh(feature_importance.head(10)['feature'][::-1], 
                    feature_importance.head(10)['importance'][::-1])
            plt.title('Top 10 Feature Importance')
            plt.xlabel('Importance')
            plt.tight_layout()
            plt.show()

        # Enhanced visualization: three-panel layout
        fpr, tpr, _ = roc_curve(y_test, y_proba)
        roc_auc = auc(fpr, tpr)
        plt.figure(figsize=(15, 5))
        
        # ROC Curve subplot
        plt.subplot(1, 3, 1)
        plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate (Recall)')
        plt.title('ROC Curve')
        plt.legend(loc="lower right")
        plt.grid(True, alpha=0.3)

        # Precision-Recall Curve subplot
        plt.subplot(1, 3, 2)
        precision, recall, thresholds_pr = precision_recall_curve(y_test, y_proba)
        plt.plot(recall, precision, color='blue', lw=2)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.grid(True, alpha=0.3)
        
        # Threshold vs F1-score
        plt.subplot(1, 3, 3)
        plt.plot(thresholds, f1_scores, 'g-', lw=2)
        plt.axvline(x=optimal_threshold, color='red', linestyle='--', 
                   label=f'Optimal: {optimal_threshold:.2f}')
        plt.xlabel('Threshold')
        plt.ylabel('F1-Score')
        plt.title('Threshold Optimization')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

        # confusion matrix
        cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
        plt.figure(figsize=(6, 4))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                   xticklabels=["Legit (Pred)", "Attack (Pred)"], 
                   yticklabels=["Legit (True)", "Attack (True)"])
        plt.title(f"Confusion Matrix (Threshold: {optimal_threshold:.2f})")
        plt.ylabel("Actual")
        plt.xlabel("Predicted")
        plt.show()

        # 14) save
        self.model = clf
        self.ml_feature_columns = X.columns.tolist()
        if save_model:
            joblib.dump({"model": clf, "scaler": self.scaler, "features": self.ml_feature_columns}, self.ml_model_path)
            print(f"✅ Model + scaler saved to {self.ml_model_path}")

    # ------------------------
    # Predict / unified scoring
    # ------------------------
    def load_model(self, model_path=None):
        path = model_path if model_path else self.ml_model_path
        if os.path.exists(path):
            obj = joblib.load(path)
            self.model = obj.get("model", obj)
            self.scaler = obj.get("scaler", self.scaler)
            self.ml_feature_columns = obj.get("features", self.ml_feature_columns)
            print(f"✅ Loaded model from {path}")
        else:
            raise FileNotFoundError(path)

    def ml_risk_for_pair(self, last_login, current_login):
        """Return probability of attack (0-1) from ML model for a given pair (dict-like)."""
        if not self.model:
            raise Exception("Model not trained or loaded.")
        
        dummy_df = pd.DataFrame([{
            "time": current_login.get("time"),
            "User ID": current_login.get("user_id", None),
            "Country": current_login.get("Country", current_login.get("country")),
            "City": current_login.get("City", current_login.get("city")),
            "Device Type": current_login.get("Device Type", current_login.get("device_type")),
            "Round-Trip Time [ms]": current_login.get("Round-Trip Time [ms]", current_login.get("rtt_ms", None)),
            "Is Attack IP": current_login.get("Is Attack IP", current_login.get("is_attack_ip", False)),
            "Login Successful": current_login.get("Login Successful", current_login.get("login_successful", True)),
            "User Agent String": current_login.get("User Agent String", current_login.get("user_agent", "")),
            "ASN": current_login.get("ASN", current_login.get("asn", None)),
            "label": current_login.get("label", False)
        }])
        
        two_rows = pd.concat([
            pd.DataFrame([{
                "time": last_login.get("time"),
                "User ID": last_login.get("user_id", None),
                "Country": last_login.get("Country", last_login.get("country")),
                "City": last_login.get("City", last_login.get("city")),
                "Device Type": last_login.get("Device Type", last_login.get("device_type")),
                "Round-Trip Time [ms]": last_login.get("Round-Trip Time [ms]", last_login.get("rtt_ms", None)),
                "Is Attack IP": last_login.get("Is Attack IP", last_login.get("is_attack_ip", False)),
                "Login Successful": last_login.get("Login Successful", last_login.get("login_successful", True)),
                "User Agent String": last_login.get("User Agent String", last_login.get("user_agent", "")),
                "ASN": last_login.get("ASN", last_login.get("asn", None)),
                "label": last_login.get("label", False)
            }]),
            dummy_df
        ], ignore_index=True)
        X_single, _ = self.build_ml_features_for_df(two_rows)
        if X_single.shape[0] == 0:
            return 0.5
        X_single = X_single.fillna(0)
        X_scaled = self.scaler.transform(X_single)
        proba = self.model.predict_proba(X_scaled)[:, 1][0]
        return float(proba)

    def unified_risk_score(self, user_history_df, last_login, current_login,
                           weights={"ml": 0.4, "travel": 0.3, "behavior": 0.2, "technical": 0.1}):
        """
        Compute the final unified risk score (0-1) and return breakdown + recommended action.
        last_login/current_login: dict-like with keys matching dataset columns (time must be datetime)
        user_history_df: DataFrame of prior logins (may be empty)
        """
        # Prepare datetimes
        last_time = pd.to_datetime(last_login.get("time", last_login.get("Login Timestamp", None)))
        curr_time = pd.to_datetime(current_login.get("time", current_login.get("Login Timestamp", None)))

        # ML risk (probability)
        try:
            ml_prob = self.ml_risk_for_pair(last_login, current_login)
        except Exception:
            ml_prob = 0.5

        # Travel analysis
        travel_info = self.travel_plausibility(last_time,
                                              last_login.get("City", last_login.get("city")),
                                              last_login.get("Country", last_login.get("country")),
                                              curr_time,
                                              current_login.get("City", current_login.get("city")),
                                              current_login.get("Country", current_login.get("country")))
        travel_risk = travel_info["travel_risk"]

        # Behavioral
        behavior_info = self.behavioral_consistency_score(user_history_df, current_login)
        behavior_risk = 1.0 - behavior_info["consistency_score"]

        # Technical
        tech_info = self.technical_score(current_login)
        technical_risk = tech_info["technical_score"]

        # Combine
        final = (weights["ml"] * ml_prob +
                 weights["travel"] * travel_risk +
                 weights["behavior"] * behavior_risk +
                 weights["technical"] * technical_risk)
        final = max(0.0, min(1.0, final))

        # Decide action
        if final < 0.3:
            level, action = "LOW", "ALLOW"
        elif final < 0.6:
            level, action = "MEDIUM", "ALLOW_WITH_OTP"
        else:
            # if travel impossible, prefer BLOCK
            action = "BLOCK" if not travel_info["plausible"] else "STRICT_VERIFICATION"
            level = "HIGH"

        return {
            "final_risk_score": final,
            "risk_level": level,
            "action": action,
            "components": {
                "ml_prob": ml_prob,
                "travel_risk": travel_risk,
                "behavior_risk": behavior_risk,
                "technical_risk": technical_risk
            },
            "travel_info": travel_info,
            "behavior_info": behavior_info,
            "technical_info": tech_info
        }


if __name__ == "__main__":
    # Configuration - change these paths as needed
    dataset_path = r"C:\Users\Brando\Desktop\School\Project\BantAI_Datawave\rba-dataset.csv"
    model_path = "bantai_model.pkl"

    # Initialize BantAI system
    print("🚀 Initializing BantAI TravelAware Fraud Detection System...")
    bantai = BantAI_TravelAware(
        cache_file="geocache.json", 
        ml_model_path=model_path, 
        geocode_delay=1.0
    )

    # Train ML component on dataset subset
    print("📚 Training ML model on dataset...")
    # label_column can be "Is Attack IP" or "Is Account Takeover"
    bantai.train_model_from_csv(
        dataset_path, 
        nrows=20000,  # Use subset for faster training
        label_column="Is Attack IP", 
        use_smote=True, 
        save_model=True,
        threshold=0.5  # This will be optimized automatically
    )
    
    print("✅ Training completed! Model saved and ready for fraud detection.")


    #20k sweet 78
    #50k sweet 58