In [None]:
import os
import json
import numpy as np
import tensorflow as tf
from datetime import datetime, timedelta
from sklearn.preprocessing import MinMaxScaler
import random

class BandwidthPredictor:
    def __init__(self, models_dir="tflite_models", metadata_dir="metadata"):
        """
        Initialize the bandwidth predictor with models and metadata for edge deployment
        
        Parameters:
        -----------
        models_dir : str
            Directory containing TFLite models
        metadata_dir : str
            Directory containing metadata files
        """
        self.models_dir = models_dir
        self.metadata_dir = metadata_dir
        
        # Load metadata
        self._load_metadata()
        
        # Load TFLite models
        self._load_models()
        
        print(f"Loaded {len(self.models)} service group models")
        print(f"Available service groups: {list(self.models.keys())}")
    
    def _load_metadata(self):
        """Load metadata files for prediction"""
        try:
            # Load service mappings
            with open(os.path.join(self.metadata_dir, "service_mappings.json"), "r") as f:
                self.service_mappings = json.load(f)
            
            # Load service group stats
            with open(os.path.join(self.metadata_dir, "service_group_stats.json"), "r") as f:
                self.service_group_stats = json.load(f)
                
            # Load feature list
            with open(os.path.join(self.metadata_dir, "feature_list.json"), "r") as f:
                self.feature_info = json.load(f)
                
            # Load scaler parameters
            with open(os.path.join(self.metadata_dir, "scalers.json"), "r") as f:
                scaler_params = json.load(f)
                
            # Initialize scalers for each service group
            self.scalers = {}
            for service_group, params in scaler_params.items():
                feature_scaler = MinMaxScaler()
                feature_scaler.min_ = np.array(params["features"]["min"])
                feature_scaler.scale_ = np.array(params["features"]["scale"])
                
                target_scaler = MinMaxScaler()
                target_scaler.min_ = np.array(params["target"]["min"])
                target_scaler.scale_ = np.array(params["target"]["scale"])
                
                self.scalers[service_group] = {
                    "features": feature_scaler,
                    "target": target_scaler
                }
            
            print("Successfully loaded metadata")
            
        except Exception as e:
            raise RuntimeError(f"Error loading metadata: {str(e)}")
    
    def _load_models(self):
        """Load TFLite models"""
        self.models = {}
        self.interpreters = {}
        self.input_details = {}
        self.output_details = {}
        
        try:
            # Check if models directory exists
            if not os.path.exists(self.models_dir):
                raise RuntimeError(f"Models directory {self.models_dir} not found")
            
            # Load each model file
            for model_file in os.listdir(self.models_dir):
                if model_file.endswith(".tflite"):
                    # Extract service group from filename
                    if 'fallback' in model_file:
                        # Skip fallback models if regular ones exist
                        service_group = model_file.split('_fallback_model')[0].replace('_', ' ').title()
                        if service_group in self.models:
                            continue
                    else:
                        service_group = model_file.split('_model')[0].replace('_', ' ').title()
                    
                    # Load model
                    model_path = os.path.join(self.models_dir, model_file)
                    
                    # Read model file
                    with open(model_path, 'rb') as f:
                        model_content = f.read()
                    
                    # Create interpreter
                    interpreter = tf.lite.Interpreter(model_content=model_content)
                    interpreter.allocate_tensors()
                    
                    # Get input and output details
                    input_details = interpreter.get_input_details()
                    output_details = interpreter.get_output_details()
                    
                    # Store model
                    self.models[service_group] = model_content
                    self.interpreters[service_group] = interpreter
                    self.input_details[service_group] = input_details
                    self.output_details[service_group] = output_details
                    
                    print(f"Loaded model for {service_group}")
            
            if not self.models:
                raise RuntimeError("No models found in the models directory")
                
        except Exception as e:
            raise RuntimeError(f"Error loading models: {str(e)}")
    
    def generate_time_features(self, timestamp):
        """
        Generate time-based features for a timestamp
        
        Parameters:
        -----------
        timestamp : datetime
            Timestamp to generate features for
            
        Returns:
        --------
        tuple
            Time features (hour, day_of_week, is_weekend, hour_features, day_features)
        """
        hour = timestamp.hour
        day_of_week = timestamp.weekday()
        is_weekend = 1 if day_of_week >= 5 else 0
        
        # Hour one-hot encoding
        hour_features = [1 if h == hour else 0 for h in range(24)]
        
        # Day one-hot encoding
        day_features = [1 if d == day_of_week else 0 for d in range(7)]
        
        return hour, day_of_week, is_weekend, hour_features, day_features
    
    def prepare_features(self, timestamp, service_group, usage_percentage=None, device_group=None):
        """
        Prepare input features for prediction
        
        Parameters:
        -----------
        timestamp : datetime
            Timestamp for prediction
        service_group : str
            Service group to predict for
        usage_percentage : float, optional
            Usage percentage (if known)
        device_group : str, optional
            Device group ('personal_device' or 'work_device')
            
        Returns:
        --------
        numpy.ndarray
            Scaled features for model input
        """
        # Generate time features
        hour, day_of_week, is_weekend, hour_features, day_features = self.generate_time_features(timestamp)
        
        # Generate or use provided usage percentage
        if usage_percentage is None:
            # Define usage patterns by service group and time of day
            usage_patterns = {
                "Streaming": {
                    "morning": (5, 9, 30, 50),      # 5-9 AM: 30-50%
                    "workday": (9, 17, 40, 60),     # 9 AM-5 PM: 40-60%
                    "evening": (17, 23, 70, 90),    # 5-11 PM: 70-90%
                    "night": (23, 5, 50, 70),       # 11 PM-5 AM: 50-70%
                    "weekend_boost": 10             # +10% on weekends
                },
                "Gaming": {
                    "morning": (5, 9, 20, 40),      # 5-9 AM: 20-40%
                    "workday": (9, 17, 30, 50),     # 9 AM-5 PM: 30-50%
                    "evening": (17, 23, 75, 95),    # 5-11 PM: 75-95%
                    "night": (23, 5, 60, 80),       # 11 PM-5 AM: 60-80%
                    "weekend_boost": 15             # +15% on weekends
                },
                "Social Media": {
                    "morning": (5, 9, 50, 70),      # 5-9 AM: 50-70%
                    "workday": (9, 17, 60, 80),     # 9 AM-5 PM: 60-80%
                    "evening": (17, 23, 70, 90),    # 5-11 PM: 70-90%
                    "night": (23, 5, 40, 60),       # 11 PM-5 AM: 40-60%
                    "weekend_boost": 5              # +5% on weekends
                },
                "Shopping": {
                    "morning": (5, 9, 30, 50),      # 5-9 AM: 30-50%
                    "workday": (9, 17, 50, 70),     # 9 AM-5 PM: 50-70%
                    "evening": (17, 23, 60, 80),    # 5-11 PM: 60-80%
                    "night": (23, 5, 20, 40),       # 11 PM-5 AM: 20-40%
                    "weekend_boost": 10             # +10% on weekends
                },
                "Software": {
                    "morning": (5, 9, 60, 80),      # 5-9 AM: 60-80%
                    "workday": (9, 17, 80, 95),     # 9 AM-5 PM: 80-95%
                    "evening": (17, 23, 40, 60),    # 5-11 PM: 40-60%
                    "night": (23, 5, 20, 40),       # 11 PM-5 AM: 20-40%
                    "weekend_boost": -20            # -20% on weekends
                }
            }
            
            # Use default pattern if service group not found
            pattern = usage_patterns.get(service_group, {
                "morning": (5, 9, 40, 60),
                "workday": (9, 17, 50, 70),
                "evening": (17, 23, 60, 80),
                "night": (23, 5, 30, 50),
                "weekend_boost": 5
            })
            
            # Determine which time period we're in
            if pattern["morning"][0] <= hour < pattern["morning"][1]:
                min_usage, max_usage = pattern["morning"][2], pattern["morning"][3]
            elif pattern["workday"][0] <= hour < pattern["workday"][1]:
                min_usage, max_usage = pattern["workday"][2], pattern["workday"][3]
            elif pattern["evening"][0] <= hour < pattern["evening"][1]:
                min_usage, max_usage = pattern["evening"][2], pattern["evening"][3]
            else:  # night
                min_usage, max_usage = pattern["night"][2], pattern["night"][3]
            
            # Add weekend boost if applicable
            if is_weekend:
                min_usage += pattern["weekend_boost"]
                max_usage += pattern["weekend_boost"]
                
            # Ensure within 0-100 range
            min_usage = max(0, min(100, min_usage))
            max_usage = max(0, min(100, max_usage))
            
            # Generate random usage percentage within the determined range
            usage_percentage = random.uniform(min_usage, max_usage)
        
        # Determine device group if not provided
        if device_group is None:
            # Default logic: work devices more likely during work hours
            if service_group == "Software" and 8 <= hour <= 18 and day_of_week < 5:
                device_group_encoded = 1  # work device
            else:
                device_group_encoded = 0  # personal device
        else:
            # Use provided device group
            device_group_encoded = 1 if device_group == "work_device" else 0
        
        # Generate network metrics based on service group and time
        if service_group == "Streaming":
            signal_strength = random.uniform(-70, -30)
            packet_loss = random.uniform(0, 0.2)
            latency = random.uniform(10, 50)
            jitter = random.uniform(1, 5)
        elif service_group == "Gaming":
            signal_strength = random.uniform(-65, -25)
            packet_loss = random.uniform(0, 0.1)
            latency = random.uniform(5, 30)
            jitter = random.uniform(0.5, 3)
        elif service_group == "Social Media":
            signal_strength = random.uniform(-75, -40)
            packet_loss = random.uniform(0, 0.3)
            latency = random.uniform(20, 80)
            jitter = random.uniform(2, 8)
        elif service_group == "Shopping":
            signal_strength = random.uniform(-80, -45)
            packet_loss = random.uniform(0, 0.4)
            latency = random.uniform(30, 100)
            jitter = random.uniform(3, 10)
        else:  # Software
            signal_strength = random.uniform(-70, -35)
            packet_loss = random.uniform(0, 0.2)
            latency = random.uniform(15, 70)
            jitter = random.uniform(1, 7)
            
        # Calculate usage minutes based on usage percentage
        usage_minutes = usage_percentage * 0.3
        
        # Build feature vector based on metadata
        basic_features = self.feature_info["basic_features"]
        network_features = self.feature_info["network_features"]
        
        # Basic features
        feature_values = [usage_percentage]
        if 'device_group_encoded' in basic_features:
            feature_values.append(device_group_encoded)
        
        # Network features
        network_values = []
        if 'signal_strength' in network_features:
            network_values.append(signal_strength)
        if 'packet_loss' in network_features:
            network_values.append(packet_loss)
        if 'latency' in network_features:
            network_values.append(latency)
        if 'jitter' in network_features:
            network_values.append(jitter)
        if 'usage_minutes' in network_features:
            network_values.append(usage_minutes)
        
        # Time features
        time_values = [hour, day_of_week, is_weekend]
        
        # Combine all features
        all_features = feature_values + network_values + time_values + hour_features + day_features
        
        # Scale features
        feature_array = np.array(all_features).reshape(1, -1)
        scaled_features = self.scalers[service_group]["features"].transform(feature_array)
        
        return scaled_features.astype(np.float32)
    
    def predict(self, timestamp, service_group, usage_percentage=None, device_group=None):
        """
        Make prediction for a specific service group
        
        Parameters:
        -----------
        timestamp : datetime
            Timestamp for prediction
        service_group : str
            Service group to predict for
        usage_percentage : float, optional
            Usage percentage (if known)
        device_group : str, optional
            Device group ('personal_device' or 'work_device')
            
        Returns:
        --------
        dict
            Prediction results
        """
        if service_group not in self.models:
            raise ValueError(f"No model available for service group: {service_group}")
        
        # Prepare features
        scaled_features = self.prepare_features(timestamp, service_group, usage_percentage, device_group)
        
        # Get interpreter
        interpreter = self.interpreters[service_group]
        input_details = self.input_details[service_group]
        output_details = self.output_details[service_group]
        
        # Set input tensor
        interpreter.set_tensor(input_details[0]['index'], scaled_features)
        
        # Run inference
        interpreter.invoke()
        
        # Get output
        scaled_prediction = interpreter.get_tensor(output_details[0]['index'])
        
        # Inverse transform to get original bandwidth scale
        bandwidth_prediction = self.scalers[service_group]['target'].inverse_transform(
            scaled_prediction.reshape(-1, 1)
        )[0][0]
        
        # Apply constraints based on service group stats
        if service_group in self.service_group_stats:
            stats = self.service_group_stats[service_group]
            max_val = stats['max']
            constrained_bandwidth = max(min(bandwidth_prediction, max_val * 1.1), stats['min'] * 0.9)
        else:
            # Fallback to conservative estimation
            constrained_bandwidth = bandwidth_prediction
        
        # Calculate allocation percentage based on maximum possible bandwidth
        if service_group in self.service_group_stats:
            max_bandwidth = self.service_group_stats[service_group]['max']
            allocation_percentage = (constrained_bandwidth / max_bandwidth) * 100
        else:
            # Use a default assumption if stats not available
            allocation_percentage = 75.0  # Conservative default
        
        # Ensure percentage is within reasonable bounds
        allocation_percentage = max(min(allocation_percentage, 95), 10)
        
        # Get service group ID
        group_id = self.service_mappings.get(service_group, {}).get("group_id", 1000)
        
        # Get a service ID (any valid one for this group)
        service_id = 100
        for service_name, sid in self.service_mappings.get(service_group, {}).get("services", {}).items():
            service_id = sid
            break  # Just take the first one
        
        # Format the results
        result = {
            "service_group": service_group,
            "group_id": group_id,
            "service_id": service_id,
            "bandwidth_allocation": f"{allocation_percentage:.2f}%",
            "predicted_bandwidth_mbps": f"{constrained_bandwidth:.2f}"
        }
        
        return result
    
    def predict_all_groups(self, timestamp, usage_percentages=None, device_groups=None):
        """
        Predict bandwidth for all service groups at a given time
        
        Parameters:
        -----------
        timestamp : datetime
            Timestamp for prediction
        usage_percentages : dict, optional
            Dictionary mapping service groups to usage percentages
        device_groups : dict, optional
            Dictionary mapping service groups to device groups
            
        Returns:
        --------
        dict
            Predictions for all service groups
        """
        results = {}
        
        for service_group in self.models.keys():
            usage = None
            if usage_percentages and service_group in usage_percentages:
                usage = usage_percentages[service_group]
                
            device = None
            if device_groups and service_group in device_groups:
                device = device_groups[service_group]
            
            try:
                prediction = self.predict(timestamp, service_group, usage, device)
                results[service_group] = prediction
            except Exception as e:
                print(f"Error predicting for {service_group}: {str(e)}")
        
        return results
    
    def generate_time_series_predictions(self, start_time, end_time, interval_hours=1):
        """
        Generate predictions for a time range
        
        Parameters:
        -----------
        start_time : datetime
            Start timestamp for predictions
        end_time : datetime
            End timestamp for predictions
        interval_hours : int or float, optional
            Hours between prediction points
            
        Returns:
        --------
        dict
            Time series predictions in the format required for bandwidth allocation
        """
        # Generate timestamps
        current = start_time
        predictions = {}
        
        while current <= end_time:
            # Make predictions for this timestamp
            formatted_timestamp = current.strftime('%Y-%m-%d %H:%M:%S')
            
            # Determine the most relevant service group for this time
            hour = current.hour
            is_weekend = current.weekday() >= 5
            
            # Time-based service group probabilities
            if 9 <= hour <= 17 and not is_weekend:
                # Work hours on weekdays - Software dominates
                service_probs = {
                    "Software": 0.5,
                    "Streaming": 0.1,
                    "Social Media": 0.2,
                    "Shopping": 0.1,
                    "Gaming": 0.1
                }
            elif 18 <= hour <= 23:
                # Evening hours - Streaming and Gaming dominate
                service_probs = {
                    "Software": 0.1,
                    "Streaming": 0.4,
                    "Social Media": 0.2,
                    "Shopping": 0.1,
                    "Gaming": 0.2
                }
            elif is_weekend:
                # Weekend - Mixed with more entertainment
                service_probs = {
                    "Software": 0.1,
                    "Streaming": 0.3,
                    "Social Media": 0.2,
                    "Shopping": 0.2,
                    "Gaming": 0.2
                }
            else:
                # Early morning - More Social Media
                service_probs = {
                    "Software": 0.2,
                    "Streaming": 0.2,
                    "Social Media": 0.4,
                    "Shopping": 0.1,
                    "Gaming": 0.1
                }
            
            # Filter to available models
            available_groups = set(self.models.keys())
            service_probs = {k: v for k, v in service_probs.items() if k in available_groups}
            
            if not service_probs:
                # No valid models found - just use the first available model
                service_group = list(self.models.keys())[0]
            else:
                # Normalize probabilities
                total = sum(service_probs.values())
                service_probs = {k: v/total for k, v in service_probs.items()}
                
                # Weighted random choice of service group
                groups = list(service_probs.keys())
                weights = list(service_probs.values())
                service_group = random.choices(groups, weights=weights, k=1)[0]
            
            # Predict for this service group
            prediction = self.predict(current, service_group)
            predictions[formatted_timestamp] = prediction
            
            # Move to next time interval
            current += timedelta(hours=interval_hours)
        
        # Format as expected by the bandwidth allocation system
        return {"responseData": predictions}

    def save_predictions_to_json(self, predictions, output_file='edge_predictions.json'):
        """
        Save predictions to a JSON file
        
        Parameters:
        -----------
        predictions : dict
            Prediction results
        output_file : str, optional
            Path to save the JSON file
            
        Returns:
        --------
        None
        """
        with open(output_file, 'w') as f:
            json.dump(predictions, f, indent=2)
            
        print(f"Predictions saved to {output_file}")


