## Interpreting DQN network with social attention

In [1]:
%pip install rl-agents@git+https://github.com/manavdahra/rl-agents
%pip install highway-env@git+https://github.com/manavdahra/highway-env
%pip install plotly nbformat

mdahras-MacBook-Pro.local
Collecting rl-agents@ git+https://github.com/manavdahra/rl-agents
  Cloning https://github.com/manavdahra/rl-agents to /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-5xxus69n/rl-agents_4e02728afd874d7c9da144ae50a58c69
  Running command git clone --filter=blob:none --quiet https://github.com/manavdahra/rl-agents /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-5xxus69n/rl-agents_4e02728afd874d7c9da144ae50a58c69
  Resolved https://github.com/manavdahra/rl-agents to commit b65a875cd76c2b58a6124ed95235d041896667fe
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting highway_env@ git+https://github.com/manavdahra/highway-env (from rl-agents@ git+https://github.com/manavdahra/rl-agents)
  Cloning https://github.com/manavdahra/highway-env to /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/p

## Library imports

In [2]:
import torch
import numpy as np
import gymnasium as gym
import highway_env
import plotly.express as px

from pathlib import Path
from rl_agents.agents.deep_q_network.pytorch import DQNAgent

gym.register_envs(highway_env)
torch.manual_seed(123)
np.random.seed(123)

## Load trained agent

In [3]:
from rl_agents.trainer.evaluation import Evaluation
from rl_agents.agents.common.factory import load_agent, load_environment

env_config = 'config/env.json'
agent_config = 'config/agents/DQNAgent/ego_attention_4h.json'

