In [1]:
import numpy as np
import random
from PIL import Image
from abc import ABC, abstractmethod
from collections import defaultdict

random.seed(1)
np.random.seed(1)

SPRITE_SIZE = 32

class GridEntity:
    sprite_cache = {}
    @classmethod
    def preload_sprites(cls, symbols, orientations, sprite_path, sprite_size):
        for symbol in symbols:
            for orientation in orientations:
                img_path = f'{sprite_path}{symbol}.png'
                try:
                    img = Image.open(img_path).convert('RGBA').resize((sprite_size, sprite_size))
                    rotation = {'up': 180, 'down': 0, 'left': 270, 'right': 90}[orientation]
                    cls.sprite_cache[(symbol, orientation)] = img.rotate(rotation)
                except FileNotFoundError:
                    print(f"File not found: {img_path}")

    @classmethod
    def load_sprite(cls, symbol, orientation, sprite_path, sprite_size):
        # This method now simply retrieves the sprite from the cache
        return cls.sprite_cache[(symbol, orientation)]

    def __init__(self, x, y, symbol, orientation=None, speed=1, sprite_path='sprites/', sprite_size=64):
        self.x = x
        self.y = y
        self.symbol = symbol
        self.orientation = orientation
        self.speed = speed
        self.orientation_to_rotation = {'up': 180, 'down': 0, 'left': 270, 'right': 90}
        self.sprite_size = sprite_size
        self.sprite_path = sprite_path
        self.sprite = self.load_sprite(symbol, orientation, sprite_path, sprite_size)

    def set_position(self, x, y):
        self.x = x
        self.y = y

    def move_forward(self):
        if self.orientation == 'up':
            self.y -= self.speed
        elif self.orientation == 'down':
            self.y += self.speed
        elif self.orientation == 'left':
            self.x -= self.speed
        elif self.orientation == 'right':
            self.x += self.speed

    def turn_left(self):
        turns = {'up': 'left', 'left': 'down', 'down': 'right', 'right': 'up'}
        self.orientation = turns[self.orientation]
        self.sprite = self.sprite.rotate(90)

    def turn_right(self):
        turns = {'up': 'right', 'right': 'down', 'down': 'left', 'left': 'up'}
        self.orientation = turns[self.orientation]
        self.sprite = self.sprite.rotate(-90)

    def __repr__(self):
        return f"{self.symbol} at ({self.x}, {self.y}) facing {self.orientation}"

class Vehicle(GridEntity):
    def __init__(self, x, y, symbol, size, orientation, speed):
        super().__init__(x, y, symbol, orientation, speed)
        self.size = size

    def predict_next_position(self):
        dx, dy = 0, 0
        if self.orientation == 'up':
            dy = -self.speed
        elif self.orientation == 'down':
            dy = self.speed
        elif self.orientation == 'left':
            dx = -self.speed
        elif self.orientation == 'right':
            dx = self.speed
        return self.x + dx, self.y + dy

    def set_speed(self, speed):
        self.speed = speed

    def intervene_rotation(self, new_orientation):
        if new_orientation not in ['up', 'down', 'left', 'right']:
            raise ValueError("Invalid orientation")
        self.orientation = new_orientation
        self.sprite = self.load_sprite(self.symbol, self.orientation, self.sprite_path, self.sprite_size)


class Pedestrian(GridEntity):
    def __init__(self, x, y, symbol='P', orientation=None, speed=1):
        super().__init__(x, y, symbol, orientation, speed)

    def predict_random_walk(self):
        dx, dy = np.random.choice([-1, 0, 1]), np.random.choice([-1, 0, 1])
        return self.x + dx, self.y + dy

    def move_to(self, x, y):
        self.x = x
        self.y = y

class Obstacle(GridEntity):
    def __init__(self, x, y, symbol='O', orientation='down', speed=0):
        super().__init__(x, y, symbol, orientation, speed)

    def move_to(self, x, y):
        self.x = x
        self.y = y
    

