In [None]:
import numpy as np
import pandas as pd
import torch
import dgl
import random
import copy
import networkx as nx
import matplotlib.pyplot as plt
import sumolib
import traci
import os
import subprocess
from typing import List, Dict
import geopandas as gpd
import osmnx as ox
from shapely.geometry import Point
import requests
from dgl.data.utils import save_graphs

class MarkovServiceDemand:
    def __init__(self, services: List[str]):
        self.services = services + ['none']  # Include 'none' as a state
        self.transition_matrix = self._create_transition_matrix()
        self.current_services = {}  # Current state per vehicle
        self.poisson_triggered = set()  # Vehicles that already triggered Poisson
        
    def _create_transition_matrix(self) -> Dict[str, Dict[str, float]]:
        """Transition matrix between services (including 'none')"""
        return {
            'none': {
                'cooperative_perception': 0.15, 
                'platooning_control': 0.25, 
                'edge_object_recognition': 0.08, 
                'predictive_collision_avoidance': 0.07, 
                'infrastructure_vision': 0.05, 
                'none': 0.4
            },
            'cooperative_perception': {
                'cooperative_perception': 0.6, 
                'platooning_control': 0.08, 
                'edge_object_recognition': 0.04, 
                'predictive_collision_avoidance': 0.03, 
                'infrastructure_vision': 0.02, 
                'none': 0.23
            },
            'platooning_control': {
                'cooperative_perception': 0.05, 
                'platooning_control': 0.7, 
                'edge_object_recognition': 0.04, 
                'predictive_collision_avoidance': 0.08, 
                'infrastructure_vision': 0.03, 
                'none': 0.1
            },
            'edge_object_recognition': {
                'cooperative_perception': 0.08, 
                'platooning_control': 0.08, 
                'edge_object_recognition': 0.55, 
                'predictive_collision_avoidance': 0.12, 
                'infrastructure_vision': 0.07, 
                'none': 0.1
            },
            'predictive_collision_avoidance': {
                'cooperative_perception': 0.04, 
                'platooning_control': 0.08, 
                'edge_object_recognition': 0.08, 
                'predictive_collision_avoidance': 0.65, 
                'infrastructure_vision': 0.05, 
                'none': 0.1
            },
            'infrastructure_vision': {
                'cooperative_perception': 0.01, 
                'platooning_control': 0.02, 
                'edge_object_recognition': 0.02, 
                'predictive_collision_avoidance': 0.05, 
                'infrastructure_vision': 0.85, 
                'none': 0.05
            }
        }
    
    def initialize_vehicle(self, vid: str, vehicle_data: Dict) -> str:
        """Initialize a vehicle with Poisson distribution (only once)"""
        if vid not in self.poisson_triggered:
            self.poisson_triggered.add(vid)
            
            # Poisson distribution to decide initial service
            k = np.random.poisson(0.5)  # Œª = 0.5 for less aggressive
            if k >= 1:
                # Choose initial service based on context
                initial_service = self._get_initial_service(vehicle_data)
                self.current_services[vid] = initial_service
                return initial_service
            else:
                # No initial service
                self.current_services[vid] = 'none'
                return 'none'
        else:
            # Vehicle already initialized, use current state
            return self.current_services.get(vid, 'none')
    
    def _get_initial_service(self, vehicle_data: Dict) -> str:
        """Determine initial service based on vehicle type"""
        vehicle_type = vehicle_data.get('vehicle_type', 'car')
        
        initial_probs = {
            'car': {
                'cooperative_perception': 0.3, 
                'platooning_control': 0.4, 
                'edge_object_recognition': 0.15, 
                'predictive_collision_avoidance': 0.1, 
                'infrastructure_vision': 0.05
            },
            'truck': {
                'cooperative_perception': 0.1, 
                'platooning_control': 0.3, 
                'edge_object_recognition': 0.2, 
                'predictive_collision_avoidance': 0.35, 
                'infrastructure_vision': 0.05
            },
            'bus': {
                'cooperative_perception': 0.15, 
                'platooning_control': 0.35, 
                'edge_object_recognition': 0.15, 
                'predictive_collision_avoidance': 0.3, 
                'infrastructure_vision': 0.05
            },
            'emergency': {
                'cooperative_perception': 0.05, 
                'platooning_control': 0.2, 
                'edge_object_recognition': 0.1, 
                'predictive_collision_avoidance': 0.15, 
                'infrastructure_vision': 0.5
            },
            'motorcycle': {
                'cooperative_perception': 0.25, 
                'platooning_control': 0.45, 
                'edge_object_recognition': 0.15, 
                'predictive_collision_avoidance': 0.1, 
                'infrastructure_vision': 0.05
            }
        }
        
        probs = initial_probs.get(vehicle_type, initial_probs['car'])
        return random.choices(list(probs.keys()), weights=list(probs.values()))[0]
    
    def get_next_service(self, vid: str, vehicle_data: Dict) -> str:
        """Markov transition to next service"""
        current = self.current_services.get(vid, 'none')
        
        # Get transition probabilities
        next_probs = self.transition_matrix[current].copy()
        
        # Adjust based on vehicle context
        next_probs = self._adjust_for_context(next_probs, vehicle_data)
        
        # Choose next service
        next_service = random.choices(
            list(next_probs.keys()), 
            weights=list(next_probs.values())
        )[0]
        
        # Update state
        self.current_services[vid] = next_service
        return next_service
    
    def _adjust_for_context(self, probabilities: Dict[str, float], vehicle_data: Dict) -> Dict[str, float]:
        """Adjust probabilities based on vehicle context"""
        adjusted = probabilities.copy()
        vehicle_type = vehicle_data.get('vehicle_type', 'car')
        speed = vehicle_data.get('speed', 0)
        
        # Adjustments based on vehicle type
        if vehicle_type == 'emergency':
            adjusted['infrastructure_vision'] *= 2.5
            adjusted['predictive_collision_avoidance'] *= 1.5
        elif vehicle_type in ['truck', 'bus']:
            adjusted['predictive_collision_avoidance'] *= 2.0
            adjusted['platooning_control'] *= 1.3
            
        # Adjustments based on speed
        if speed < 10:  # Traffic jam
            adjusted['cooperative_perception'] *= 1.8
            adjusted['platooning_control'] *= 0.7
        elif speed > 60:  # Highway
            adjusted['platooning_control'] *= 1.5
            adjusted['predictive_collision_avoidance'] *= 1.3
            
        # Renormalize
        total = sum(adjusted.values())
        for service in adjusted:
            adjusted[service] /= total
            
        return adjusted
    
    def get_service_duration(self, service: str) -> int:
        """Service lifetime duration"""
        base_durations = {
            'cooperative_perception': np.random.exponential(1/0.08),
            'platooning_control': np.random.exponential(1/0.12),
            'edge_object_recognition': np.random.exponential(1/0.05),
            'predictive_collision_avoidance': np.random.exponential(1/0.1),
            'infrastructure_vision': np.random.exponential(1/0.15),
            'none': np.random.exponential(1/0.05)
        }
        return max(1, int(base_durations.get(service, 10)))

