In [3]:
import numpy as np
from enum import Enum
import random

class Direction(Enum):
    NORTH = 0
    EAST = 1
    SOUTH = 2
    WEST = 3

class Vehicle:
    def __init__(self, position, direction):
        self.position = position  # (x, y)
        self.direction = direction
        self.waiting_time = 0
        self.has_crossed = False

class TrafficLight:
    def __init__(self, position):
        self.position = position
        self.is_green = False
        self.green_duration = 0

class TrafficEnvironment:
    def __init__(self, size=10, traffic_density='medium'):
        self.size = size
        self.grid = np.zeros((size, size))  # 0: empty, 1: road, 2: intersection
        self.vehicles = []
        self.traffic_lights = {}
        self.traffic_density = traffic_density
        self.time_step = 0
        
        # Initialize roads and intersections
        self.setup_roads()
        
    def setup_roads(self):
        """Create a simple intersection with two perpendicular roads"""
        # Horizontal road
        mid = self.size // 2
        self.grid[mid, :] = 1
        
        # Vertical road
        self.grid[:, mid] = 1
        
        # Mark intersection
        self.grid[mid, mid] = 2
        
        # Add traffic light at intersection
        self.traffic_lights[(mid, mid)] = TrafficLight((mid, mid))
    
    def add_vehicle(self, density_factor=1.0):
        """Add vehicles based on traffic density"""
        if self.traffic_density == 'low':
            spawn_probability = 0.1 * density_factor
        elif self.traffic_density == 'medium':
            spawn_probability = 0.2 * density_factor
        else:  # high
            spawn_probability = 0.3 * density_factor
            
        if random.random() < spawn_probability:
            # Randomly choose entry points
            mid = self.size // 2
            entry_points = [
                ((0, mid), Direction.SOUTH),
                ((self.size-1, mid), Direction.NORTH),
                ((mid, 0), Direction.EAST),
                ((mid, self.size-1), Direction.WEST)
            ]
            position, direction = random.choice(entry_points)
            self.vehicles.append(Vehicle(position, direction))
    
    def update_vehicle_positions(self):
        """Update positions of all vehicles"""
        for vehicle in self.vehicles:
            if vehicle.has_crossed:
                continue
                
            x, y = vehicle.position
            mid = self.size // 2
            
            # Check if at intersection
            if (x, y) == (mid, mid):
                traffic_light = self.traffic_lights.get((mid, mid))
                if traffic_light and traffic_light.is_green:
                    self._move_vehicle(vehicle)
                else:
                    vehicle.waiting_time += 1
            else:
                self._move_vehicle(vehicle)
    
    def _move_vehicle(self, vehicle):
        """Move a single vehicle based on its direction"""
        x, y = vehicle.position
        if vehicle.direction == Direction.NORTH:
            new_pos = (x-1, y)
        elif vehicle.direction == Direction.SOUTH:
            new_pos = (x+1, y)
        elif vehicle.direction == Direction.EAST:
            new_pos = (x, y+1)
        else:  # WEST
            new_pos = (x, y-1)
            
        # Check if vehicle has reached the end of the grid
        if (0 <= new_pos[0] < self.size and 
            0 <= new_pos[1] < self.size):
            vehicle.position = new_pos
        else:
            vehicle.has_crossed = True
    
    def step(self):
        """Perform one step in the environment"""
        self.time_step += 1
        self.add_vehicle()
        self.update_vehicle_positions()
        
        # Clean up vehicles that have crossed
        self.vehicles = [v for v in self.vehicles if not v.has_crossed]
        
        return self._get_state()
    
    def _get_state(self):
        """Return current state of the environment"""
        state = self.grid.copy()
        
        # Add vehicles to state
        for vehicle in self.vehicles:
            x, y = vehicle.position
            state[x, y] = 3  # 3 represents vehicle
            
        return state
    
    def reset(self):
        """Reset the environment"""
        self.vehicles = []
        self.time_step = 0
        for light in self.traffic_lights.values():
            light.is_green = False
            light.green_duration = 0
        return self._get_state()

    def set_traffic_density(self, density):
        """Set traffic density to low, medium, or high"""
        if density not in ['low', 'medium', 'high']:
            raise ValueError("Density must be 'low', 'medium', or 'high'")
        self.traffic_density = density

def test_environment():
    """Test function to verify environment behavior"""
    env = TrafficEnvironment(size=10, traffic_density='medium')
    
    # Run simulation for 20 steps
    for _ in range(20):
        state = env.step()
        
        # Count vehicles
        vehicle_count = len(env.vehicles)
        waiting_vehicles = sum(1 for v in env.vehicles if v.waiting_time > 0)
        
        print(f"Step {env.time_step}:")
        print(f"Total vehicles: {vehicle_count}")
        print(f"Waiting vehicles: {waiting_vehicles}")
        print("Current state:")
        print(state)
        print("\n")

KeyboardInterrupt: 