class TrafficLight(GridEntity):
    def __init__(self, x, y, state='red', symbol='L', orientation=None, speed=0, frequency=None):
        super().__init__(x, y, symbol, orientation, speed)
        self.state = state
        self.orientation = orientation
        self.frequency = frequency
        self.traffic_light_sprite_cache = {}
        self._populate_tl_sprite_cache()     

        self.update_sprite()

    def update_sprite(self):
        self.sprite = self.traffic_light_sprite_cache[(self.orientation, self.state)]

    def update(self, step):
        if not self.frequency:
            return
        if step % sum(self.frequency) < self.frequency[0]:
            self.state = 'red'
        else:
            self.state = 'green'
        self.update_sprite()

    def change_orientation(self, new_orientation):
        if new_orientation not in ['up', 'down', 'left', 'right']:
            raise ValueError("Invalid orientation")
        self.orientation = new_orientation
        self.update_sprite()
    
    def _populate_tl_sprite_cache(self):
        # Iterate over each orientation and state to create and cache sprites
        for orient in ['up', 'down', 'left', 'right']:
            # Load the base sprite for the current orientation
            base_sprite = self.load_sprite(self.symbol, orient, self.sprite_path, self.sprite_size)

            # Create and cache the red and green sprites for this orientation
            red_sprite = Image.blend(base_sprite, Image.new('RGBA', base_sprite.size, (255, 0, 0, 128)), 0.5)
            green_sprite = Image.blend(base_sprite, Image.new('RGBA', base_sprite.size, (0, 255, 0, 128)), 0.5)

            # Cache the sprites with keys as (orientation, state)
            self.traffic_light_sprite_cache[(orient, 'red')] = red_sprite
            self.traffic_light_sprite_cache[(orient, 'green')] = green_sprite

    def intervene_state(self):
        self.state = 'green' if self.state == 'red' else 'red'
        self.update_sprite()
    
    def __repr__(self):
        return f"Traffic light at ({self.x}, {self.y}) facing {self.orientation} with state {self.state}."