# Example usage
if __name__ == "__main__":
    # Initialize the predictor
    predictor = BandwidthPredictor()
    
    # Example 1: Predict for a specific timestamp and service group
    current_time = datetime.now()
    prediction = predictor.predict(current_time, "Streaming")
    print("Prediction for current time:")
    print(json.dumps(prediction, indent=2))
    
    # Example 2: Generate predictions for the next 24 hours
    start_time = datetime.now()
    end_time = start_time + timedelta(hours=24)
    
    print(f"\nGenerating predictions from {start_time} to {end_time}")
    time_series = predictor.generate_time_series_predictions(start_time, end_time, interval_hours=2)
    predictor.save_predictions_to_json(time_series, "next_24_hours_predictions.json")
    
    # Example 3: Predict for a custom date range
    custom_start = datetime(2025, 3, 27, 8, 0, 0)  # March 27, 2025 at 8:00 AM
    custom_end = datetime(2025, 3, 29, 23, 0, 0)   # March 29, 2025 at 11:00 PM
    
    print(f"\nGenerating predictions for custom date range: {custom_start} to {custom_end}")
    custom_predictions = predictor.generate_time_series_predictions(custom_start, custom_end, interval_hours=4)
    predictor.save_predictions_to_json(custom_predictions, "custom_date_range_predictions.json")
    
    print("\nAll predictions completed successfully!")