In [None]:
# risk_assessment.py

import time
import json
import datetime
import math
import statistics
import numpy as np  # Make sure to install numpy if not already installed
import logging
import psutil
import os
import uuid
from flask import Flask, request, jsonify

app = Flask(__name__)

class RiskAssessment:
    def __init__(self, pe_url=None):
        # Stores risk level, no-anomaly rounds, history, and past anomalies for each entity
        self.entity_risks = {}  # {entity_id: {...}}
        # Mapping of attack types to criticality values (A)
        self.attack_criticality = {
            'privilege_escalation': 9,
            'phishing': 7,
            'lateral_movement': 8,
            'data_exfiltration': 10,
            # Add more attack types as needed
        }
        # Mapping of entities to segments
        self.entity_segments = {
            'iot_device1': 'production_network',
            'iot_device2': 'production_network',
            'user': 'user_network',
            'server': 'data_center',
            # Add more entities as needed
        }
        # Mapping of segments to criticality values (S)
        self.segment_criticality = {
            'production_network': 10,
            'data_center': 10,
            'user_network': 5,
            'default_segment': 5,
            # Add more segments as needed
        }
        self.pe_url = pe_url  # URL of the Policy Engine
        self.load_data()  # Load saved data
        self.setup_logging()

    def setup_logging(self):
        # Configure the logging module
        logging.basicConfig(
            filename='risk_assessment.log',
            level=logging.INFO,
            format='%(asctime)s %(levelname)s: %(message)s'
        )

    def load_data(self):
        try:
            with open('entity_risks.json', 'r') as f:
                self.entity_risks = json.load(f)
            print("Data loaded from entity_risks.json")
        except FileNotFoundError:
            # File doesn't exist yet
            self.entity_risks = {}
        except Exception as e:
            print(f"Error loading data: {e}")
            self.entity_risks = {}

    def save_data(self):
        try:
            with open('entity_risks.json', 'w') as f:
                json.dump(self.entity_risks, f)
            print("Data saved to entity_risks.json")
        except Exception as e:
            print(f"Error saving data: {e}")

    def record_anomaly(self, entity_id, attack_type, confidence, request_id):
        # Start CPU and timing measurements
        process = psutil.Process(os.getpid())
        cpu_times_start = process.cpu_times()
        start_time = time.time()

        # Initialize entity data if not present
        if entity_id not in self.entity_risks:
            self.entity_risks[entity_id] = {
                'risk_level': 'normal',
                'no_anomaly_rounds': 0,
                'history': [],
                'past_anomalies_count': 0
            }
        else:
            self.entity_risks[entity_id]['no_anomaly_rounds'] = 0

        # Increment past anomalies count
        self.entity_risks[entity_id]['past_anomalies_count'] += 1

        # Confidence of Threat (C)
        C = confidence  # Confidence score from user input (0-100)
        C_norm = (C - 0) / (100 - 0)  # Normalization (C_min=0, C_max=100)

        # Attack Criticality (A)
        A = self.attack_criticality.get(attack_type, 5)  # Default to 5 if attack type not found
        A_norm = (A - 1) / (10 - 1)  # Normalization (A_min=1, A_max=10)

        # Segment Criticality (S)
        segment = self.entity_segments.get(entity_id, 'default_segment')
        S = self.segment_criticality.get(segment, 5)  # Default to 5
        S_norm = (S - 1) / (10 - 1)  # Normalization (S_min=1, S_max=10)

        # Past Anomalies (P)
        P = self.entity_risks[entity_id]['past_anomalies_count']
        P_log = math.log(P + 1)

        # Collect P_log values for all entities
        P_log_values = [math.log(self.entity_risks[e]['past_anomalies_count'] + 1) for e in self.entity_risks]

        # Compute Median and IQR of P_log_values
        median_P_log = statistics.median(P_log_values)
        Q1 = np.percentile(P_log_values, 25)
        Q3 = np.percentile(P_log_values, 75)
        IQR_P_log = Q3 - Q1

        # Handle division by zero
        if IQR_P_log == 0:
            P_norm = 0
        else:
            P_norm = (P_log - median_P_log) / IQR_P_log

        # Weights
        w_C = 0.25
        w_A = 0.35
        w_S = 0.20
        w_P = 0.20

        # Compute Threat Risk
        Threat_Risk = (w_C * C_norm) + (w_A * A_norm) + (w_S * S_norm) + (w_P * P_norm)

        # Determine Risk Level
        risk_level = self.determine_risk_level(Threat_Risk)

        # Update risk level
        previous_risk_level = self.entity_risks[entity_id]['risk_level']
        self.entity_risks[entity_id]['risk_level'] = risk_level

        # Get current timestamp
        timestamp = datetime.datetime.now().isoformat()

        # Add anomaly to history
        self.entity_risks[entity_id]['history'].append({
            'timestamp': timestamp,
            'attack_type': attack_type,
            'confidence': confidence,
            'C_norm': C_norm,
            'A_norm': A_norm,
            'S_norm': S_norm,
            'P_norm': P_norm,
            'Threat_Risk': Threat_Risk,
            'risk_level': risk_level
        })

        print(f"Entity '{entity_id}' risk level updated from '{previous_risk_level}' to '{risk_level}'")

        # Save data
        self.save_data()

        # Notify PE if risk level is 'high risk' or 'critical'
        if self.pe_url and risk_level in ['high risk', 'critical']:
            self.notify_pe_risk_update(entity_id, risk_level, request_id)

        # End CPU and timing measurements
        end_time = time.time()
        cpu_times_end = process.cpu_times()

        # Calculate CPU time and processing time
        user_cpu_time = cpu_times_end.user - cpu_times_start.user
        system_cpu_time = cpu_times_end.system - cpu_times_start.system
        total_cpu_time = user_cpu_time + system_cpu_time
        processing_time = end_time - start_time

        # Log the performance data
        logging.info(
            f"Request ID: {request_id} | Entity ID: {entity_id} | Risk Level Updated: {risk_level} | "
            f"Processing Time: {processing_time:.6f}s | CPU Time: User={user_cpu_time:.6f}s "
            f"System={system_cpu_time:.6f}s Total={total_cpu_time:.6f}s"
        )

        return risk_level

    def update_no_anomaly(self, request_id):
        # Start CPU and timing measurements
        process = psutil.Process(os.getpid())
        cpu_times_start = process.cpu_times()
        start_time = time.time()

        # For each entity, update no_anomaly_rounds and adjust risk levels if necessary
        for entity_id, data in self.entity_risks.items():
            if data['risk_level'] != 'normal':
                data['no_anomaly_rounds'] += 1
                risk_level = data['risk_level']
                no_anomaly_rounds = data['no_anomaly_rounds']

                # Define thresholds for downgrading risk levels
                if risk_level == 'low risk' and no_anomaly_rounds >= 1:
                    data['risk_level'] = 'normal'
                    data['no_anomaly_rounds'] = 0
                    print(f"Entity '{entity_id}' risk level decreased to 'normal'")
                elif risk_level == 'high risk' and no_anomaly_rounds >= 2:
                    data['risk_level'] = 'normal'
                    data['no_anomaly_rounds'] = 0
                    print(f"Entity '{entity_id}' risk level decreased to 'normal'")
                elif risk_level == 'critical' and no_anomaly_rounds >= 3:
                    data['risk_level'] = 'normal'
                    data['no_anomaly_rounds'] = 0
                    print(f"Entity '{entity_id}' risk level decreased to 'normal'")
                else:
                    print(f"Entity '{entity_id}' risk level remains at '{risk_level}', no_anomaly_rounds: {no_anomaly_rounds}")
            else:
                # Risk level is 'normal'; no action needed
                pass

        # Save data after updates
        self.save_data()

        # End CPU and timing measurements
        end_time = time.time()
        cpu_times_end = process.cpu_times()

        # Calculate CPU time and processing time
        user_cpu_time = cpu_times_end.user - cpu_times_start.user
        system_cpu_time = cpu_times_end.system - cpu_times_start.system
        total_cpu_time = user_cpu_time + system_cpu_time
        processing_time = end_time - start_time

        # Log the performance data
        logging.info(
            f"Request ID: {request_id} | Action: Update No Anomaly | "
            f"Processing Time: {processing_time:.6f}s | CPU Time: User={user_cpu_time:.6f}s "
            f"System={system_cpu_time:.6f}s Total={total_cpu_time:.6f}s"
        )

    def determine_risk_level(self, Threat_Risk):
        if Threat_Risk >= 0.8:
            return 'critical'
        elif Threat_Risk >= 0.6:
            return 'high risk'
        elif Threat_Risk >= 0.4:
            return 'low risk'
        else:
            return 'normal'

    def get_risk_status(self, entity_id, request_id):
        # Start CPU and timing measurements
        process = psutil.Process(os.getpid())
        cpu_times_start = process.cpu_times()
        start_time = time.time()

        # Returns the current risk level of the entity
        risk_level = self.entity_risks.get(entity_id, {'risk_level': 'normal'})['risk_level']

        # End CPU and timing measurements
        end_time = time.time()
        cpu_times_end = process.cpu_times()

        # Calculate CPU time and processing time
        user_cpu_time = cpu_times_end.user - cpu_times_start.user
        system_cpu_time = cpu_times_end.system - cpu_times_start.system
        total_cpu_time = user_cpu_time + system_cpu_time
        processing_time = end_time - start_time

        # Log the performance data
        logging.info(
            f"Request ID: {request_id} | Entity ID: {entity_id} | Action: Get Risk Status | "
            f"Risk Level: {risk_level} | Processing Time: {processing_time:.6f}s | "
            f"CPU Time: User={user_cpu_time:.6f}s System={system_cpu_time:.6f}s "
            f"Total={total_cpu_time:.6f}s"
        )

        return risk_level

    def get_entity_history(self, entity_id, request_id):
        # Returns the history of anomalies for the entity
        history = self.entity_risks.get(entity_id, {}).get('history', [])
        return history

    def notify_pe_risk_update(self, entity_id, risk_level, request_id):
        # Notify the PE of the risk update
        try:
            payload = {
                'entity_id': entity_id,
                'risk_level': risk_level,
                'request_id': request_id
            }
            response = requests.post(f'{self.pe_url}/receive_risk_update', json=payload)
            response.raise_for_status()
            logging.info(f"Request ID: {request_id} | RiskAssessment: Notified PE of risk update for entity '{entity_id}' to '{risk_level}'.")
        except Exception as e:
            logging.error(f"Request ID: {request_id} | RiskAssessment Error: Failed to notify PE of risk update: {e}")