class Gridworld:
    @classmethod
    def get_possible_intervention(cls, entity):
        if isinstance(entity, Vehicle):
            return ['turn']
        elif isinstance(entity, TrafficLight):
            return ['change_orientation', 'change_state']
        elif isinstance(entity, Pedestrian):
            return ['move_to']
        elif isinstance(entity, Obstacle):
            return ['move_to']

    @property
    def entities(self):
        return [entity for entities in self.entity_map.values() for entity in entities]

    def __init__(self, width, height, sprite_size=32):
        self.width = width
        self.height = height
        self.grid = np.full((height, width), ' ', dtype='<U1')
        self.entity_map = defaultdict(list)  # For efficient spatial queries
        self.step_count = 0
        self.sprite_size = sprite_size

    def add_entity(self, entity):
        self.entity_map[(entity.x, entity.y)].append(entity)

    # Optimized movement and collision detection
    def move_entity(self, entity, new_x, new_y):
        if self.is_cell_free(new_x, new_y) and 0 <= new_x < self.width and 0 <= new_y < self.height:
            self.entity_map[(entity.x, entity.y)].remove(entity)
            entity.set_position(new_x, new_y)
            self.entity_map[(new_x, new_y)].append(entity)

    def update_grid(self):
        self.grid.fill(' ')
        for entity in self.entities:
            if 0 <= entity.x < self.width and 0 <= entity.y < self.height:
                self.grid[entity.y][entity.x] = entity.symbol

    def display(self):
        for row in self.grid:
            print(' '.join(row))

    def step(self):
        self.step_count += 1

        # Update traffic lights based on step count
        # for position, entities in list(self.entity_map.items()):
        #     for entity in entities:
        #         if isinstance(entity, TrafficLight):
        #             entity.update(self.step_count)

        # Enforce traffic rules
        self.enforce_traffic_rules()
        # self.randomly_change_car_orientation()
        # Temporary structure to store entity movements
        movements = []

        # Update each entity
        for position, entities in list(self.entity_map.items()):
            for entity in entities:
                if isinstance(entity, Vehicle):
                    next_x, next_y = entity.predict_next_position()
                    movements.append((entity, next_x, next_y))
                elif isinstance(entity, Pedestrian):
                    next_x, next_y = entity.predict_random_walk()
                    movements.append((entity, next_x, next_y))
                # elif isinstance(entity, TrafficLight):
                #     entity.update(self.step_count)

        # Apply movements
        for entity, next_x, next_y in movements:
            self.move_entity(entity, next_x, next_y)

        # Handle collisions
        self.handle_collisions()

    def intervene(self, entity, intervention, **intervention_args):
        if isinstance(entity, Vehicle):
            if intervention == 'turn':
                entity.intervene_rotation(intervention_args['new_orientation'])
        elif isinstance(entity, TrafficLight):
            if intervention == 'change_orientation':
                entity.change_orientation(intervention_args['new_orientation'])
            elif intervention == 'change_state':
                entity.intervene_state()
        elif isinstance(entity, Pedestrian):
            if intervention == 'move_to':
                self.move_entity(entity, intervention_args['x'], intervention_args['y'])
        elif isinstance(entity, Obstacle):
            if intervention == 'move_to':
                self.move_entity(entity, intervention_args['x'], intervention_args['y'])
    
    def random_intervention(self):
        entity = random.choice(self.entities)
        possible_interventions = self.get_possible_intervention(entity)
        intervention = random.choice(possible_interventions)
        if intervention == 'turn':
            new_orientation = random.choice(['up', 'down', 'left', 'right'])
            self.intervene(entity, intervention, new_orientation=new_orientation)
        elif intervention == 'change_orientation':
            new_orientation = random.choice(['up', 'down', 'left', 'right'])
            self.intervene(entity, intervention, new_orientation=new_orientation)
        elif intervention == 'change_state':
            self.intervene(entity, intervention)
        elif intervention == 'move_to':
            # Pick a random cell around the entity that is free
            x, y = entity.x, entity.y
            possible_moves = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]
            possible_moves = [(x, y) for x, y in possible_moves if self.is_cell_free(x, y)]
            if possible_moves:
                new_x, new_y = random.choice(possible_moves)
                self.intervene(entity, intervention, x=new_x, y=new_y)
    
    def semi_random_intervention(self):
        """
        Make reasonable random interventions
        i.e., don't turn a car into a wall
        if there's an obstacle in front of a car, either turn it or move the obstacle
        """
        causals = self.get_causals()
        binary_interventions = [0 for _ in range(len(causals))]
        entity = random.choice(self.entities)

        if isinstance(entity, Vehicle):
            next_pos = entity.predict_next_position()
            if not self.is_cell_free(*next_pos):
                # Either turn the car or move the obstacle if there is one. Don't move light or pedestrian

                # Check if there is an obstacle in front of the car
                obstacle = self.get_entity_at_position(*next_pos)
                if obstacle and isinstance(obstacle, Obstacle):
                    # Move the obstacle
                    possible_moves = self.get_free_cells_around_entity(obstacle)
                    if possible_moves:
                        new_x, new_y = random.choice(possible_moves)
                        self.intervene(obstacle, 'move_to', x=new_x, y=new_y)
                    else:
                        # Turn the car
                        new_orientation = random.choice(['up', 'down', 'left', 'right'])
                        self.intervene(entity, 'turn', new_orientation=new_orientation)
                else:
                    # Turn the car
                    new_orientation = random.choice(['up', 'down', 'left', 'right'])
                    self.intervene(entity, 'turn', new_orientation=new_orientation)
            else:
                # Turn the car
                new_orientation = random.choice(['up', 'down', 'left', 'right'])
                self.intervene(entity, 'turn', new_orientation=new_orientation)
        elif isinstance(entity, TrafficLight):
            # Change the state of the traffic light
            self.intervene(entity, 'change_state')
            print(f"Intervened on traffic light at ({entity.x}, {entity.y})")
        elif isinstance(entity, Pedestrian):
            # Move the pedestrian to a random free cell
            possible_moves = self.get_free_cells_around_entity(entity)
            if possible_moves:
                new_x, new_y = random.choice(possible_moves)
                self.intervene(entity, 'move_to', x=new_x, y=new_y)
        elif isinstance(entity, Obstacle):
            # Move the obstacle to a random free cell
            possible_moves = self.get_free_cells_around_entity(entity)
            if possible_moves:
                new_x, new_y = random.choice(possible_moves)
                self.intervene(entity, 'move_to', x=new_x, y=new_y)

    
    def randomly_change_car_orientation(self):
        for position, entities in list(self.entity_map.items()):
            for entity in entities:
                if isinstance(entity, Vehicle):
                    if random.random() < 0.1:
                        if random.random() < 0.5:
                            entity.turn_left()
                        else:
                            entity.turn_right()                        


    def move_vehicle(self, vehicle):
        next_x, next_y = vehicle.predict_next_position()
        self.move_entity(vehicle, next_x, next_y)

    def move_pedestrian(self, pedestrian):
        next_x, next_y = pedestrian.predict_random_walk()
        self.move_entity(pedestrian, next_x, next_y)

    def handle_collisions(self):
        for pos, entities in self.entity_map.items():
            if len(entities) > 1:
                # Handle collisions between entities at the same position
                for entity1 in entities:
                    for entity2 in entities:
                        if entity1 != entity2:
                            self.resolve_collision(entity1, entity2)

    def resolve_collision(self, entity1, entity2):
        if isinstance(entity1, Vehicle) and isinstance(entity2, Vehicle):
            entity1.set_speed(0)
            entity2.set_speed(0)
            print(f"Collision between {entity1.symbol} and {entity2.symbol} at ({entity1.x}, {entity1.y})")

    def enforce_traffic_rules(self):
        for position, entities in list(self.entity_map.items()):
            for entity in entities:
                if isinstance(entity, Vehicle):
                    self.check_traffic_light(entity)

    def check_traffic_light(self, vehicle):
        vehicle.set_speed(1)
        for position, entities in list(self.entity_map.items()):
            for entity in entities:
                if isinstance(entity, TrafficLight):
                    if self.is_light_ahead(vehicle, entity):
                        if entity.state == 'red':
                            vehicle.set_speed(0)
                        else:
                            vehicle.set_speed(1)
                    # else:
                    #     vehicle.set_speed(1)


    def is_light_ahead(self, vehicle, light):
        dx, dy = vehicle.predict_next_position()
        if vehicle.orientation == 'up' and light.orientation == 'down':
            return light.y < vehicle.y and light.x == vehicle.x
        elif vehicle.orientation == 'down' and light.orientation == 'up':
            return light.y > vehicle.y and light.x == vehicle.x
        elif vehicle.orientation == 'left' and light.orientation == 'right':
            return light.x < vehicle.x and light.y == vehicle.y
        elif vehicle.orientation == 'right' and light.orientation == 'left':
            return light.x > vehicle.x and light.y == vehicle.y
        return False

    def is_cell_free(self, x, y):
        return not self.entity_map[(x, y)]

    def get_free_cells_around_entity(self, entity):
        x, y = entity.x, entity.y
        possible_moves = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]
        possible_moves = [(x, y) for x, y in possible_moves if self.is_cell_free(x, y)]
        return possible_moves

    def get_entity_at_position(self, x, y):
        if self.entity_map[(x, y)]:
            return self.entity_map[(x, y)][0]
        return None

    def render(self):
        sprite_size = self.sprite_size
        env_img = Image.new('RGBA', (self.width * sprite_size, self.height * sprite_size), (255, 255, 255, 0))
        for entities in self.entity_map.values():
            for entity in entities:
                img = entity.sprite
                x, y = entity.x * sprite_size, entity.y * sprite_size
                env_img.paste(img, (x, y), img)
        return env_img

    def randomly_initialize(self, num_cars=5, num_lights=5, num_boulders=5, x_percent=80, y_percent=20, z_percent=30, fixed_light_positions=[]):
        if fixed_light_positions:
            self.are_light_positions_fixed = True
        grid_size = self.width  # Assuming width and height are the same
        min_dist_from_edge = lambda percent: int(percent / 100 * grid_size)

        for (light_x, light_y, light_orientation) in fixed_light_positions:
            # Place light
            light = TrafficLight(light_x, light_y, 'red', 'L', light_orientation, frequency=(100, 1))
            self.add_entity(light)

            # Calculate car position based on light position and orientation
            car_orientation = get_opposite_orientation(light_orientation)
            car_x, car_y = calculate_light_position(light_x, light_y, car_orientation, min_dist=x_percent, max_dist=x_percent, grid_size=grid_size)
            vehicle = Vehicle(car_x, car_y, 'C', size=1, orientation=car_orientation, speed=1)
            self.add_entity(vehicle)

        for _ in range(num_cars - len(fixed_light_positions)):
            # 1. Semi-randomly place cars
            orientation = random.choice(['up', 'down', 'left', 'right'])
            min_dist = min_dist_from_edge(x_percent)
            if orientation == 'up':
                y = random.randint(min_dist, grid_size - 1)
                x = random.randint(0, grid_size - 1)
            elif orientation == 'down':
                y = random.randint(0, grid_size - min_dist - 1)
                x = random.randint(0, grid_size - 1)
            elif orientation == 'left':
                x = random.randint(min_dist, grid_size - 1)
                y = random.randint(0, grid_size - 1)
            else:  # right
                x = random.randint(0, grid_size - min_dist - 1)
                y = random.randint(0, grid_size - 1)

            vehicle = Vehicle(x, y, 'C', size=1, orientation=orientation, speed=1)
            self.add_entity(vehicle)

            # 2. Place traffic light facing the car
            light_x, light_y = self.calculate_light_position(x, y, orientation, min_dist=y_percent, grid_size=grid_size)
            light = TrafficLight(light_x, light_y, 'red', 'L', self.get_opposite_orientation(orientation), frequency=(100, 1))
            self.add_entity(light)

        # Randomly place additional traffic lights if needed
        for _ in range(num_lights - len(fixed_light_positions)):
            x, y = np.random.randint(0, grid_size), np.random.randint(0, grid_size)
            orientation = random.choice(['up', 'down', 'left', 'right'])
            light = TrafficLight(x, y, 'red', 'L', orientation, frequency=(100, 1))
            self.add_entity(light)

        # 3. Randomly place boulders
        for _ in range(num_boulders):
            x, y = np.random.randint(0, grid_size), np.random.randint(0, grid_size)
            # With z% chance, place the boulder between a car and its facing traffic light
            if random.randint(0, 100) < z_percent:
                # Select a random car
                car = random.choice([entity for entity in self.entities if isinstance(entity, Vehicle)])
                light_x, light_y = self.calculate_light_position(car.x, car.y, car.orientation, min_dist=y_percent, grid_size=grid_size)
                boulder_x, boulder_y = (car.x + light_x) // 2, (car.y + light_y) // 2  # Place boulder halfway between car and light
            else:
                boulder_x, boulder_y = x, y

            boulder = Obstacle(boulder_x, boulder_y, 'O')
            self.add_entity(boulder)

    def get_causals(self, are_light_positions_fixed=True):
        """

        Returns a dictionary of causal variables and their values
        The causal variables are:
        1. Car positions
        2. Car orientations
        3. Traffic light positions (if not fixed)
        4. Traffic light states
        5. Traffic light orientations
        6. Pedestrian positions
        7. Obstacle positions
        """
        causal_dict = {}
        for entity in self.entities:
            if isinstance(entity, Vehicle):
                causal_dict[f'car_{entity.symbol}_position'] = (entity.x, entity.y)
                causal_dict[f'car_{entity.symbol}_orientation'] = entity.orientation
            elif isinstance(entity, TrafficLight):
                if are_light_positions_fixed:
                    causal_dict[f'traffic_light_{entity.symbol}_position'] = (entity.x, entity.y)
                causal_dict[f'traffic_light_{entity.symbol}_state'] = entity.state
                causal_dict[f'traffic_light_{entity.symbol}_orientation'] = entity.orientation
            elif isinstance(entity, Pedestrian):
                causal_dict[f'pedestrian_{entity.symbol}_position'] = (entity.x, entity.y)
            elif isinstance(entity, Obstacle):
                causal_dict[f'obstacle_{entity.symbol}_position'] = (entity.x, entity.y)
        return causal_dict

        
    
    @staticmethod
    def get_opposite_orientation(orientation):
        return {'up': 'down', 'down': 'up', 'left': 'right', 'right': 'left'}[orientation]

    @staticmethod
    def calculate_light_position(x, y, orientation, min_dist=10, max_dist=20, grid_size=50):
        # Generate a random distance within the specified range
        distance = random.randint(min_dist, max_dist)

        # Calculate the offset based on orientation
        offset_x, offset_y = {'up': (0, -distance), 'down': (0, distance), 
                            'left': (-distance, 0), 'right': (distance, 0)}[orientation]

        # Calculate the new position and ensure it is within grid boundaries
        light_x = min(max(x + offset_x, 0), grid_size - 1)
        light_y = min(max(y + offset_y, 0), grid_size - 1)

        return light_x, light_y


