## Interpreting DQN network with social attention

In [None]:
%pip install plotly nbformat shap transformer_lens scikit-image torch-lucent captum

In [None]:
%pip install rl-agents@git+https://github.com/manavdahra/rl-agents#egg=rl-agents
%pip install highway-env@git+https://github.com/manavdahra/highway-env#egg=highway-env

## Imports libraries

In [1]:
import torch
import numpy as np
import gymnasium as gym
import highway_env
import plotly.express as px

from pathlib import Path
from rl_agents.agents.deep_q_network.pytorch import DQNAgent

gym.register_envs(highway_env)
torch.manual_seed(123)
np.random.seed(123)

In [2]:
from rl_agents.agents.common.factory import load_agent, load_environment

env_config = 'config/env.json'
agent_config = 'config/agents/DQNAgent/ego_attention.json'

env = load_environment(env_config)
agent = load_agent(agent_config, env)

  logger.deprecation(
Preferred device cuda:best unavailable, switching to default cpu


## Train agent

In [None]:
from rl_agents.trainer.evaluation import Evaluation

evaluation = Evaluation(
    env, 
    agent, 
    num_episodes=1000, 
    display_env=False, 
    display_agent=False,
    recover=False,
    directory="../../output/intersection-v0/ego-attention"
)
evaluation.train()

## Load agent

In [33]:
def load_agent_model(model_path):
    try:
        model_path = agent.load(filename=model_path)
        if model_path:
            print("Loaded {} model from {}".format(agent.__class__.__name__, model_path))
    except NotImplementedError:
        pass

load_agent_model("../../output/intersection-v0/ego-attention/saved_models/latest.tar")

Loaded DQNAgent model from ../../output/intersection-v0/ego-attention/saved_models/latest.tar



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



## Evaluate agent

In [34]:
import base64
from IPython import display as ipythondisplay

def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

In [35]:
from rl_agents.trainer.evaluation import Evaluation
evaluation = Evaluation(env, agent, num_episodes=10, training = False, recover = False)
evaluation.test()
show_videos(evaluation.run_directory)

[INFO] Episode 0 score: 9.0 
[INFO] Episode 1 score: 9.0 
[INFO] Episode 2 score: -2.0 
[INFO] Episode 3 score: 10.0 
[INFO] Episode 4 score: -2.5 
[INFO] Episode 5 score: 9.0 
[INFO] Episode 6 score: 3.0 
[INFO] Episode 7 score: 9.0 
[INFO] Episode 8 score: 10.0 
[INFO] Episode 9 score: -1.0 


## Analysis and interpretation

- Agent analyzer will run agent and record it's activations, activity to draw some key insights
- Analyse the agent's architecture

In [36]:
# Defined variables of the environment
features = ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"]
actions_map = {
    0: "SLOWER",
    1: "IDLE",
    2: "FASTER",
}

### Model analyzer
It does the following:
1. Encapsulates the agent
2. Registers forward hooks to capture activations of each linear layer and records scene images.
3. Runs evaluations and prints statistics.

In [37]:
from collections import defaultdict
from PIL import Image

untrained_agent = load_agent(agent_config, env)

def render_images(images):
    for image in images:
        Image.fromarray(image).show()

class ModelAnalyzer:
    def __init__(self, agent: DQNAgent, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.agent = agent
        self.reset()

    def reset(self):
        if hasattr(self, 'hooks'):
            for hook in self.hooks:
                hook.remove()

        state, _ = env.reset()
        self.states = [state]
        self.actions = []
        self.activations = defaultdict(list)
        self.hooks = []
        self.done = False
        self.attention_matrix = []
        self.images = []
        self.total_reward = 0
        
        self._register_hooks(self.agent.value_net)

    def step(self):
        action = self.agent.act(self.states[-1])
        self.actions.append(action)
        state, reward, done, truncated, _ = env.step(action)
        self.total_reward += reward
        self.done = done or truncated
        self.images.append(env.render())
        self._record(state)

    def stats(self):
        print(f"Total reward: {self.total_reward}")
        print(f"States: {len(self.states)}")
        print(f"Actions: {len(self.actions)}")
        print(f"Attention matrices: {len(self.attention_matrix)}")
        print(f"images: {len(self.images)}")

        for key in self.activations:
            print(f"Activations for {key}: {len(self.activations[key])}")
    
    def _record(self, obs):
        self.states.append(obs)
        state = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        self.attention_matrix.append(self.agent.value_net.get_attention_matrix(state).detach())
            
    def _register_hooks(self, module: torch.nn.Module):
        for name, sub_module in module.named_modules():
            if isinstance(sub_module, torch.nn.Linear):
                self.hooks.append(sub_module.register_forward_hook(self._get_activation(name)))
        
    
    def _get_activation(self, name):
        def hook(module, args, output):
            self.activations[name].append(output.detach())
        
        return hook



Print the agent's model architecture

In [38]:
print(evaluation.agent.value_net)

EgoAttentionNetwork(
  (ego_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (others_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (attention_layer): EgoAttention(
    (value_all): Linear(in_features=64, out_features=64, bias=False)
    (key_all): Linear(in_features=64, out_features=64, bias=False)
    (query_ego): Linear(in_features=64, out_features=64, bias=False)
    (attention_combine): Linear(in_features=64, out_features=64, bias=False)
  )
  (output_layer): MultiLayerPerceptron(
    (layers): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
    )
    (predict): Linear(in_features=64, out_features=3, bias=True)
  )
)


### Run analyzer on trained agent and print stats

In [43]:
analyzer = ModelAnalyzer(agent)

while not analyzer.done:
    analyzer.step()

analyzer.stats()

Total reward: 7.475491573760404
States: 12
Actions: 11
Attention matrices: 11
images: 11
Activations for ego_embedding.layers.0: 363
Activations for ego_embedding.layers.1: 363
Activations for others_embedding.layers.0: 363
Activations for others_embedding.layers.1: 363
Activations for attention_layer.key_all: 363
Activations for attention_layer.value_all: 363
Activations for attention_layer.query_ego: 363
Activations for attention_layer.attention_combine: 363
Activations for output_layer.layers.0: 341
Activations for output_layer.layers.1: 341
Activations for output_layer.predict: 341


In [44]:
print("Actions sequence: ", [actions_map[a] for a in analyzer.actions])

Actions sequence:  ['FASTER', 'IDLE', 'FASTER', 'FASTER', 'SLOWER', 'SLOWER', 'FASTER', 'FASTER', 'FASTER', 'FASTER', 'IDLE']


## Interpretation of the agent
The process of interpretation of the agent model can be broken down into 2 major categories:
- Interpretation by Obersvation (TODO: add motivation and context here)
- Interpreation by Intervention (TODO: add motivation and context here)

### Observations:
- Attention layer QK and OV circuits
    - Add explanations and some theory [Attention heads](https://transformer-circuits.pub/2021/framework/index.html#:~:text=CONCEPTUAL%20TAKE%2DAWAYS)
    - Observations - Done
- Study feature importance using captum
    - Explain integrated gradients approach [Integrated graidents](https://arxiv.org/pdf/1703.01365)
    - Observations - Done
    - Explain neuron conductance approach: [Neuron conductance](https://arxiv.org/pdf/1805.12233)
    - Observations - TODO
- Plot neuron activations
    - Motivation and reasoning for viewing individual neurons behaviour
    - Observations - In progress
- Semantic dictionary
    - Explain what this means and why we need this
    - Observations - TODO
- Feature optimization
    - Exaplainer
    - Observations - TODO
    - Render a scene with given input observations
- Logit lens on attention mechanism
    - Explainer
- Circuit analysis on actions - TODO
- Tensor-to-tensor visualisation - TODO

### Intervention:
- Ablation - TODO
- Model Editing - TODO
- Path tracing - TODO

### Explainable RL (not in scope)
- Discovering symbolic policies on Deep RL networks. [Source](https://proceedings.mlr.press/v139/landajuela21a/landajuela21a.pdf)

In [14]:
import torch.nn.functional as F
def cosine_similarity(a: torch.Tensor, b: torch.Tensor):
    return F.cosine_similarity(a, b).numpy()

def compute_square_diff(a: torch.Tensor, b: torch.Tensor):
    return (a - b).pow(2).numpy()

def compute_diff(a: torch.Tensor, b: torch.Tensor):
    return (a - b).abs().numpy()

### QK and OV Circuits

In [45]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def analyze_qk_circuit(value_net):
    config = value_net.attention_layer.config
    heads, features_per_head = config["heads"], config["feature_size"] // config["heads"]
    query_layer = value_net.attention_layer.query_ego
    key_layer = value_net.attention_layer.key_all
    q_wt = query_layer.weight.detach().view(heads, features_per_head, features_per_head).transpose(-2, -1)
    k_wt = key_layer.weight.detach().view(heads, features_per_head, features_per_head)
    
    return F.softmax(torch.matmul(q_wt, k_wt), dim=-1).detach()

def analyze_ov_circuit(value_net):
    config = value_net.attention_layer.config
    heads, features_per_head = config["heads"], config["feature_size"] // config["heads"]
    value_layer = value_net.attention_layer.value_all
    output_later = value_net.attention_layer.attention_combine
    val_wt = value_layer.weight.detach().view(heads, features_per_head, features_per_head)
    out_wt = output_later.weight.detach().view(heads, features_per_head, features_per_head)
    
    return torch.matmul(out_wt, val_wt).detach()

def analyze_embedding_weights(embedding):
    embedding_layer_0 = embedding.layers[0].weight.detach().T
    embedding_layer_1 = embedding.layers[1].weight.detach()
    
    return torch.matmul(embedding_layer_0, embedding_layer_1).detach()

def analyze_embedding_biases(embedding):
    embedding_layer_0 = embedding.layers[0].bias.detach()
    embedding_layer_1 = embedding.layers[1].bias.detach()
    
    return [embedding_layer_0, embedding_layer_1]

def compare_qk_circuit(untrained_agent, trained_agent):
    untrained_qk_circuit = analyze_qk_circuit(untrained_agent.value_net)
    trained_qk_circuit = analyze_qk_circuit(trained_agent.value_net)

    heads = untrained_qk_circuit.size(0)
    for h in range(heads):
        fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
        fig.add_trace(go.Heatmap(z=untrained_qk_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Heatmap(z=trained_qk_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=2)
        fig.add_trace(go.Heatmap(z=compute_diff(trained_qk_circuit[h], untrained_qk_circuit[h]), coloraxis="coloraxis"), row=1, col=3)
        fig.update_layout(title_text=f"Attention head: {h+1} QK Circuits", coloraxis=dict(colorscale="RdBu"))
        fig.show()
    
def compare_ov_circuit(untrained_agent, trained_agent):
    untrained_ov_circuit = analyze_ov_circuit(untrained_agent.value_net)
    trained_ov_circuit = analyze_ov_circuit(trained_agent.value_net)

    heads = untrained_ov_circuit.size(0)
    for h in range(heads):
        fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
        fig.add_trace(go.Heatmap(z=untrained_ov_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Heatmap(z=trained_ov_circuit[h].numpy(), coloraxis="coloraxis"), row=1, col=2)
        fig.add_trace(go.Heatmap(z=compute_diff(trained_ov_circuit[h], untrained_ov_circuit[h]), coloraxis="coloraxis"), row=1, col=3)
        fig.update_layout(title_text=f"Attention head: {h+1} OV Circuits", coloraxis=dict(colorscale="RdBu"))
        fig.show()
    

compare_qk_circuit(untrained_agent, analyzer.agent)
compare_ov_circuit(untrained_agent, analyzer.agent)

Above comparisions are between QK and OV circuits of untrained and trained agents with Absolute difference in values.
A quick look reveals that indeed the agent learned some structure/pattern of taking actions which helps in avoiding crashes with other vehicles. 

Observations of structure:
- Rectangular blocks of slightly higher activations can be seen. It's worth studying whether these gaps change if the environment settings change.
- It is also visible that the weight patterns alternate between high and low activations. Why ?
- QK circuit reveals that the Neurons 17, 31 and 47 have high attention weights assigned, qualifying them as candidates for further analysis.
- OV circuit reveals a clear activation for neuron 31 again with alternating patterns.

> TODO: understand these patterns and try to make some conclusions about them

## Feature importance

### Study average feature importance using Integrated gradients approach

In [None]:
from captum.attr import IntegratedGradients
import copy 

def get_feature_importance(net, target):
    ig = IntegratedGradients(net)
    inputs = torch.tensor(analyzer.states, dtype=torch.float32).requires_grad_()
    attributions, _ = ig.attribute(inputs, target=target, return_convergence_delta=True)
    attributions = attributions.detach().numpy()
    importances = np.mean(attributions, axis=(0, 1))
    
    return importances

def visualize_importances(untrained_value_net, trained_value_net):
    fig = make_subplots(rows=len(actions_map), cols=2, subplot_titles=("Untrained", "Trained"), shared_yaxes=True)

    # Action slower
    untrained_importances = get_feature_importance(untrained_value_net, 0)
    trained_importances = get_feature_importance(trained_value_net, 0)

    fig.add_trace(go.Bar(x=features, y=untrained_importances, name="Untrained"), row=1, col=1)
    fig.add_trace(go.Bar(x=features, y=trained_importances, name="Trained"), row=1, col=2)

    # Action idle
    untrained_importances = get_feature_importance(untrained_value_net, 1)
    trained_importances = get_feature_importance(trained_value_net, 1)
    
    fig.add_trace(go.Bar(x=features, y=untrained_importances, name="Untrained"), row=2, col=1)
    fig.add_trace(go.Bar(x=features, y=trained_importances, name="Trained"), row=2, col=2)

    # Action faster
    untrained_importances = get_feature_importance(untrained_value_net, 2)
    trained_importances = get_feature_importance(trained_value_net, 2)
    
    fig.add_trace(go.Bar(x=features, y=untrained_importances, name="Untrained"), row=3, col=1)
    fig.add_trace(go.Bar(x=features, y=trained_importances, name="Trained"), row=3, col=2)

    fig.update_layout(title_text=f"Feature Importances")
    fig.show()

value_net_copy = copy.deepcopy(analyzer.agent.value_net)
visualize_importances(untrained_agent.value_net, value_net_copy)

On studying what features on average are learned by the model, show that model indeed learns to track key important features like 
- Presence - Whether a vehicle is present on the road or not
- x - X coordinate of the vehicle
- y - Y coordinate of the vehicle
- cosh - Trignometric cosine angle of the heading of the vehicle

The average importance of learned feature is compared with an untrained agent copy to verify the results.

Since the agent has 2 different embedding layers -
Ego embedding layer - reponsible for tracking features for Driving agent in question
Others embedding layer - responsible for tracking features for Other vehicles 

It would be better to analyze features learnt individually in each of the embedding layers.

### Analyze Embedding matrices

Let's find what Embedding matrices represent in both Ego and Other embedding layers

In [None]:
def compare_ego_embeddings_matrix(untrained_agent, trained_agent):
    untrained_ego_embeddings_wt = analyze_embedding_weights(untrained_agent.value_net.ego_embedding)
    trained_ego_embeddings_wt = analyze_embedding_weights(trained_agent.value_net.ego_embedding)

    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
    fig.add_trace(go.Heatmap(z=untrained_ego_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_ego_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_square_diff(trained_ego_embeddings_wt, untrained_ego_embeddings_wt), coloraxis="coloraxis", y=features), row=1, col=3)
    fig.update_layout(
        title_text="Ego Embeddings Matrix", 
        coloraxis=dict(colorscale="RdBu"), 
    )
    fig.show()

def compare_others_embeddings_matrix(untrained_agent, trained_agent):
    untrained_others_embeddings_wt = analyze_embedding_weights(untrained_agent.value_net.others_embedding)
    trained_others_embeddings_wt = analyze_embedding_weights(trained_agent.value_net.others_embedding)

    features = ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"]
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
    fig.add_trace(go.Heatmap(z=untrained_others_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_others_embeddings_wt, coloraxis="coloraxis", y=features), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_square_diff(trained_others_embeddings_wt, untrained_others_embeddings_wt), coloraxis="coloraxis", y=features), row=1, col=3)
    fig.update_layout(
        title_text="Others Embeddings Matrix", 
        coloraxis=dict(colorscale="RdBu"), 
    )
    fig.show()

compare_ego_embeddings_matrix(untrained_agent, analyzer.agent)
compare_others_embeddings_matrix(untrained_agent, analyzer.agent)

# TODO: Refine the observations collected here
Above graphs show that the agent's ego embedding layer (responsible for tracking observations of itself) has learnt to assign more weights to features - `y`, `vx`, `sinh`.
On the other hand, agent's others embedding layer (responsible for tracking observations of other vehicles) has learnt to assign more weights to the feature `x`, `y`, `vy` and `sinh`.

Qualitatively, that seems reasonable. In order for the agent to make a successful turn at the intersection, the agent needs to track following key information:
- Vertical position of itself. Most likely the agent has learnt to slow down as it's `y` coordinate (increases) reaches closer to the intersection, conversely speed up when it is far away from intersection.
- Horizontal velocity of the other vehicles `vx`. This seems more important as agent learns to wait at the intersection when a vehicle is crossing the intersection at high speed. Thereby increasing its chances of avoiding collision.
- Other features like `presence` and `x` also show medium level weights values.

Studying feature importance using Layer conductance


In [81]:
for k in analyzer.activations.keys():
    print(k)

ego_embedding.layers.0
ego_embedding.layers.1
others_embedding.layers.0
others_embedding.layers.1
attention_layer.key_all
attention_layer.value_all
attention_layer.query_ego
attention_layer.attention_combine
output_layer.layers.0
output_layer.layers.1
output_layer.predict


### Layer conductance

In [118]:
from captum.attr import LayerConductance
import copy 

def get_layer_conductance(net, layer, target):
    states = torch.tensor(analyzer.states, dtype=torch.float32)
    cond = LayerConductance(net, layer)
    cond_vals = cond.attribute(states, target=target)
    dims = tuple(list(range(len(cond_vals.shape))))
    return np.mean(cond_vals.detach().numpy(), axis=dims[:-1])


def analyze_layer_conductances(net, layer, layer_name):
    fig = make_subplots(rows=len(actions_map), cols=1)

    # Action slower
    cond_vals = get_layer_conductance(net, layer, 0)
    fig.add_trace(go.Bar(y=cond_vals, name=f"Neuron activations for Action {actions_map[0]}"), row=1, col=1)
    
    # Action idle
    cond_vals = get_layer_conductance(net, layer, 1)
    fig.add_trace(go.Bar(y=cond_vals, name=f"Neuron activations for Action {actions_map[1]}"), row=2, col=1)

    # Action faster
    cond_vals = get_layer_conductance(net, layer, 2)
    fig.add_trace(go.Bar(y=cond_vals, name=f"Neuron activations for Action {actions_map[2]}"), row=3, col=1)

    fig.update_layout(title_text=f"Layer Conductances for layer: {layer_name}")
    fig.show()

value_net_copy = copy.deepcopy(analyzer.agent.value_net)

analyze_layer_conductances(value_net_copy, value_net_copy.ego_embedding.layers[0], "Ego embedding layer 0")
analyze_layer_conductances(value_net_copy, value_net_copy.ego_embedding.layers[1], "Ego embedding layer 1")

analyze_layer_conductances(value_net_copy, value_net_copy.others_embedding.layers[0], "Others embedding layer 0")
analyze_layer_conductances(value_net_copy, value_net_copy.others_embedding.layers[1], "Others embedding layer 1")
# analyze_layer_conductances(value_net_copy, value_net_copy.output_layer.predict, "predict")

analyze_layer_conductances(value_net_copy, value_net_copy.output_layer.layers[0], "Output layer 0")
analyze_layer_conductances(value_net_copy, value_net_copy.output_layer.layers[1], "Output layer 1")
analyze_layer_conductances(value_net_copy, value_net_copy.output_layer.predict, "predict")

Embedding layers - 
Studying the activation patterns of Embedding layers, reveal that both Ego and Others embedding layers learn to assign weights to same set of neurons.
Their activation patterns are almost identical irrespective of the action taken by the model. This makes sense, as embedding layers can be thought of as sensory layers which do not encapsulate agent's decision making process but only as filters of the key important features in the environment.

Output layers - 
After studying the activation patterns of output layers, it shows that the activation patterns of these layers differ minutely on different actions taken. 


### Neuron conductance

In [136]:
from captum.attr import NeuronConductance

def get_layer_conductance(net, layer, neuron, target):
    states = torch.tensor(analyzer.states, dtype=torch.float32)
    states.requires_grad_()
    neuron_cond = NeuronConductance(net, layer)
    print(states.shape)
    neuron_cond_vals = neuron_cond.attribute(states, neuron_selector=neuron, target=target)
    neuron_cond_vals = neuron_cond_vals.detach().numpy()
    dims = tuple(list(range(len(neuron_cond_vals.shape))))
    return np.mean(neuron_cond_vals.detach().numpy(), axis=dims[:-1])

def analyze_neuron_conductances(net, layer, layer_name, neuron, target):
    fig = make_subplots(rows=len(actions_map), cols=1)

    # Action slower
    cond_vals = get_layer_conductance(net, layer, neuron, target)
    print(cond_vals.shape)
    fig.add_trace(go.Bar(x=features, y=cond_vals), row=1, col=1)
    fig.update_layout(title_text=f"Neuron Conductances for layer: {layer_name} and neuron: {neuron}")
    fig.show()

# analyze_neuron_conductances(value_net_copy, value_net_copy.ego_embedding.layers[0], "Ego embedding layer 0", 0, 0)


In [138]:
def analyze_neuron(analyzer, layer_name, neurons_range):
    # Extract activations for the neuron
    neuron_activations = np.array([a[0,0,neurons_range].detach().numpy() for a in analyzer.activations[layer_name]])
    neuron_activations = neuron_activations.transpose()

    sorted_neurons = []
    for i, activations in enumerate(neuron_activations):
        sorted_neurons.append(())
    
    plot_neuron_activations(neuron_activations, f"Neuron activations for: {layer_name}", [f"neuron {i}" for i in neurons_range])

def plot_neuron_activations(neuron_activations, title, names):
    # plot activations for multiple neurons
    fig = go.Figure()
    for i, activations in enumerate(neuron_activations):
        activations = activations - activations.mean()
        fig.add_trace(go.Scatter(y=activations, mode='lines', name=names[i]))
    fig.update_layout(title=title, xaxis_title="Time step", yaxis_title="Activation")
    fig.show()
    pass

analyze_neuron(analyzer, "ego_embedding.layers.0", range(64))
analyze_neuron(analyzer, "ego_embedding.layers.1", range(64))
analyze_neuron(analyzer, "others_embedding.layers.0", range(64))
analyze_neuron(analyzer, "others_embedding.layers.1", range(64))
# analyze_neuron(analyzer, "attention_layer.key_all", range(64))

In [None]:
from skimage.transform import resize

def normalize_and_convert(image):
    # Normalize the image to the range [0, 1]
    image = (image - np.min(image)) / (np.max(image) - np.min(image))
    # Convert the image to the range [0, 255]
    image = (image * 255).astype(np.uint8)
    return image

def plot_actions(analyzer):
    images = analyzer.images
    actions = analyzer.actions
    attention_matrix = analyzer.attention_matrix
    for i, image in enumerate(images):
        fig = make_subplots(rows=1, cols=2)
        img = resize(image, (300, 600), anti_aliasing=True)
        img = normalize_and_convert(img)
        fig.add_trace(go.Image(z=img), row=1, col=1)
        fig.add_trace(go.Heatmap(z=attention_matrix[i].squeeze(0).squeeze(1), coloraxis="coloraxis"), row=1, col=2)
        fig.update_layout(coloraxis=dict(colorscale='RdBu'), title_text=f"Step: {i+1} Action: {actions_map[actions[i]]}")
        fig.show()

plot_actions(analyzer)

### MLP Layer analysis

In [None]:
def analyze_mlp_hidden_weights(agent):
    output_layer = agent.value_net.output_layer
    wt_layer_0 = output_layer.layers[0].weight.detach().T
    wt_layer_1 = output_layer.layers[1].weight.detach()

    return torch.matmul(wt_layer_0, wt_layer_1)

def compare_mlp_layer(untrained_agent, trained_agent):
    untrained_mlp_layer = analyze_mlp_hidden_weights(untrained_agent)
    trained_mlp_layer = analyze_mlp_hidden_weights(trained_agent)

    fig = make_subplots(rows=1, cols=3, subplot_titles=("Untrained", "Trained", "Difference"))
    fig.add_trace(go.Heatmap(z=untrained_mlp_layer, coloraxis="coloraxis"), row=1, col=1)
    fig.add_trace(go.Heatmap(z=trained_mlp_layer, coloraxis="coloraxis"), row=1, col=2)
    fig.add_trace(go.Heatmap(z=compute_diff(trained_mlp_layer, untrained_mlp_layer), coloraxis="coloraxis"), row=1, col=3)
    fig.update_layout(title_text="MLP Layer", coloraxis=dict(colorscale="RdBu"))
    fig.show()

compare_mlp_layer(untrained_agent, analyzer.agent)

### Perturbing activations

In [None]:
# Function to perturb activations
def perturb_activations(activations, perturbation):
    return activations + perturbation

def analyze_final_layer_perturbations(analyzer):
    for i in range(len(analyzer.activations["output_layer.layers.1"])):
        input = analyzer.activations["output_layer.layers.1"][i].detach()
        org_out = analyzer.agent.value_net.output_layer.predict(input).detach()
        
        pert_in = perturb_activations(input, 0.1)
        pert_out = analyzer.agent.value_net.output_layer.predict(pert_in).detach()
        
        diff = torch.norm(org_out - pert_out, p=2).item()
        print(f"Step: {i}, Perturbation: {diff}")

analyze_final_layer_perturbations(analyzer)

### Logit Lens

In [83]:
### TODO apply logit lens technique here ###

## SHAP technique

In [None]:
### TODO apply shapley values here ###

### Activation analysis

In [None]:
def attn_scores(query: torch.Tensor, key: torch.Tensor):
    d_k = query.size(-1)
    return F.softmax(torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(d_k), dim=-1)

```
Total reward: 10.0
States: 11
Actions: 10
Attention matrices: 10
images: 10
Activations for ego_embedding.layers.0: 330
Activations for ego_embedding.layers.1: 330
Activations for others_embedding.layers.0: 330
Activations for others_embedding.layers.1: 330
Activations for attention_layer.key_all: 330
Activations for attention_layer.value_all: 330
Activations for attention_layer.query_ego: 330
Activations for attention_layer.attention_combine: 330
Activations for output_layer.layers.0: 310
Activations for output_layer.layers.1: 310
Activations for output_layer.predict: 310
```

In [None]:
def analyze_activations(analyzer):
    activations = analyzer.activations

    q_acts = activations['attention_layer.query_ego']
    k_acts = activations['attention_layer.key_all']
    v_acts = activations['attention_layer.value_all']

    for i in range(150, 160):
        fig = make_subplots(rows=1, cols=2, subplot_titles=("Attn scores", "Scene"))
        scores = attn_scores(q_acts[i].repeat(1, k_acts[i].size(1), 1), k_acts[i]).squeeze(0).numpy()
        fig.add_trace(go.Heatmap(z=scores, coloraxis="coloraxis"), row=1, col=1)
        fig.add_trace(go.Image(z=analyzer.images[i//30]), row=1, col=2)
        fig.update_layout(title_text=f"Step: {i}", coloraxis=dict(colorscale="RdBu"))
        fig.show()


def analyze_actions(analyzer):
    predict_acts = analyzer.activations["output_layer.predict"]
    critical_actions = [
        0, # slower 
        1, # idle
    ]
    critical_steps = []
    print(analyzer.actions)
    for step, action in enumerate(analyzer.actions):
        if action in critical_actions:
            critical_steps.append((step-1, action))

    for step, action in critical_steps:
        px.imshow(analyzer.images[step], title=f"Step: {step}").show()
        # print(IntersectionEnv.ACTIONS[action], predict_acts[step])
        px.imshow(predict_acts[step], title=f"Predicted actions: {step}").show()

analyze_activations(analyzer)
# analyze_actions(analyzer)