# Training a DQN with social attention on `intersection-v0`



In [1]:
%pip install rl-agents@git+https://github.com/manavdahra/rl-agents
%pip install highway-env@git+https://github.com/manavdahra/highway-env

mdahras-MacBook-Pro.local
Collecting rl-agents@ git+https://github.com/manavdahra/rl-agents
  Cloning https://github.com/manavdahra/rl-agents to /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-ajeqa262/rl-agents_42130d062df1463bb205d01f9974c949
  Running command git clone --filter=blob:none --quiet https://github.com/manavdahra/rl-agents /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-ajeqa262/rl-agents_42130d062df1463bb205d01f9974c949
  Resolved https://github.com/manavdahra/rl-agents to commit b65a875cd76c2b58a6124ed95235d041896667fe
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting highway_env@ git+https://github.com/manavdahra/highway-env (from rl-agents@ git+https://github.com/manavdahra/rl-agents)
  Cloning https://github.com/manavdahra/highway-env to /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/p

In [2]:
import gymnasium as gym
import base64
from pathlib import Path

from gymnasium.wrappers import RecordVideo
from IPython import display as ipythondisplay

def record_videos(env, video_folder="videos"):
    wrapped = RecordVideo(
        env, video_folder=video_folder, episode_trigger=lambda e: True
    )

    # Capture intermediate frames
    env.unwrapped.set_record_video_wrapper(wrapped)

    return wrapped


def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

## Training

We use a policy architecture based on social attention, see [[Leurent and Mercat, 2019]](https://arxiv.org/abs/1911.12250).


In [3]:
from rl_agents.trainer.evaluation import Evaluation
from rl_agents.agents.common.factory import load_agent, load_environment

# Get the environment and agent configurations from the rl-agents repository
env_config = 'config/env.json'
agent_config = 'config/agents/DQNAgent/ego_attention_4h.json'

env = load_environment(env_config)
agent = load_agent(agent_config, env)
evaluation = Evaluation(
    env, 
    agent, 
    num_episodes=100, 
    display_env=False, 
    display_agent=False,
    recover=True,
    directory="../../output/intersection-v0/ego-attention"
)
print(f"Ready to train {agent} on {env}")

  logger.deprecation(
Preferred device cuda:best unavailable, switching to default cpu
  checkpoint = torch.load(filename, map_location=self.device)


Ready to train <rl_agents.agents.deep_q_network.pytorch.DQNAgent object at 0x1275a45f0> on <OrderEnforcing<PassiveEnvChecker<IntersectionEnv<intersection-v0>>>>


Run tensorboard locally to visualize training.

Start training. This should take about an hour.

In [4]:
evaluation.train()

  return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
2025-01-18 10:38:22.057 Python[57675:4112351] +[IMKClient subclass]: chose IMKClient_Modern
[INFO] Episode 0 score: -3.7 
[INFO] Saved DQNAgent model to ../../output/intersection-v0/ego-attention/run_20250118-103757_57675/checkpoint-0.tar 
[INFO] Episode 1 score: -2.0 
[INFO] Saved DQNAgent model to ../../output/intersection-v0/ego-attention/run_20250118-103757_57675/checkpoint-1.tar 
[INFO] Episode 2 score: -1.4 
[INFO] Episode 3 score: -1.4 
[INFO] Episode 4 score: 0.0 
[INFO] Episode 5 score: 6.1 
[INFO] Episode 6 score: 2.5 
[INFO] Episode 7 score: -3.3 
[INFO] Episode 8 score: -2.4 
[INFO] Saved DQNAgent model to ../../output/intersection-v0/ego-attention/run_20250118-103757_57675/checkpoint-8.tar 
[INFO] Episode 9 score: -1.4 
[INFO] Episode 10 score: -1.4 
[INFO] Episode 11 score: 6.0 
[INFO] Episode 12 score: 3.8 
[INFO] Episode 13 score: 0.0 
[INFO] Episode 14 score: -2.4 
[INFO

Progress can be visualised in the tensorboard cell above, which should update every 30s (or manually). You may need to click the *Fit domain to data* buttons below each graph.

## Testing

In [10]:
#@title Run the learned policy for a few episodes.
env = load_environment(env_config)
env.unwrapped.config["offscreen_rendering"] = True
agent = load_agent(agent_config, env)
evaluation = Evaluation(env, agent, num_episodes=20, training = False, recover = True)
evaluation.test()
show_videos(evaluation.run_directory)

  logger.deprecation(
  checkpoint = torch.load(filename, map_location=self.device)
[INFO] Episode 0 score: 8.0 
[INFO] Episode 1 score: -2.0 
[INFO] Episode 2 score: 0.0 
[INFO] Episode 3 score: 0.0 
[INFO] Episode 4 score: 0.0 
[INFO] Episode 5 score: 9.0 
[INFO] Episode 6 score: 9.0 
[INFO] Episode 7 score: 8.0 
[INFO] Episode 8 score: 9.0 
[INFO] Episode 9 score: -1.3 
[INFO] Episode 10 score: 9.0 
[INFO] Episode 11 score: 10.0 
[INFO] Episode 12 score: -0.6 
[INFO] Episode 13 score: 9.0 
[INFO] Episode 14 score: 9.0 
[INFO] Episode 15 score: 9.0 
[INFO] Episode 16 score: 1.0 
[INFO] Episode 17 score: 9.0 
[INFO] Episode 18 score: 8.0 
[INFO] Episode 19 score: 9.0 