# Create a random assortment of vehicles and pedestrians
import random
from itertools import product
from copy import deepcopy

grid_x, grid_y = 20, 20
gridworld = Gridworld(grid_x, grid_y, sprite_size=SPRITE_SIZE)

# Preload sprites
symbols = ['C', 'P', 'L', 'O']
orientations = ['up', 'down', 'left', 'right']
GridEntity.preload_sprites(symbols, orientations, sprite_path='sprites/', sprite_size=SPRITE_SIZE)

# Create vehicles, pedestrians, and traffic lights
vehicles = []
pedestrians = []
traffic_lights = []


for i in range(5):
    # Sample two random pairs of coordinates for the vehicle and obstacle without replacement
    xs, ys = np.random.choice(grid_x, size=(2, 2), replace=False)

    orientation = random.choice(['up', 'down', 'left', 'right'])

    # Create a vehicle
    vehicle = Vehicle(xs[0], ys[0], 'C', size=1, orientation=orientation, speed=1)
    vehicles.append(vehicle)
    gridworld.add_entity(vehicle)

    # Create random obstacle
    obstacle = Obstacle(xs[1], ys[1], 'O')
    obstacle2 = Obstacle(xs[1] + 1, ys[1], 'O')
    gridworld.add_entity(obstacle)
    
    # Create a traffic light
    light_x, light_y = gridworld.calculate_light_position(xs[0], ys[0], orientation, grid_size=grid_x)
    light = TrafficLight(light_x, light_y, 'red', 'L', gridworld.get_opposite_orientation(orientation), frequency=(100, 1))
    traffic_lights.append(light)
    gridworld.add_entity(light)

# Run the simulation
frames = []
for _ in range(30):  # Number of steps in the animation
    gridworld.semi_random_intervention()
    gridworld.step()

    # Render a fresh frame for each step
    frame = gridworld.render()
    frames.append(frame.copy())  # Store a copy of the frame
    # Show the frame
    # frame.show()
# Create an animated GIF from the frames
frames[0].save('gridworld_random.gif', save_all=True, append_images=frames[1:], duration=300, loop=0, disposal=2)

Intervened on traffic light at (4, 1)
Intervened on traffic light at (19, 1)
Intervened on traffic light at (3, 19)
Intervened on traffic light at (3, 19)
Intervened on traffic light at (3, 19)
Intervened on traffic light at (19, 5)
Intervened on traffic light at (19, 5)
Intervened on traffic light at (4, 1)
Intervened on traffic light at (15, 2)
