## Interpreting DQN network with social attention

In [118]:
%pip install plotly nbformat shap transformer_lens

MDAHRAS-MACBOOK-PRO.LOCAL-M-7XH9-M-7XH9-M-7XH9-M-7XH9-M-7XH9
Note: you may need to restart the kernel to use updated packages.


In [120]:
%pip install rl-agents@git+https://github.com/manavdahra/rl-agents#egg=rl-agents
%pip install highway-env@git+https://github.com/manavdahra/highway-env#egg=highway-env

MDAHRAS-MACBOOK-PRO.LOCAL-M-7XH9-M-7XH9-M-7XH9-M-7XH9-M-7XH9
Collecting rl-agents@ git+https://github.com/manavdahra/rl-agents#egg=rl-agents
  Cloning https://github.com/manavdahra/rl-agents to /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-ro__gdut/rl-agents_3ed8c6fd0cb9465480697a9bc628a8ea
  Running command git clone --filter=blob:none --quiet https://github.com/manavdahra/rl-agents /private/var/folders/40/b3pz_mbj6bl7vh33p2tyg6j00000gn/T/pip-install-ro__gdut/rl-agents_3ed8c6fd0cb9465480697a9bc628a8ea
  Resolved https://github.com/manavdahra/rl-agents to commit b65a875cd76c2b58a6124ed95235d041896667fe
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting highway_env@ git+https://github.com/manavdahra/highway-env (from rl-agents@ git+https://github.com/manavdahra/rl-agents#egg=rl-agents)
  Cloning https://github.com/manavdahra/highway-en

## Library imports

In [121]:
import torch
import numpy as np
import gymnasium as gym
import highway_env
import plotly.express as px

from pathlib import Path
from rl_agents.agents.deep_q_network.pytorch import DQNAgent

gym.register_envs(highway_env)
torch.manual_seed(123)
np.random.seed(123)

In [122]:
from rl_agents.agents.common.factory import load_agent, load_environment

env_config = 'config/env.json'
agent_config = 'config/agents/DQNAgent/ego_attention_4h.json'

env = load_environment(env_config)
agent = load_agent(agent_config, env)