class CombinedVehicleManager:
    def __init__(self, services: List[str], service_specs: Dict[str, Dict], 
                 target_vehicles: int, demand_model: MarkovServiceDemand,
                 network_file: str = "manhattan.net.xml",
                 route_file: str = "manhattan.rou.xml",
                 use_gui: bool = True):
        
        self.services = services
        self.service_specs = service_specs
        self.target_vehicles = target_vehicles
        self.demand_model = demand_model
        self.current_vehicles = {}
        self.vehicle_counter = 0
        self.first_appearance_timestamps = {}
        self.service_durations = {}

        self.vehicle_specs = {
            "car": {"length": 4.3, "max_speed": 50, "accel": 2.6, "decel": 4.5},
            "truck": {"length": 12.0, "max_speed": 40, "accel": 1.3, "decel": 3.5},
            "bus": {"length": 14.0, "max_speed": 36, "accel": 1.2, "decel": 3.0},
            "motorcycle": {"length": 2.5, "max_speed": 60, "accel": 3.0, "decel": 6.0},
            "emergency": {"length": 6.0, "max_speed": 70, "accel": 3.5, "decel": 6.5}
        }

        # Close existing connections
        try:
            traci.close()
        except:
            pass

        # ‚úÖ USING EXISTING NETWORK
        if not os.path.exists(network_file):
            print(f"‚ùå Network file {network_file} not found")
            self.sumo = None
            self.net = None
            return
        
        # Load YOUR Manhattan network
        try:
            self.net = sumolib.net.readNet(network_file)
            print(f"‚úÖ YOUR Manhattan network loaded: {len(self.net.getNodes())} nodes, {len(self.net.getEdges())} edges")
        except Exception as e:
            print(f"‚ùå Error loading network: {e}")
            self.sumo = None
            self.net = None
            return
        
        # ‚úÖ ROUTE FILES VERIFICATION
        if not os.path.exists(route_file):
            print(f"üîÑ Creating routes for your Manhattan network...")
            self._generate_routes_for_existing_network(target_vehicles, route_file)
        else:
            print(f"‚úÖ Existing route file used: {route_file}")
        
        # SUMO configuration with YOUR files
        self.sumo_config = self._create_sumo_config(network_file, route_file)
        
        try:
            if use_gui:
                print("üöó Starting SUMO-GUI with YOUR network...")
                traci.start(['sumo-gui', '-c', self.sumo_config, '--start', '--delay', '100'])
            else:
                print("üöó Starting SUMO (console mode) with YOUR network...")
                traci.start(['sumo', '-c', self.sumo_config])
            
            self.sumo = traci
            print("‚úÖ SUMO started with YOUR Manhattan network")
            
        except Exception as e:
            print(f"‚ùå SUMO error: {e}")
            self.sumo = None

    def _generate_routes_for_existing_network(self, num_vehicles: int, route_file: str):
        """Generate routes adapted to YOUR existing Manhattan network"""
        routes_content = """<?xml version="1.0" encoding="UTF-8"?>
<routes>
    <vType id="car" length="4.3" maxSpeed="13.89" accel="2.6" decel="4.5" sigma="0.5" color="yellow"/>
    <vType id="truck" length="12.0" maxSpeed="11.11" accel="1.3" decel="3.5" sigma="0.7" color="red"/>
    <vType id="bus" length="14.0" maxSpeed="10.0" accel="1.2" decel="3.0" sigma="0.8" color="blue"/>
    <vType id="motorcycle" length="2.5" maxSpeed="16.67" accel="3.0" decel="6.0" sigma="0.3" color="green"/>
    <vType id="emergency" length="6.0" maxSpeed="19.44" accel="3.5" decel="6.5" sigma="0.2" color="white"/>
"""
        
        # Get edges from YOUR network
        edges = [edge.getID() for edge in self.net.getEdges() 
                if not edge.getID().startswith(':')]
        
        print(f"üîç YOUR network has {len(edges)} available edges")
        
        vehicle_types = ["car", "truck", "bus", "motorcycle", "emergency"]
        type_probabilities = [0.65, 0.15, 0.08, 0.10, 0.02]
        
        # Generate realistic routes for Manhattan
        routes = []
        for i in range(num_vehicles):
            route_edges = self._create_manhattan_style_route(edges)
            routes.append(route_edges)
        
        # Staggered departures
        departure_times = sorted([random.uniform(0, 500) for _ in range(num_vehicles)])
        
        for i, depart in enumerate(departure_times):
            route_edges = routes[i] if i < len(routes) else [random.choice(edges)]
            vtype = random.choices(vehicle_types, weights=type_probabilities, k=1)[0]
            
            routes_content += f"""
    <route id="route_{i}" edges="{' '.join(route_edges)}"/>
    <vehicle id="veh{i}" type="{vtype}" depart="{depart:.1f}" route="route_{i}"/>"""
        
        routes_content += "\n</routes>"
        
        with open(route_file, 'w', encoding='utf-8') as f:
            f.write(routes_content)
        print(f"‚úÖ Routes generated for YOUR Manhattan: {route_file}")

    def _create_manhattan_style_route(self, edges: List[str]) -> List[str]:
        """Create a typical Manhattan route (grid movements)"""
        try:
            # Separate horizontal and vertical edges
            horizontal_edges = []
            vertical_edges = []
            
            for edge_id in edges:
                edge_obj = self.net.getEdge(edge_id)
                from_node = edge_obj.getFromNode()
                to_node = edge_obj.getToNode()
                
                # Determine if edge is horizontal or vertical
                if abs(from_node.getCoord()[0] - to_node.getCoord()[0]) > abs(from_node.getCoord()[1] - to_node.getCoord()[1]):
                    horizontal_edges.append(edge_id)
                else:
                    vertical_edges.append(edge_id)
            
            # Create a grid path (ex: 2 horizontals, 1 vertical, 2 horizontals)
            route = []
            if horizontal_edges and vertical_edges:
                # Start with horizontal
                route.append(random.choice(horizontal_edges))
                # Then vertical  
                route.append(random.choice(vertical_edges))
                # Then another horizontal
                route.append(random.choice(horizontal_edges))
            else:
                # Fallback if not enough edges
                route = random.sample(edges, min(3, len(edges)))
            
            return route
            
        except Exception as e:
            print(f"‚ö†Ô∏è Manhattan route creation error: {e}")
            return random.sample(edges, min(3, len(edges)))

    def _create_sumo_config(self, network_file: str, route_file: str) -> str:
        """Create SUMO configuration with optimized parameters"""
        config_content = f"""<?xml version="1.0" encoding="UTF-8"?>
<configuration>
    <input>
        <net-file value="{network_file}"/>
        <route-files value="{route_file}"/>
    </input>
    <time>
        <begin value="0"/>
        <end value="10000"/>
    </time>
    <processing>
        <step-length value="1.0"/>
        <no-step-log value="true"/>
        <ignore-route-errors value="true"/>
        <no-internal-links value="false"/>
    </processing>
    <report>
        <no-warnings value="false"/>
        <verbose value="true"/>
    </report>
</configuration>"""
        
        config_file = "manhattan_config.sumocfg"
        with open(config_file, 'w', encoding='utf-8') as f:
            f.write(config_content)
        
        return config_file

    def _get_markov_service(self, vid: str, vehicle_data: Dict) -> str:
        """Use Markov model to get service"""
        if vid not in self.demand_model.current_services:
            # Initialization with Poisson (only once)
            service = self.demand_model.initialize_vehicle(vid, vehicle_data)
        else:
            # Normal Markov transition
            service = self.demand_model.get_next_service(vid, vehicle_data)
        
        # Update lifetime duration
        if service != 'none':
            self.service_durations[vid] = self.demand_model.get_service_duration(service)
        else:
            self.service_durations[vid] = 0
            
        return service

    def update(self, t: int) -> Dict[str, torch.Tensor]:
        """Update simulation with realistic service management"""
        if self.sumo is None:
            return self._empty_tensor_dict()
        
        try:
            self.sumo.simulationStep()
            
            # Vehicle count control
            current_count = len(self.sumo.vehicle.getIDList())
            if current_count < self.target_vehicles:
                to_add = self.target_vehicles - current_count
                self._add_vehicles_safe(to_add)
            elif current_count > self.target_vehicles:
                to_remove = current_count - self.target_vehicles
                self._remove_vehicles_safe(to_remove)
            
            # Update vehicles with realistic services
            updated_vehicles = {}
            for vid in self.sumo.vehicle.getIDList():
                try:
                    pos = self.sumo.vehicle.getPosition(vid)
                    x, y = pos
                    
                    speed = self.sumo.vehicle.getSpeed(vid)
                    road_id = self.sumo.vehicle.getRoadID(vid)
                    vehicle_type = self.sumo.vehicle.getTypeID(vid)
                    
                    vehicle_data = {
                        'vehicle_type': vehicle_type,
                        'speed': speed,
                        'road_id': road_id
                    }
                    
                    if vid in self.current_vehicles:
                        first_appearance = self.first_appearance_timestamps.get(vid, t)
                        current_service = self.current_vehicles[vid]['service']
                        
                        if vid in self.service_durations and self.service_durations[vid] > 0:
                            # Active service, decrement duration
                            self.service_durations[vid] -= 1
                            service = current_service  # Keep current service
                        else:
                            # Service expired or no service, Markov transition
                            service = self._get_markov_service(vid, vehicle_data)
                    else:
                        # New vehicle - Markov initialization
                        first_appearance = t
                        self.first_appearance_timestamps[vid] = first_appearance
                        
                        # ‚úÖ Single initialization with integrated Poisson
                        service = self._get_markov_service(vid, vehicle_data)
                    
                    # Prepare vehicle data
                    vehicle_info = {
                        'id': hash(vid) % 1000000,
                        'sumo_id': vid,
                        'position': torch.tensor([x, y], dtype=torch.float),
                        'service': service,
                        'cpu_demand': self.service_specs[service]['cpu'] if service != 'none' else 0,
                        'ram_demand': self.service_specs[service]['ram'] if service != 'none' else 0,
                        'data_size': self.service_specs[service]['data_size'] if service != 'none' else 0,
                        'speed': speed,
                        'vehicle_type': vehicle_type,
                        'road_id': road_id,
                        'timestamp': t,
                        'timestamp_apparition': first_appearance,
                        'service_remaining_duration': self.service_durations.get(vid, 0)
                    }
                    
                    updated_vehicles[vid] = vehicle_info
                    
                except traci.TraCIException:
                    continue
            
            self.current_vehicles = updated_vehicles
            return self._to_tensor_format()
            
        except Exception as e:
            print(f"‚ùå Error during SUMO update: {e}")
            return self._empty_tensor_dict()

    def _add_vehicles_safe(self, num_to_add: int):
        """Add vehicles with robust error handling"""
        if not hasattr(self, 'net') or self.net is None:
            return
            
        edges = [edge.getID() for edge in self.net.getEdges() if not edge.getID().startswith(':')]
        if not edges:
            return
            
        vehicle_types = ["car", "truck", "bus", "motorcycle", "emergency"]
        
        for i in range(num_to_add):
            try:
                # Simple route between two valid edges
                start_edge = random.choice(edges)
                end_edge = random.choice(edges)
                
                route_id = f"dynamic_route_{self.vehicle_counter}"
                veh_id = f"dynamic_veh_{self.vehicle_counter}"
                chosen_type = random.choice(vehicle_types)
                
                # Create simple route
                self.sumo.route.add(route_id, [start_edge, end_edge])
                self.sumo.vehicle.add(veh_id, route_id, typeID=chosen_type, depart="now")
                
                self.vehicle_counter += 1
                
            except Exception as e:
                print(f"‚ö†Ô∏è Error adding vehicle {i}: {e}")
                continue

    def _remove_vehicles_safe(self, num_to_remove: int):
        """Safely remove vehicles"""
        vehicles = self.sumo.vehicle.getIDList()
        if vehicles:
            vehicles_to_remove = random.sample(vehicles, min(num_to_remove, len(vehicles)))
            for vid in vehicles_to_remove:
                try:
                    self.sumo.vehicle.remove(vid)
                    if vid in self.current_vehicles:
                        del self.current_vehicles[vid]
                    if vid in self.service_durations:
                        del self.service_durations[vid]
                except:
                    continue

    def _to_tensor_format(self) -> Dict[str, torch.Tensor]:
        all_vehicle_types = ["car", "truck", "bus", "motorcycle", "emergency"]
        all_services = self.services + ['none']
        
        if not self.current_vehicles:
            return self._empty_tensor_dict()
        
        return {
            'id': torch.tensor([v['id'] for v in self.current_vehicles.values()]),
            'position': torch.stack([v['position'] for v in self.current_vehicles.values()]),
            'type_service': torch.tensor([all_services.index(v['service']) for v in self.current_vehicles.values()]),
            'vehicle_type': torch.tensor([all_vehicle_types.index(v['vehicle_type']) for v in self.current_vehicles.values()]), 
            'cpu_demand': torch.tensor([v['cpu_demand'] for v in self.current_vehicles.values()]),
            'ram_demand': torch.tensor([v['ram_demand'] for v in self.current_vehicles.values()]),
            'data_size': torch.tensor([v['data_size'] for v in self.current_vehicles.values()]),
            'speed': torch.tensor([v['speed'] for v in self.current_vehicles.values()]),
            'timestamp_apparition': torch.tensor([v['timestamp_apparition'] for v in self.current_vehicles.values()]),
            'timestamp': torch.tensor([v['timestamp'] for v in self.current_vehicles.values()]),
            'service_remaining_duration': torch.tensor([v['service_remaining_duration'] for v in self.current_vehicles.values()])
        }

    def _empty_tensor_dict(self):
        return {
            'id': torch.tensor([], dtype=torch.long),
            'position': torch.tensor([], dtype=torch.float).reshape(0, 2),
            'type_service': torch.tensor([], dtype=torch.long),
            'vehicle_type': torch.tensor([], dtype=torch.long),
            'cpu_demand': torch.tensor([], dtype=torch.float),
            'ram_demand': torch.tensor([], dtype=torch.float),
            'data_size': torch.tensor([], dtype=torch.float),
            'speed': torch.tensor([], dtype=torch.float),
            'timestamp_apparition': torch.tensor([], dtype=torch.long),
            'timestamp': torch.tensor([], dtype=torch.long),
            'service_remaining_duration': torch.tensor([], dtype=torch.long)
        }