env = load_environment(env_config)
agent = load_agent(agent_config, env)
evaluation = Evaluation(
    env, 
    agent, 
    num_episodes=100, 
    display_env=False, 
    display_agent=False,
    recover=True,
    directory="../../output/intersection-v0/ego-attention"
)
print(f"Ready to evaluate and interpret {agent} on {env}")

  logger.deprecation(
Preferred device cuda:best unavailable, switching to default cpu
  checkpoint = torch.load(filename, map_location=self.device)
[INFO] Loaded DQNAgent model from ../../output/intersection-v0/ego-attention/saved_models/latest.tar 


Ready to evaluate and interpret <rl_agents.agents.deep_q_network.pytorch.DQNAgent object at 0x103240f20> on <OrderEnforcing<PassiveEnvChecker<IntersectionEnv<intersection-v0>>>>


## Evaluate agent

In [4]:
import base64
from IPython import display as ipythondisplay

def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

In [5]:
evaluation = Evaluation(env, agent, num_episodes=20, training = False, recover = True)
evaluation.test()
show_videos(evaluation.run_directory)

2025-01-18 21:27:14.503 Python[73821:5037005] +[IMKClient subclass]: chose IMKClient_Modern
  return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
[INFO] Episode 0 score: 9.0 
[INFO] Episode 1 score: -1.0 
[INFO] Episode 2 score: 10.0 
[INFO] Episode 3 score: -1.0 
[INFO] Episode 4 score: 9.0 
[INFO] Episode 5 score: 9.0 
[INFO] Episode 6 score: 9.0 
[INFO] Episode 7 score: -1.0 
[INFO] Episode 8 score: 9.0 
[INFO] Episode 9 score: 9.0 
[INFO] Episode 10 score: 9.0 
[INFO] Episode 11 score: -1.0 
[INFO] Episode 12 score: -1.0 
[INFO] Episode 13 score: -0.3 
[INFO] Episode 14 score: 10.0 
[INFO] Episode 15 score: 1.0 
[INFO] Episode 16 score: -1.0 
[INFO] Episode 17 score: -1.0 
[INFO] Episode 18 score: -1.0 
[INFO] Episode 19 score: 10.0 


## Analysis and interpretation

- Agent analyzer will run agent and record it's activations, activity to draw some key insights
- Analyse the agent's architecture

In [22]:
from collections import defaultdict
from PIL import Image

def render_images(images):
    for image in images:
        Image.fromarray(image).show()

class ModelAnalyzer:
    def __init__(self, agent: DQNAgent, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.agent = agent
        self.reset()

    def reset(self):
        if hasattr(self, 'hooks'):
            for hook in self.hooks:
                hook.remove()

        self.states = []
        self.activations = defaultdict(list)
        self.hooks = []
        self.done = False
        self.previous_state, _ = env.reset()
        self.attention_matrix = []
        self.images = []
        
        self._register_hooks(self.agent.value_net)

    def step(self):
        if self.done:
            self.reset()
            return

        action = self.agent.act(self.previous_state)
        state, _, done, truncated, _ = env.step(action)
        self.done = done or truncated
        self.images.append(env.render())
        self._record(state)
    
    def _record(self, obs):
        self.states.append(obs)
        state = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        self.attention_matrix.append(self.agent.value_net.get_attention_matrix(state).detach())
            
    def _register_hooks(self, module: torch.nn.Module):
        for name, sub_module in module.named_modules():
            if isinstance(sub_module, torch.nn.Linear):
                self.hooks.append(sub_module.register_forward_hook(self._get_activation(name)))
        
    
    def _get_activation(self, name):
        def hook(module, args, output):
            self.activations[name].append(output.detach())
        
        return hook

### DQN with Social attention architecture 

In [23]:
print(agent.value_net)

EgoAttentionNetwork(
  (ego_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (others_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (attention_layer): EgoAttention(
    (value_all): Linear(in_features=64, out_features=64, bias=False)
    (key_all): Linear(in_features=64, out_features=64, bias=False)
    (query_ego): Linear(in_features=64, out_features=64, bias=False)
    (attention_combine): Linear(in_features=64, out_features=64, bias=False)
  )
  (output_layer): MultiLayerPerceptron(
    (layers): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
    )
    (predict): Linear(in_features=64, out_features=3, bias=True)
  )
)


### Render a random scene

### Run analyzer using the agent

In [28]:
analyzer = ModelAnalyzer(agent)

while not analyzer.done:
    analyzer.step()

In [30]:
for img, q_act in zip(analyzer.images, analyzer.activations["attention_layer.query_ego"]):
    print(q_act.shape)
    print(img.shape)

torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)
torch.Size([1, 1, 64])
(600, 1200, 3)


### Attention pattern

In [None]:
def analyze_attention_weights(analyzer):
    
    query_layer = analyzer.agent.value_net.attention_layer.query_ego
    key_layer = analyzer.agent.value_net.attention_layer.key_all
    value_layer = analyzer.agent.value_net.attention_layer.value_all
    attn_layer = analyzer.agent.value_net.attention_layer.attention_combine

    q_wt = next(query_layer.parameters()).detach().numpy()
    k_wt = next(key_layer.parameters()).detach().numpy()
    v_wt = next(value_layer.parameters()).detach().numpy()
    attn_mat = analyzer.attention_matrix.squeeze(0).squeeze(1)

    print(attn_mat.shape)
    px.imshow(np.dot(q_wt, k_wt.T), title="QK weights", color_continuous_scale="RdBu").show()
    px.imshow(attn_mat, title="Attention weights", color_continuous_scale="RdBu").show()
    
analyze_attention_weights(analyzer)

### Activation analysis

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def analyze_activations(analyzer):
    activations = analyzer.activations
    q_ego = activations["ego_embedding.layers.1"]
    q_oth = activations["others_embedding.layers.1"]

    q_act = activations["attention_layer.query_ego"]
    k_act = activations["attention_layer.key_all"]
    v_act = activations["attention_layer.value_all"]
    attn_combine = activations["attention_layer.attention_combine"]

    print(q_ego[0,0])
    print(q_act[0,0])
    # qk = q @ k.T
    # px.imshow(qk, title="QK circuit", color_continuous_scale="Rdbu").show()
    # px.imshow(attn_combine, title="Attn combined", color_continuous_scale="RdBu").show()

print(analyzer.agent.value_net)
analyze_activations(analyzer)