[33mWARN: The environment intersection-v0 is out of date. You should consider upgrading to version `v1`.[0m



## Train agent

In [None]:
from rl_agents.trainer.evaluation import Evaluation

evaluation = Evaluation(
    env, 
    agent, 
    num_episodes=500, 
    display_env=False, 
    display_agent=False,
    recover=True,
    directory="../../output/intersection-v0/ego-attention"
)
evaluation.train()

## Load and Evaluate model

In [123]:
def load_agent_model(model_path):
    try:
        model_path = agent.load(filename=model_path)
        if model_path:
            print("Loaded {} model from {}".format(agent.__class__.__name__, model_path))
    except NotImplementedError:
        pass

load_agent_model("../../output/intersection-v0/ego-attention/saved_models/latest.tar")

Loaded DQNAgent model from ../../output/intersection-v0/ego-attention/saved_models/latest.tar



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



## Evaluate agent

In [124]:
import base64
from IPython import display as ipythondisplay

def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

In [125]:
from rl_agents.trainer.evaluation import Evaluation
evaluation = Evaluation(env, agent, num_episodes=10, training = False, recover = False)
evaluation.test()
show_videos(evaluation.run_directory)

[INFO] Episode 0 score: 0.0 
[INFO] Episode 1 score: -0.6 
[INFO] Episode 2 score: 0.0 
[INFO] Episode 3 score: 5.0 
[INFO] Episode 4 score: -1.3 
[INFO] Episode 5 score: 9.0 
[INFO] Episode 6 score: 9.0 
[INFO] Episode 7 score: 9.0 
[INFO] Episode 8 score: -2.5 
[INFO] Episode 9 score: 9.0 


## Analysis and interpretation

- Agent analyzer will run agent and record it's activations, activity to draw some key insights
- Analyse the agent's architecture

In [6]:
from collections import defaultdict
from PIL import Image

def render_images(images):
    for image in images:
        Image.fromarray(image).show()

class ModelAnalyzer:
    def __init__(self, agent: DQNAgent, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.agent = agent
        self.reset()

    def reset(self):
        if hasattr(self, 'hooks'):
            for hook in self.hooks:
                hook.remove()

        state, _ = env.reset()
        self.states = [state]
        self.actions = []
        self.activations = defaultdict(list)
        self.hooks = []
        self.done = False
        self.attention_matrix = []
        self.images = []
        self.total_reward = 0
        
        self._register_hooks(self.agent.value_net)

    def step(self):
        action = self.agent.act(self.states[-1])
        self.actions.append(action)
        state, reward, done, truncated, _ = env.step(action)
        self.total_reward += reward
        self.done = done or truncated
        self.images.append(env.render())
        self._record(state)

    def stats(self):
        print(f"Total reward: {self.total_reward}")
        print(f"States: {len(self.states)}")
        print(f"Actions: {len(self.actions)}")
        print(f"Attention matrices: {len(self.attention_matrix)}")
        print(f"images: {len(self.images)}")

        for key in self.activations:
            print(f"Activations for {key}: {len(self.activations[key])}")
    
    def _record(self, obs):
        self.states.append(obs)
        state = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        self.attention_matrix.append(self.agent.value_net.get_attention_matrix(state).detach())
            
    def _register_hooks(self, module: torch.nn.Module):
        for name, sub_module in module.named_modules():
            if isinstance(sub_module, torch.nn.Linear):
                self.hooks.append(sub_module.register_forward_hook(self._get_activation(name)))
        
    
    def _get_activation(self, name):
        def hook(module, args, output):
            self.activations[name].append(output.detach())
        
        return hook

### Check the agent architecture

In [7]:
print(evaluation.agent.value_net)

EgoAttentionNetwork(
  (ego_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (others_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (attention_layer): EgoAttention(
    (value_all): Linear(in_features=64, out_features=64, bias=False)
    (key_all): Linear(in_features=64, out_features=64, bias=False)
    (query_ego): Linear(in_features=64, out_features=64, bias=False)
    (attention_combine): Linear(in_features=64, out_features=64, bias=False)
  )
  (output_layer): MultiLayerPerceptron(
    (layers): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
    )
    (predict): Linear(in_features=64, out_features=3, bias=True)
  )
)


### Run analyzer using the agent

In [12]:
analyzer = ModelAnalyzer(agent)

while not analyzer.done:
    analyzer.step()

analyzer.stats()

Total reward: 9.0
States: 10
Actions: 9
Attention matrices: 9
images: 9
Activations for ego_embedding.layers.0: 297
Activations for ego_embedding.layers.1: 297
Activations for others_embedding.layers.0: 297
Activations for others_embedding.layers.1: 297
Activations for attention_layer.key_all: 297
Activations for attention_layer.value_all: 297
Activations for attention_layer.query_ego: 297
Activations for attention_layer.attention_combine: 297
Activations for output_layer.layers.0: 279
Activations for output_layer.layers.1: 279
Activations for output_layer.predict: 279


## Observation
- Feature optimization
- Logit lens on attention mechanism
- Circuit analysis on actions

## Intervention
- Path tracing

## Explainable RL
- Discovering symbolic policies on Deep RL networks. [Source](https://proceedings.mlr.press/v139/landajuela21a/landajuela21a.pdf)

### Attention pattern

In [148]:
import torch.nn.functional as F
import plotly.graph_objects as go
from plotly.subplots import make_subplots

untrained_agent = load_agent(agent_config, env)

def analyze_qk_circuit(value_net):
    query_layer = value_net.attention_layer.query_ego
    key_layer = value_net.attention_layer.key_all
    q_wt = query_layer.weight.detach().view(1, 64, 64).transpose(-2, -1)
    k_wt = key_layer.weight.detach().view(1, 64, 64)
    
    return F.softmax(torch.matmul(q_wt, k_wt), dim=-1)

def analyze_ov_circuit(value_net):
    value_layer = value_net.attention_layer.value_all
    output_later = value_net.attention_layer.attention_combine
    val_wt = value_layer.weight.detach().view(1, 64, 64)
    out_wt = output_later.weight.detach().view(1, 64, 64)
    
    return torch.matmul(out_wt, val_wt)

def analyze_embedding_weights(embedding):
    embedding_layer_0 = embedding.layers[0].weight.detach().T
    embedding_layer_1 = embedding.layers[1].weight.detach()
    
    return torch.matmul(embedding_layer_0, embedding_layer_1)

def analyze_embedding_biases(embedding):
    embedding_layer_0 = embedding.layers[0].bias.detach()
    embedding_layer_1 = embedding.layers[1].bias.detach()
    
    return [embedding_layer_0, embedding_layer_1]

def compute_diff(a: torch.Tensor, b: torch.Tensor):
    return ((a - b)**2).numpy()

def compare_qk_circuit(untrained_agent, trained_agent):
    untrained_qk_circuit = analyze_qk_circuit(untrained_agent.value_net)
    trained_qk_circuit = analyze_qk_circuit(trained_agent.value_net)

    heads = untrained_qk_circuit.size(0)
    for h in range(heads):
        fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
        fig.add_trace(go.Heatmap(z=untrained_qk_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Heatmap(z=trained_qk_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=2)
        fig.add_trace(go.Heatmap(z=compute_diff(trained_qk_circuit[h], untrained_qk_circuit[h]), coloraxis="coloraxis"), row=1, col=3)
        fig.update_layout(title_text=f"Attention head: {h+1} QK Circuits", coloraxis=dict(colorscale="RdBu"))
        fig.show()
    
def compare_ov_circuit(untrained_agent, trained_agent):
    untrained_ov_circuit = analyze_ov_circuit(untrained_agent.value_net)
    trained_ov_circuit = analyze_ov_circuit(trained_agent.value_net)

    heads = untrained_ov_circuit.size(0)
    for h in range(heads):
        fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
        fig.add_trace(go.Heatmap(z=untrained_ov_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Heatmap(z=trained_ov_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=2)
        fig.add_trace(go.Heatmap(z=compute_diff(trained_ov_circuit[h], untrained_ov_circuit[h]), coloraxis="coloraxis"), row=1, col=3)
        fig.update_layout(title_text=f"Attention head: {h+1} OV Circuits", coloraxis=dict(colorscale="RdBu"))
        fig.show()
    

compare_qk_circuit(untrained_agent, analyzer.agent)
compare_ov_circuit(untrained_agent, analyzer.agent)



Above comparisions are between QK and OV circuits of untrained and trained agents with MSE difference in values.
A quick look reveals that indeed the agent learned some structure/pattern of taking actions which helps in avoiding crashes with other vehicles. 

Observations of structure:
- Regularised patterns with gaps in between. It's worth studying whether these gaps change if the environment settings change.
- It is also visible that the weight patterns alternate between high and low activations. Why ?
- QK circuit reveals that the Neurons 10 and 41 have high attention weights assigned, qualifying them as candidates for further analysis.
- OV circuit reveals a clear activation for few neurons again with alternating patterns.

> TODO: understand these patterns and try to make some conclusions about them

### Analyze Embedding matrices
Next, let's analyze the embedding matrices of the agent's architecture which sits before Attention layer of the transformers. These matrices are crucial in understanding as to what the agent thinks are important features, this will guide us in the right direction of applying interpretability techniques.

In [149]:
def compare_ego_embeddings_matrix(untrained_agent, trained_agent):
    untrained_ego_embeddings_wt = analyze_embedding_weights(untrained_agent.value_net.ego_embedding)
    trained_ego_embeddings_wt = analyze_embedding_weights(trained_agent.value_net.ego_embedding)

    features = ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"]
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained"))
    fig.add_trace(go.Heatmap(z=untrained_ego_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_ego_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_diff(trained_ego_embeddings_wt, untrained_ego_embeddings_wt), coloraxis="coloraxis", y=features), row=1, col=3)
    fig.update_layout(
        title_text="Ego Embeddings Matrix", 
        coloraxis=dict(colorscale="RdBu"), 
    )
    fig.show()

    untrained_ego_embeddins_b = analyze_embedding_biases(untrained_agent.value_net.ego_embedding)
    trained_ego_embeddings_b = analyze_embedding_biases(trained_agent.value_net.ego_embedding)

    fig = make_subplots(rows=1, cols=2)
    fig.add_trace(go.Bar(x=features, y=torch.abs(trained_ego_embeddings_b[0] - untrained_ego_embeddins_b[0]), name="Layer 0 Biases diff"), row=1, col=1)
    fig.add_trace(go.Bar(x=features, y=torch.abs(trained_ego_embeddings_b[1] - untrained_ego_embeddins_b[1]), name="Layer 1 Biases diff"), row=1, col=2)
    fig.update_layout(title_text="Ego Embeddings Biases")
    fig.show()

def compare_others_embeddings_matrix(untrained_agent, trained_agent):
    untrained_others_embeddings_wt = analyze_embedding_weights(untrained_agent.value_net.others_embedding)
    trained_others_embeddings_wt = analyze_embedding_weights(trained_agent.value_net.others_embedding)

    features = ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"]
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained"))
    fig.add_trace(go.Heatmap(z=untrained_others_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_others_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_diff(trained_others_embeddings_wt, untrained_others_embeddings_wt), coloraxis="coloraxis", y=features), row=1, col=3)
    fig.update_layout(
        title_text="Others Embeddings Matrix", 
        coloraxis=dict(colorscale="RdBu"), 
    )
    fig.show()

    untrained_others_embeddins_b = analyze_embedding_biases(untrained_agent.value_net.others_embedding)
    trained_others_embeddings_b = analyze_embedding_biases(trained_agent.value_net.others_embedding)

    fig = make_subplots(rows=1, cols=2)
    fig.add_trace(go.Bar(x=features, y=torch.abs(trained_others_embeddings_b[0] - untrained_others_embeddins_b[0]), name="Layer 0 Biases diff"), row=1, col=1)
    fig.add_trace(go.Bar(x=features, y=torch.abs(trained_others_embeddings_b[1] - untrained_others_embeddins_b[1]), name="Layer 1 Biases diff"), row=1, col=2)
    fig.update_layout(title_text="Others Embeddings Biases")
    fig.show()

compare_ego_embeddings_matrix(untrained_agent, analyzer.agent)
compare_others_embeddings_matrix(untrained_agent, analyzer.agent)

Above graphs hint at the possibility that the agent's ego embedding layer (responsible for tracking observations of itself) has learnt to assign more weights to features - `y`, `vx`, `x` etc, in the decreasing order of their importance.
On the other hand, agent's others embedding layer (responsible for tracking observations of other vehicles) has learnt to assign more weights to the feature `vx`, `x` etc.

Qualitatively, that seems reasonable. In order for the agent to make a successful turn at the intersection, the agent needs to track following key information:
- Vertical position of itself. Most likely the agent has learnt to slow down as it's `y` coordinate (increases) reaches closer to the intersection, conversely speed up when it is far away from intersection.
- Horizontal velocity of the other vehicles `vx`. This seems more important as agent learns to wait at the intersection when a vehicle is crossing the intersection at high speed. Thereby increasing its chances of avoiding collision.
- Other features like `presence` and `x` also show medium level weights values.

> Studying absolute differences of biases between untrained and trained models don't reveal any good insights. Maybe a different approach of analyzing them may reveal better insights.

### MLP Layer analysis

In [91]:
def analyze_mlp_hidden_weights(agent):
    output_layer = agent.value_net.output_layer
    wt_layer_0 = output_layer.layers[0].weight.detach().T
    wt_layer_1 = output_layer.layers[1].weight.detach()

    return torch.matmul(wt_layer_0, wt_layer_1)

def compare_mlp_layer(untrained_agent, trained_agent):
    untrained_mlp_layer = analyze_mlp_hidden_weights(untrained_agent)
    trained_mlp_layer = analyze_mlp_hidden_weights(trained_agent)

    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
    fig.add_trace(go.Heatmap(z=untrained_mlp_layer, coloraxis="coloraxis"), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_mlp_layer, coloraxis="coloraxis"), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_diff(trained_mlp_layer, untrained_mlp_layer), coloraxis="coloraxis"), row=1, col=3)
    fig.update_layout(title_text="MLP Layer", coloraxis=dict(colorscale="RdBu"))
    fig.show()

compare_mlp_layer(untrained_agent, analyzer.agent)

### Perturbing activations

In [103]:
# Function to perturb activations
def perturb_activations(activations, perturbation):
    return activations + perturbation

def analyze_final_layer_perturbations(analyzer):
    for i in range(len(analyzer.activations["output_layer.layers.1"])):
        input = analyzer.activations["output_layer.layers.1"][i].detach()
        org_out = analyzer.agent.value_net.output_layer.predict(input).detach()
        
        pert_in = perturb_activations(input, 0.1)
        pert_out = analyzer.agent.value_net.output_layer.predict(pert_in).detach()
        
        diff = torch.norm(org_out - pert_out, p=2).item()
        print(f"Step: {i}, Perturbation: {diff}")

analyze_final_layer_perturbations(analyzer)

Step: 0, Perturbation: 0.09943672269582748
Step: 1, Perturbation: 0.09943672269582748
Step: 2, Perturbation: 0.09943672269582748
Step: 3, Perturbation: 0.09943672269582748
Step: 4, Perturbation: 0.09943672269582748
Step: 5, Perturbation: 0.09943672269582748
Step: 6, Perturbation: 0.09943672269582748
Step: 7, Perturbation: 0.09943672269582748
Step: 8, Perturbation: 0.09943672269582748
Step: 9, Perturbation: 0.09943672269582748
Step: 10, Perturbation: 0.09943672269582748
Step: 11, Perturbation: 0.09943672269582748
Step: 12, Perturbation: 0.09943672269582748
Step: 13, Perturbation: 0.09943672269582748
Step: 14, Perturbation: 0.09943672269582748
Step: 15, Perturbation: 0.09943672269582748
Step: 16, Perturbation: 0.09943672269582748
Step: 17, Perturbation: 0.09943672269582748
Step: 18, Perturbation: 0.09943672269582748
Step: 19, Perturbation: 0.09943672269582748
Step: 20, Perturbation: 0.09943672269582748
Step: 21, Perturbation: 0.09943672269582748
Step: 22, Perturbation: 0.0994367226958274

### Logit Lens

In [83]:
### TODO apply logit lens technique here ###

## SHAP technique

In [None]:
### TODO apply shapley values here ###

### Activation analysis

In [None]:
def attn_scores(query: torch.Tensor, key: torch.Tensor):
    d_k = query.size(-1)
    return F.softmax(torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(d_k), dim=-1)

```
Total reward: 10.0
States: 11
Actions: 10
Attention matrices: 10
images: 10
Activations for ego_embedding.layers.0: 330
Activations for ego_embedding.layers.1: 330
Activations for others_embedding.layers.0: 330
Activations for others_embedding.layers.1: 330
Activations for attention_layer.key_all: 330
Activations for attention_layer.value_all: 330
Activations for attention_layer.query_ego: 330
Activations for attention_layer.attention_combine: 330
Activations for output_layer.layers.0: 310
Activations for output_layer.layers.1: 310
Activations for output_layer.predict: 310
```

In [None]:
def analyze_activations(analyzer):
    activations = analyzer.activations

    q_acts = activations['attention_layer.query_ego']
    k_acts = activations['attention_layer.key_all']
    v_acts = activations['attention_layer.value_all']

    for i in range(150, 160):
        fig = make_subplots(rows=1, cols=2, subplot_titles=("Attn scores", "Scene"))
        scores = attn_scores(q_acts[i].repeat(1, k_acts[i].size(1), 1), k_acts[i]).squeeze(0).numpy()
        fig.add_trace(go.Heatmap(z=scores, coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Image(z=analyzer.images[i//30]), row=1, col=2)
        fig.update_layout(title_text=f"Step: {i}", coloraxis=dict(colorscale="RdBu"))
        fig.show()


def analyze_actions(analyzer):
    predict_acts = analyzer.activations["output_layer.predict"]
    critical_actions = [
        0, # slower 
        1, # idle
    ]
    critical_steps = []
    print(analyzer.actions)
    for step, action in enumerate(analyzer.actions):
        if action in critical_actions:
            critical_steps.append((step-1, action))

    for step, action in critical_steps:
        px.imshow(analyzer.images[step], title=f"Step: {step}").show()
        # print(IntersectionEnv.ACTIONS[action], predict_acts[step])
        px.imshow(predict_acts[step], title=f"Predicted actions: {step}").show()

analyze_activations(analyzer)
# analyze_actions(analyzer)