class CombinedDocumentSimulator:
    def __init__(self, num_edges=8, target_vehicles=30, time_steps=230, 
                 network_file="manhattan.net.xml", use_gui=True,
                 save_snapshots: bool = True, output_file: str = "snapshots.bin"):
        self.num_edges = num_edges
        self.time_steps = time_steps
        self.services = [
            'cooperative_perception', 
            'platooning_control', 
            'edge_object_recognition',
            'predictive_collision_avoidance',
            'infrastructure_vision'
        ]
        
        self.service_specs = {
            'cooperative_perception': {'cpu': 15, 'ram': 10, 'data_size': 2},
            'platooning_control': {'cpu': 7, 'ram': 5, 'data_size': 1.6},
            'edge_object_recognition': {'cpu': 5, 'ram': 3, 'data_size': 1},
            'predictive_collision_avoidance': {'cpu': 10, 'ram': 7, 'data_size': 1.8},
            'infrastructure_vision': {'cpu': 3, 'ram': 2, 'data_size': 0.6},
            'none': {'cpu': 0, 'ram': 0, 'data_size': 0}
        }

        self.save_snapshots = save_snapshots
        self.output_file = output_file
        self.saved_snapshots = []

        # Realistic demand model
        self.demand_model = MarkovServiceDemand(self.services)
   
        self.vehicle_manager = CombinedVehicleManager(
            services=self.services,
            service_specs=self.service_specs,
            target_vehicles=target_vehicles,
            demand_model=self.demand_model,
            network_file=network_file,
            use_gui=use_gui
        )

        # Load SUMO network to get road positions
        self.sumo_net = sumolib.net.readNet(network_file)
        self.current_edge_state = None

    def _get_optimal_edge_positions(self) -> torch.Tensor:
        """Calculate optimal edge positions to cover Manhattan"""
        # Get all edges from Manhattan network
        sumo_edges = list(self.sumo_net.getEdges())
        
        # Extract center positions of SUMO edges
        edge_positions = []
        for edge in sumo_edges:
            # Get edge shape (list of points)
            shape = edge.getShape()
            if len(shape) > 0:
                # Calculate edge center
                x_coords = [point[0] for point in shape]
                y_coords = [point[1] for point in shape]
                center_x = sum(x_coords) / len(x_coords)
                center_y = sum(y_coords) / len(y_coords)
                edge_positions.append((center_x, center_y))
        
        # If not enough edges found, use node coordinates
        if len(edge_positions) < self.num_edges:
            nodes = list(self.sumo_net.getNodes())
            for node in nodes:
                coord = node.getCoord()
                if coord:
                    edge_positions.append((coord[0], coord[1]))
        
        # Use simplified k-means to find optimal positions
        if len(edge_positions) >= self.num_edges:
            # Select most spaced positions
            selected_positions = self._select_dispersed_positions(edge_positions, self.num_edges)
        else:
            # Fallback: circular positions around network center
            selected_positions = self._get_circular_positions(self.num_edges)
        
        return torch.tensor(selected_positions, dtype=torch.float)

    def _select_dispersed_positions(self, positions: List[tuple], num_to_select: int) -> List[tuple]:
        """Select most dispersed positions"""
        if len(positions) <= num_to_select:
            return positions
        
        # Simple method: select most distant positions
        selected = [random.choice(positions)]
        
        for _ in range(1, num_to_select):
            max_min_distance = -1
            best_candidate = None
            
            for candidate in positions:
                if candidate in selected:
                    continue
                
                # Calculate minimum distance to already selected positions
                min_dist = min(self._distance(candidate, sel) for sel in selected)
                
                if min_dist > max_min_distance:
                    max_min_distance = min_dist
                    best_candidate = candidate
            
            if best_candidate:
                selected.append(best_candidate)
        
        return selected

    def _get_circular_positions(self, num_positions: int) -> List[tuple]:
        """Circular positions around network center (fallback)"""
        # Calculate Manhattan network center
        nodes = list(self.sumo_net.getNodes())
        if nodes:
            coords = [node.getCoord() for node in nodes if node.getCoord()]
            if coords:
                center_x = sum(coord[0] for coord in coords) / len(coords)
                center_y = sum(coord[1] for coord in coords) / len(coords)
            else:
                center_x, center_y = 500, 500
        else:
            center_x, center_y = 500, 500
            
        radius = 400  # Larger radius for Manhattan
        
        positions = []
        for i in range(num_positions):
            angle = 2 * np.pi * i / num_positions
            x = center_x + radius * np.cos(angle)
            y = center_y + radius * np.sin(angle)
            positions.append((x, y))
        
        return positions

    def _distance(self, pos1: tuple, pos2: tuple) -> float:
        """Calculate distance between two positions"""
        return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)

    def _get_cloud_position(self, edge_positions: torch.Tensor) -> torch.Tensor:
        """Position cloud at distance to recover uncovered vehicles"""
        # Calculate edge center of gravity
        center_x = torch.mean(edge_positions[:, 0])
        center_y = torch.mean(edge_positions[:, 1])
        
        # Calculate average edge radius
        distances_from_center = torch.norm(edge_positions - torch.tensor([center_x, center_y]), dim=1)
        mean_radius = torch.mean(distances_from_center)
        
        # Position cloud at 2x average radius (further for Manhattan)
        cloud_distance = mean_radius * 2.0 + 300
        
        # Arbitrary direction (e.g., northeast)
        cloud_x = center_x + cloud_distance * 0.7
        cloud_y = center_y + cloud_distance * 0.7
        
        return torch.tensor([[cloud_x, cloud_y]])

    def _init_edges(self) -> Dict[str, torch.Tensor]:
        """Edge initialization with optimal positions"""
        cpu_capacity = torch.full((self.num_edges + 1,), 25)
        ram_capacity = torch.full((self.num_edges + 1,), 17)
    
        # Cloud with more capacity
        cpu_capacity[-1] = 150
        ram_capacity[-1] = 100
        
        # Get optimal positions for edges
        edge_positions = self._get_optimal_edge_positions()
        
        # Get cloud position
        cloud_position = self._get_cloud_position(edge_positions)
        
        # Combine positions
        positions = torch.cat([edge_positions, cloud_position], dim=0)
    
        # Initialize services_hosted and TTL_services_hosted
        services_hosted = torch.zeros(self.num_edges + 1, len(self.services))
        TTL_services_hosted = torch.zeros(self.num_edges + 1, len(self.services))
        
        # For last edge (cloud), set all services to 1 and TTL to 1000
        services_hosted[-1, :] = torch.ones(len(self.services))
        TTL_services_hosted[-1, :] = torch.full((len(self.services),), 1000)
    
        return {
            'id': torch.arange(self.num_edges + 1),
            'position': positions,
            'cpu_capacity': cpu_capacity,
            'ram_capacity': ram_capacity,
            'cpu_available': cpu_capacity.clone(),
            'ram_available': ram_capacity.clone(),
            'services_hosted': services_hosted,
            'TTL_services_hosted': TTL_services_hosted
        }

    def _update_services(self, edge_feats: Dict[str, torch.Tensor], 
                            vehicle_feats: Dict[str, torch.Tensor], 
                            is_first_snapshot: bool) -> Dict[str, torch.Tensor]:
        """Service update with adaptation to realistic demand"""       
        if is_first_snapshot:
            edge_feats['cpu_available'] = edge_feats['cpu_capacity'].clone()
            edge_feats['ram_available'] = edge_feats['ram_capacity'].clone()
    
        # Decrement TTL of services deployed in cloud (last edge)
        cloud_index = -1  # Cloud index
        services_deployed = edge_feats['services_hosted'][cloud_index] > 0
        
        # Decrement TTL by 1 only for deployed services
        edge_feats['TTL_services_hosted'][cloud_index, services_deployed] -= 1
    
        return edge_feats

    def _calculate_local_demand(self, vehicle_feats: Dict[str, torch.Tensor], 
                              edge_pos: torch.Tensor) -> torch.Tensor:
        """Calculate local demand around an edge (only for vehicles with service)"""
        distances = torch.norm(vehicle_feats['position'] - edge_pos, dim=1)
        nearby_vehicles = distances < 400  # Larger radius for Manhattan
        
        if not nearby_vehicles.any():
            return torch.zeros(len(self.services))
        
        # Filter only vehicles with service (exclude 'none')
        service_vehicles = vehicle_feats['type_service'] < len(self.services)
        valid_nearby = nearby_vehicles & service_vehicles
        
        if not valid_nearby.any():
            return torch.zeros(len(self.services))
        
        local_service_counts = torch.bincount(
            vehicle_feats['type_service'][valid_nearby],
            minlength=len(self.services)
        )
        
        return local_service_counts.float() / max(1, valid_nearby.sum())

    def generate_snapshot(self, t: int) -> dgl.DGLGraph:
        """Generate a snapshot"""
        is_first_snapshot = (t == 0)
        if self.current_edge_state is None:
            self.current_edge_state = self._init_edges()

        vehicle_feats = self.vehicle_manager.update(t)
        num_vehicles = len(vehicle_feats['id'])

        updated_edge_feats = self._update_services(
            copy.deepcopy(self.current_edge_state), vehicle_feats, is_first_snapshot
        )
        self.current_edge_state = updated_edge_feats

        # Graph construction
        dist_matrix = torch.cdist(vehicle_feats['position'], updated_edge_feats['position'][:-1])
        connections = (dist_matrix < 150).nonzero(as_tuple=False)

        connected_vehicles = set(connections[:, 0].tolist())
        all_vehicles = set(range(num_vehicles))
        unconnected_vehicles = list(all_vehicles - connected_vehicles)

        cloud_idx = self.num_edges
        cloud_edges_src = torch.tensor(unconnected_vehicles, dtype=torch.long)
        cloud_edges_dst = torch.full((len(unconnected_vehicles),), cloud_idx, dtype=torch.long)

        if len(connections) > 0:
            src = torch.cat([connections[:, 0], cloud_edges_src])
            dst = torch.cat([connections[:, 1], cloud_edges_dst])
        else:
            src = cloud_edges_src
            dst = cloud_edges_dst

        g = dgl.heterograph({
            ('vehicle', 'connects', 'edge'): (src, dst)
        }, num_nodes_dict={
            'vehicle': num_vehicles,
            'edge': self.num_edges + 1
        })

        g.nodes['edge'].data.update(updated_edge_feats)
        g.nodes['vehicle'].data.update(vehicle_feats)

        if len(src) > 0:
            # ‚úÖ CORRECTION: Clearly separate features by connection type
            edge_features_list = []
            
            # 1. Features for NORMAL connections (edge -> vehicle)
            if len(connections) > 0:
                normal_src = connections[:, 0]
                normal_dst = connections[:, 1]
                normal_distances = dist_matrix[normal_src, normal_dst]
                
                # Get CORRECT features for each normal edge
                normal_bandwidth = torch.full((len(connections),), 100.0)
                
                edge_features_list.append({
                    'type': 'normal',
                    'src': normal_src,
                    'dst': normal_dst, 
                    'bandwidth': normal_bandwidth,
                    'distance': normal_distances,
                })
    
            # 2. Features for CLOUD connections (cloud -> vehicle)
            if len(unconnected_vehicles) > 0:
                cloud_src = cloud_edges_src
                cloud_dst = cloud_edges_dst
                cloud_distances = torch.full((len(unconnected_vehicles),), 1500.0)
                
                # Cloud-specific features
                cloud_bandwidth = torch.full((len(unconnected_vehicles),), 15.0)
                
                edge_features_list.append({
                    'type': 'cloud', 
                    'src': cloud_src,
                    'dst': cloud_dst,
                    'bandwidth': cloud_bandwidth,
                    'distance': cloud_distances,
                })
    
            # 3. Combine ALL features with CORRECT matching
            all_bandwidth = torch.cat([feat['bandwidth'] for feat in edge_features_list])
            all_distances = torch.cat([feat['distance'] for feat in edge_features_list])

    
            # 5. Assign in CORRECT order
            g.edges['connects'].data.update({
                'bandwidth': all_bandwidth,
                'distance': all_distances, 
            })
    
        if self.save_snapshots:
            self.saved_snapshots.append(g)
        return g

    def generate(self) -> List[dgl.DGLGraph]:
        return [self.generate_snapshot(t) for t in range(self.time_steps)]
    
    def save_snapshots_to_file(self):
        """Save all snapshots to binary file"""
        if not self.save_snapshots or not self.saved_snapshots:
            print("‚ùå No snapshots to save")
            return
        
        try:
            # ‚úÖ Use directly self.saved_snapshots which has already been filtered
            dgl.save_graphs(self.output_file, self.saved_snapshots)
            print(f"‚úÖ {len(self.saved_snapshots)} snapshots saved in {self.output_file}")
            
            # Statistics
            total_vehicles = sum(g.number_of_nodes('vehicle') for g in self.saved_snapshots)
            total_edges = sum(g.number_of_edges() for g in self.saved_snapshots)
            print(f"üìä Total: {total_vehicles} vehicles, {total_edges} connections")
            
            # Display range of saved snapshots
            if self.saved_snapshots:
                print(f"üìÖ Saved snapshots: 1 to {len(self.saved_snapshots)} (snapshot 0 was ignored)")
                
        except Exception as e:
            print(f"‚ùå Save error: {e}")

    def visualize_snapshot(self, g: dgl.DGLGraph, snapshot_id: int):
        """Visualization with NetworkX"""
        plt.figure(figsize=(16, 12))
        nx_g = nx.Graph()

        vehicle_type_colors = {
            0: 'red',      # Car
            1: 'blue',     # Truck
            2: 'green',    # Bus
            3: 'orange',   # Motorcycle
            4: 'purple',   # Emergency vehicle
            'edge': 'gray',
            'cloud': 'black'
        }
        
        service_colors = {
            0: 'lightcoral',    # streaming
            1: 'lightblue',     # navigation  
            2: 'lightgreen',    # iot
            3: 'gold',          # diagnostic
            4: 'violet',        # securite
            5: 'white'          # none
        }

        # Add nodes
        for ntype in g.ntypes:
            if ntype == 'vehicle' and 'vehicle_type' in g.nodes[ntype].data:
                positions = g.nodes[ntype].data['position'].cpu().numpy()
                vehicle_types = g.nodes[ntype].data['vehicle_type'].cpu().numpy()
                service_types = g.nodes[ntype].data['type_service'].cpu().numpy() if 'type_service' in g.nodes[ntype].data else [5] * g.number_of_nodes(ntype)
                vehicle_ids = g.nodes[ntype].data['id'].cpu().numpy() if 'id' in g.nodes[ntype].data else range(g.number_of_nodes(ntype))
                
                for i in range(g.number_of_nodes(ntype)):
                    node_id = f"V{int(vehicle_ids[i])}"
                    vehicle_type = int(vehicle_types[i])
                    service_type = int(service_types[i])
                    
                    # Color based on service
                    color = service_colors.get(service_type, 'white')
                    edge_color = vehicle_type_colors.get(vehicle_type, 'pink')
                    
                    nx_g.add_node(node_id, 
                                pos=positions[i], 
                                ntype=ntype,
                                vehicle_type=vehicle_type,
                                service_type=service_type,
                                color=color,
                                edge_color=edge_color,
                                size=200 if service_type == 5 else 250)
            
            elif ntype == 'edge':
                positions = g.nodes[ntype].data['position'].cpu().numpy()
                edge_ids = g.nodes[ntype].data['id'].cpu().numpy() if 'id' in g.nodes[ntype].data else range(g.number_of_nodes(ntype))
                
                for i in range(g.number_of_nodes(ntype)):
                    if i == g.number_of_nodes(ntype) - 1:
                        node_id = "Cloud"
                        color = vehicle_type_colors['cloud']
                        size = 500
                    else:
                        node_id = f"E{int(edge_ids[i])}"
                        color = vehicle_type_colors['edge']
                        size = 350
                    
                    nx_g.add_node(node_id, 
                                pos=positions[i], 
                                ntype=ntype,
                                color=color,
                                size=size)

        # Add edges
        for etype in g.etypes:
            vehicle_ids = g.nodes['vehicle'].data['id'].cpu().numpy()
            edge_ids = g.nodes['edge'].data['id'].cpu().numpy()

            src, dst = g.edges(etype=etype)
            src = src.cpu().numpy()
            dst = dst.cpu().numpy()

            for i in range(len(src)):
                src_type, _, dst_type = g.to_canonical_etype(etype)

                if src_type == 'vehicle':
                    src_id = f"V{vehicle_ids[src[i]]}"
                else:
                    if src[i] == g.number_of_nodes('edge') - 1:
                        src_id = "Cloud"
                    else:
                        src_id = f"E{edge_ids[src[i]]}"

                if dst_type == 'vehicle':
                    dst_id = f"V{vehicle_ids[dst[i]]}"
                else:
                    if dst[i] == g.number_of_nodes('edge') - 1:
                        dst_id = "Cloud"
                    else:
                        dst_id = f"E{edge_ids[dst[i]]}"

                if src_id in nx_g and dst_id in nx_g:
                    nx_g.add_edge(src_id, dst_id, etype=etype)

        # Draw graph
        if len(nx_g.nodes()) == 0:
            print("No nodes to display")
            plt.close()
            return
        
        pos = {node: data.get('pos', (0, 0)) for node, data in nx_g.nodes(data=True)}
        colors = [data.get('color', 'pink') for node, data in nx_g.nodes(data=True)]
        edge_colors = [data.get('edge_color', 'black') for node, data in nx_g.nodes(data=True)]
        sizes = [data.get('size', 300) for node, data in nx_g.nodes(data=True)]

        # Draw nodes with colored border
        nx.draw_networkx_nodes(nx_g, pos, node_color=colors, node_size=sizes, 
                              edgecolors=edge_colors, linewidths=2, alpha=0.8)
        
        if len(nx_g.nodes()) < 50:
            labels = {node: node for node in nx_g.nodes()}
            nx.draw_networkx_labels(nx_g, pos, labels, font_size=8, font_weight='bold')
        
        nx.draw_networkx_edges(nx_g, pos, alpha=0.3, edge_color='gray', width=1.5)

        # Improved legend
        legend_elements = [

            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', 
                      markeredgecolor='red', markersize=10, label='Cooperative Perception'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue', 
                      markeredgecolor='blue', markersize=10, label='Platooning Control'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', 
                      markeredgecolor='green', markersize=10, label='Edge-Assisted Object Recognition'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gold', 
                      markeredgecolor='orange', markersize=10, label='Predictive Collision Avoidance'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='violet', 
                      markeredgecolor='purple', markersize=10, label='Infrastructure-Assisted Vision'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='white', 
                      markeredgecolor='black', markersize=10, label='No service'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', 
                      markersize=10, label='Edge'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='black', 
                      markersize=10, label='Cloud')
        ]
        
        plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, 1))
        plt.title(f"Snapshot {snapshot_id} - Manhattan with Services\n(V=Vehicle, E=Edge, Color=interior=service, border=vehicle type)")
        plt.axis('off')
        plt.tight_layout()
        plt.show()

        # Statistics
        num_with_service = sum(1 for node, data in nx_g.nodes(data=True) 
                              if data.get('ntype') == 'vehicle' and data.get('service_type', 5) != 5)
        
        print(f"\nSnapshot {snapshot_id}:")
        print(f"Nodes: {len(nx_g.nodes())}, Edges: {len(nx_g.edges())}")
        print(f"Vehicles: {g.number_of_nodes('vehicle')}, Edges: {g.number_of_nodes('edge')}")
        print(f"Vehicles with service: {num_with_service}/{g.number_of_nodes('vehicle')}")


