In [7]:
!pip install tf-agents[reverb]

In [11]:
from __future__ import absolute_import, division, print_function

import keras
import numpy as np
import tensorflow as tf
import reverb
import random

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers.py_driver import PyDriver
from tf_agents.environments import py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.networks import sequential
from tf_agents.trajectories import time_step as ts
from tf_agents.policies.py_tf_eager_policy import PyTFEagerPolicy
from tf_agents.specs import tensor_spec
from tf_agents.specs.array_spec import BoundedArraySpec
from tf_agents.utils import common

In [12]:
fully_connected_layers: tuple[int, int] = (100, 50)
num_iterations: int = 20_000

initial_collect_steps: int = 100
collect_steps_per_iteration: int = 1

max_buffer_size: int = 10_000
batch_size: int = 64
learning_rate: float = 1e-3
log_interval: int = 200

num_eval_episodes: int = 10
eval_interval: int = 1_000

lane_length: int = 5
spawn_rate: float = 0.4
total_ticks: int = 50

In [13]:
class FastTrafficEnv(py_environment.PyEnvironment):
    def __init__(self):
        super().__init__()
        self._action_spec = BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=1, name="action"
        )
        self._observation_spec = BoundedArraySpec(
            shape=(2, lane_length),
            dtype=np.int32,
            minimum=0,
            maximum=2,
            name="observation",
        )
        self._init_game()

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._init_game()
        return ts.restart(self.parse_observation())

    def _step(self, action):
        return self.next_frame(action)

    # """ Inner Game Mechanics """

    def _init_game(self):
        self.cars = np.zeros(
            [2, lane_length],
            dtype=np.int32,
        )
        self.player = np.int32(random.choice([0, 1]))
        self.game_ended = False
        self.ticks = 0

    def parse_observation(self) -> np.ndarray[np.ndarray[np.int32]]:
        observation = self.cars.copy()
        observation[self.player, 0] = 2
        return observation

    def try_spawn_car(self, chance=None):
        if chance is None:
            chance = random.random()

        if chance > spawn_rate:
            return None

        lane = random.choice([0, 1])
        self.cars[lane][-1] = 1

    def move_cars(self):
        self.cars[:, 0] = 0
        self.cars = np.roll(self.cars, shift=-1)

    def check_collision(self) -> bool:
        return bool(self.cars[self.player][0])

    def apply_action(self, action):
        if action == 1:
            return None

        self.player = int(not bool(self.player))

    def next_frame(self, action):
        if self.ticks >= total_ticks or self.game_ended:
            return self.reset()

        self.apply_action(action=action)
        if self.check_collision():
            self.game_ended = True
            return ts.termination(
                observation=self.parse_observation(), reward=self.ticks
            )

        self.move_cars()
        self.try_spawn_car()
        self.ticks += 1

        return ts.transition(
            observation=self.parse_observation(),
            reward=self.ticks,
        )