# Flask API Endpoints

RiskAssessment_instance = None  # Will be initialized later

@app.route('/record_anomaly', methods=['POST'])
def record_anomaly_endpoint():
    """
    API endpoint to record an anomaly.
    Expects JSON data with 'entity_id', 'attack_type', 'confidence', and optional 'request_id'.
    """
    try:
        json_data = request.get_json()
        entity_id = json_data.get('entity_id')
        attack_type = json_data.get('attack_type')
        confidence = float(json_data.get('confidence', 0))
        request_id = json_data.get('request_id', str(uuid.uuid4()))

        risk_level = RiskAssessment_instance.record_anomaly(entity_id, attack_type, confidence, request_id)

        response = {
            'request_id': request_id,
            'entity_id': entity_id,
            'risk_level': risk_level
        }
        return jsonify(response), 200
    except Exception as e:
        logging.error(f"Error in record_anomaly_endpoint: {e}")
        return jsonify({'error': str(e)}), 500

@app.route('/update_no_anomaly', methods=['POST'])
def update_no_anomaly_endpoint():
    """
    API endpoint to update the system when no anomaly is detected.
    Expects optional 'request_id'.
    """
    try:
        json_data = request.get_json()
        request_id = json_data.get('request_id', str(uuid.uuid4()))

        RiskAssessment_instance.update_no_anomaly(request_id)

        response = {
            'request_id': request_id,
            'status': 'No anomaly update processed successfully'
        }
        return jsonify(response), 200
    except Exception as e:
        logging.error(f"Error in update_no_anomaly_endpoint: {e}")
        return jsonify({'error': str(e)}), 500

@app.route('/get_risk_status', methods=['POST'])
def get_risk_status_endpoint():
    """
    API endpoint to get the risk status of an entity.
    Expects JSON data with 'entity_id' and optional 'request_id'.
    """
    try:
        json_data = request.get_json()
        entity_id = json_data.get('entity_id')
        request_id = json_data.get('request_id', str(uuid.uuid4()))

        risk_level = RiskAssessment_instance.get_risk_status(entity_id, request_id)

        response = {
            'request_id': request_id,
            'entity_id': entity_id,
            'risk_level': risk_level
        }
        return jsonify(response), 200
    except Exception as e:
        logging.error(f"Error in get_risk_status_endpoint: {e}")
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Set the URL of the Policy Engine
    pe_url = 'http://192.52.33.4:5000'  # Replace with actual IP

    # Create an instance of RiskAssessment
    RiskAssessment_instance = RiskAssessment(pe_url=pe_url)

    # Run the Flask app on all interfaces, port 5000
    app.run(host='0.0.0.0', port=5000)