def main():
    """Main script with saving and visualization"""
    
    print("üöÄ STARTING SIMULATION WITH YOUR EXISTING NETWORK")
    
    # Configuration
    services = [
        'cooperative_perception', 
        'platooning_control', 
        'edge_object_recognition',
        'predictive_collision_avoidance',
        'infrastructure_vision'
    ]

    demand_model = MarkovServiceDemany(services)
    
    # ‚úÖ SIMULATOR WITH SAVING
    simulator = CombinedDocumentSimulator(
        num_edges=7,
        target_vehicles=20,
        time_steps=200, 
        network_file="manhattan.net.xml",
        use_gui=True,
        save_snapshots=True,
        output_file="manhattan_snapshots.bin"
    )

    # Snapshot generation WITH FILTERING
    print("‚è≥ Generating and saving snapshots...")
    
    graphs = []
    for t in range(simulator.time_steps):
        g = simulator.generate_snapshot(t)
        
        # ‚úÖ OPTION 3: Ignore first snapshot if no vehicles
        if t == 0 and g.number_of_nodes('vehicle') == 0:
            print("‚è≠Ô∏è  Snapshot 0 ignored (no vehicles)")
            continue
            
        graphs.append(g)
        
        # Graphical visualization of some snapshots
        if t % 10 == 0 and g.number_of_nodes('vehicle') > 0:  # ‚úÖ Ensure there are vehicles
            print(f"üé® Graphical visualization of snapshot {t}")
            simulator.visualize_snapshot(g, t)
    
    # ‚úÖ FINAL SAVING with filtered graphs
    simulator.saved_snapshots = graphs  # Replace with graphs without empty snapshot
    simulator.save_snapshots_to_file()
    
    # ‚úÖ FEATURES VISUALIZATION
    print("\nüìä NODE FEATURES VISUALIZATION")
    visualize_all_nodes_features(graphs)
    
    return graphs

