## Interpreting DQN network

In [1]:
%pip install rl-agents@git+https://github.com/manavdahra/rl-agents
%pip install highway-env@git+https://github.com/manavdahra/highway-env
%pip install stable_baselines3 plotly nbformat

Collecting rl-agents@ git+https://github.com/manavdahra/rl-agents
  Cloning https://github.com/manavdahra/rl-agents to /private/var/folders/8s/gwrqq0qn3b149hb0dtwlh8sw0000gn/T/pip-install-2j6_j0se/rl-agents_a87c56ff7e47419e9f4d0d17580ea6d9
  Running command git clone --filter=blob:none --quiet https://github.com/manavdahra/rl-agents /private/var/folders/8s/gwrqq0qn3b149hb0dtwlh8sw0000gn/T/pip-install-2j6_j0se/rl-agents_a87c56ff7e47419e9f4d0d17580ea6d9
  Resolved https://github.com/manavdahra/rl-agents to commit b65a875cd76c2b58a6124ed95235d041896667fe
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting highway_env@ git+https://github.com/manavdahra/highway-env (from rl-agents@ git+https://github.com/manavdahra/rl-agents)
  Cloning https://github.com/manavdahra/highway-env to /private/var/folders/8s/gwrqq0qn3b149hb0dtwlh8sw0000gn/T/pip-install-2j6_j0se/highwa

In [2]:
import torch
import numpy as np
from stable_baselines3 import DQN
import gymnasium as gym
import highway_env

gym.register_envs(highway_env)
torch.manual_seed(123)
np.random.seed(123)

In [3]:
env_name = "highway-v0"

env = gym.make("highway-v0", render_mode="rgb_array", config={"vehicles_count": 50, "lanes_count": 3})
dqn_model = DQN(
        "MlpPolicy",
        env=env,
        policy_kwargs=dict(net_arch=[256, 256]),
        learning_rate=5e-4,
        buffer_size=10000,
        learning_starts=200,
        batch_size=32,
        gamma=0.8,
        train_freq=1,
        gradient_steps=1,
        target_update_interval=50,
        verbose=1,
        tensorboard_log="output/highway-v0/tb/dqn",
)

filepath = "../../output/highway-v0/dqn.zip"
dqn_model.set_parameters(filepath)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


### Register hooks on model policy layers


Let's first analyse the model layers and it's architecture