def visualize_all_nodes_features(glist):
    """Visualize all node features for each snapshot"""
    for t, g in enumerate(glist):
        print(f"\n{'='*50}")
        print(f"=== SNAPSHOT {t} ===")
        print(f"{'='*50}")

        # 1. Display Edge node data
        if 'edge' in g.ntypes and g.number_of_nodes('edge') > 0:
            edge_data = {}
            for k, v in g.nodes['edge'].data.items():
                # Convert tensors to numpy and flatten 2D+ arrays
                if v.ndim > 1:
                    for i in range(v.shape[1]):
                        edge_data[f"{k}_{i}"] = v[:, i].numpy()
                else:
                    edge_data[k] = v.numpy()

            edge_df = pd.DataFrame(edge_data)
            edge_df.index.name = 'Edge_ID'
            print("\nüì° EDGE NODES FEATURES:")
            print(edge_df)
            
            # Additional statistics for edges
            print(f"\nüìà Edge Statistics - Snapshot {t}:")
            print(f"Number of edges: {g.number_of_nodes('edge')}")
            if 'cpu_available' in g.nodes['edge'].data:
                cpu_avail = g.nodes['edge'].data['cpu_available'].numpy()
                ram_avail = g.nodes['edge'].data['ram_available'].numpy()
                print(f"Average available CPU: {cpu_avail.mean():.2f}")
                print(f"Average available RAM: {ram_avail.mean():.2f}")

        # 2. Display Vehicle node data
        if 'vehicle' in g.ntypes and g.number_of_nodes('vehicle') > 0:
            vehicle_data = {}
            for k, v in g.nodes['vehicle'].data.items():
                if v.ndim > 1:
                    for i in range(v.shape[1]):
                        vehicle_data[f"{k}_{i}"] = v[:, i].numpy()
                else:
                    vehicle_data[k] = v.numpy()

            vehicle_df = pd.DataFrame(vehicle_data)
            vehicle_df.index.name = 'Vehicle_ID'
            print("\nüöó VEHICLE NODES FEATURES:")
            print(vehicle_df)
            
            # Additional statistics for vehicles
            print(f"\nüìä Vehicle Statistics - Snapshot {t}:")
            print(f"Number of vehicles: {g.number_of_nodes('vehicle')}")
            if 'type_service' in g.nodes['vehicle'].data:
                services = g.nodes['vehicle'].data['type_service'].numpy()
                unique, counts = np.unique(services, return_counts=True)
                service_names = [
                    'cooperative_perception', 
                    'platooning_control', 
                    'edge_object_recognition', 
                    'predictive_collision_avoidance', 
                    'infrastructure_vision', 
                    'none'
                ]
                print("Service distribution:")
                for service_id, count in zip(unique, counts):
                    service_name = service_names[service_id] if service_id < len(service_names) else 'unknown'
                    print(f"  {service_name}: {count} vehicles")

        # 3. Display edge data
        if g.number_of_edges() > 0:
            print(f"\nüîó CONNECTIONS - Snapshot {t}:")
            print(f"Total connections: {g.number_of_edges()}")
            
            # Count connections to cloud vs local edges
            if 'edge' in g.ntypes and 'vehicle' in g.ntypes:
                cloud_idx = g.number_of_nodes('edge') - 1
                edge_connections = 0
                cloud_connections = 0
                
                for etype in g.etypes:
                    src, dst = g.edges(etype=etype)
                    for d in dst.numpy():
                        if d == cloud_idx:
                            cloud_connections += 1
                        else:
                            edge_connections += 1
                
                print(f"Connections to local edges: {edge_connections}")
                print(f"Connections to cloud: {cloud_connections}")
                print(f"Local connection rate: {edge_connections/(edge_connections+cloud_connections)*100:.1f}%")

if __name__ == "__main__":
    graphs = main()