In [4]:
class ModelAnalyzer:
    def __init__(self, model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = model
        self.records = []
        self.activations = {}
        self.hooks = []
        
        self._register_hooks()
        
    def run_single(self, obs: torch.Tensor):
        self.model.predict(obs, deterministic=True)
            
    def _record_activations(self, observation, last_action, new_action):
        if last_action == new_action:
            return
        
        self.records.append({
            "obs": observation,
            "last_action": last_action,
            "new_action": new_action,
            "hidden_1": self.activations["hidden_1"],
            "hidden_1_relu": self.activations["hidden_1_relu"],
            "hidden_2": self.activations["hidden_2"],
            "hidden_2_relu": self.activations["hidden_2_relu"],
        })

    def _register_hooks(self):
        self.hidden_1 = self.model.q_net.q_net[0]
        self.hidden_1_relu = self.model.q_net.q_net[1]
        self.hidden_2 = self.model.q_net.q_net[2]
        self.hidden_2_relu = self.model.q_net.q_net[3]

        self.hooks.append(self.hidden_1.register_forward_hook(self._get_activation("hidden_1")))
        self.hooks.append(self.hidden_1_relu.register_forward_hook(self._get_activation("hidden_1_relu")))
        self.hooks.append(self.hidden_2.register_forward_hook(self._get_activation("hidden_2")))
        self.hooks.append(self.hidden_2_relu.register_forward_hook(self._get_activation("hidden_2_relu")))
    
    def _get_activation(self, name):
        def hook(module, args, output):
            self.activations[name] = output.detach()
        
        return hook

    def _cleanup_hooks(self):
        for hook in self.hooks:
            hook.remove()

In [5]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def analyze_hidden_layer_weights(analyzer: ModelAnalyzer):
    def get_layer_weights(layer):
        weights, biases = layer.parameters()
        return weights.T.detach().numpy(), biases.detach().numpy()
    
    wts_1, b_1 = get_layer_weights(analyzer.hidden_1)
    wts_2, b_2 = get_layer_weights(analyzer.hidden_2)
    
    px.bar(b_1, title="Hidden layer 1 bias", color_continuous_scale="RdBu").show()
    px.bar(b_2, title="Hidden layer 2 bias", color_continuous_scale="RdBu").show()
    px.imshow(wts_1, title="Hidden layer 1 weights", color_continuous_scale="RdBu").show()
    px.imshow(wts_2, title="Hidden layer 2 weights", color_continuous_scale="RdBu").show()

def analyze_hidden_layer_activations(analyzer: ModelAnalyzer):
    obs = env.observation_space.sample()
    analyzer.run_single(obs)
    h_activations = analyzer.activations["hidden_1_relu"].T.detach().numpy()
    print(h_activations.shape)
    fig = px.bar(h_activations, title="Hidden layer 1 Relu activations")
    fig.show()

analyzer = ModelAnalyzer(dqn_model)
analyze_hidden_layer_weights(analyzer)

In [6]:
activations = {}
def get_activation(name):
    def hook(module, args, output):
        activations[name] = output.detach()
    
    return hook

In [38]:
from rl_agents.trainer.evaluation import Evaluation
from rl_agents.agents.common.factory import load_agent, load_environment

# Get the environment and agent configurations from the rl-agents repository
env_config = 'config/env.json'
agent_config = 'config/agents/DQNAgent/ego_attention_4h.json'

env = load_environment(env_config)
env.unwrapped.config["offscreen_rendering"] = True
agent = load_agent(agent_config, env)

evaluation = Evaluation(
    env, 
    agent, 
    num_episodes=1, 
    directory="../../output/intersection-v0/ego-attention/"
)

def register_hooks():
    attn_layer = agent.value_net.attention_layer
    attn_layer.query_ego.register_forward_hook(get_activation("query_ego"))
    attn_layer.key_all.register_forward_hook(get_activation("key_all"))
    attn_layer.value_all.register_forward_hook(get_activation("value_all"))
    attn_layer.attention_combine.register_forward_hook(get_activation("attention_combine"))
            
register_hooks()

evaluation.test()


[33mWARN: The environment intersection-v0 is out of date. You should consider upgrading to version `v1`.[0m

[INFO] Episode 0 score: 9.0 


In [42]:
import pickle
import os
from pathlib import Path

def save_activations(filepath: Path, activations):
    os.remove(filepath)
    os.makedirs(filepath, exist_ok=True)
    with open(filepath, "wb") as f:
        pickle.dump(activations, f)
        
filepath = Path("../../data_dir/activations.pkl")
save_activations(filepath, activations)

IsADirectoryError: [Errno 21] Is a directory: '../../data_dir/activations.pkl'

In [None]:
def load_activations(filepath: str):
    activations = {}
    with open(filepath, "rb") as f:
        activations = pickle.load(f)
    return activations

activations = load_activations(filepath)

In [50]:
for key in activations:
    print(key)
print(activations["key_all"].shape)
print(activations["query_ego"].shape)
print(activations["value_all"].shape)
print(activations["attention_combine"].shape)



key_all
value_all
query_ego
attention_combine
torch.Size([1, 15, 64])
torch.Size([1, 1, 64])
torch.Size([1, 15, 64])
torch.Size([1, 64])


In [35]:
def analyze_ego_attention_layer(activations):
    print(activations["query_ego"].shape)
    print(activations["key_all"].shape)
    print(activations["value_all"].shape)
    q = activations["query_ego"]
    k = activations["key_all"]
    v = activations["value_all"]
    attn_combine = activations["attention_combine"]
    
    qk = q @ k.T
    px.imshow(qk, title="QK circuit", color_continuous_scale="Rdbu").show()
    px.imshow(attn_combine, title="Attn combined", color_continuous_scale="RdBu").show()
    
analyze_ego_attention_layer(activations)

torch.Size([1, 1, 64])
torch.Size([1, 15, 64])
torch.Size([1, 15, 64])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [64, 64] but got: [64, 15].