In [None]:
import glob
import os
import pickle
import random
import warnings
from typing import List, Dict, Tuple, Callable

import numpy as np
import plotly.graph_objects as go
import torch
from numpy.linalg import norm
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

warnings.filterwarnings('ignore')

layer_of_interest = [1, 2, 3, 5, 10, 15, 18]

COLORS = {
    1: (255, 127, 14), # orange
    3: (44, 160, 44), # green
    18: (214, 39, 40), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

COLORS_DARKER = {
    1: (229, 118, 20), # orange
    3: (31, 138, 44), # green
    18: (176, 32, 32), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

RAINBOW_COLORS = [
    "rgb(244, 67, 54)",
    "rgb(232, 30, 99)",
    "rgb(192, 35, 120)",
    "rgb(169, 37, 150)",
    "rgb(156, 39, 176)",
    "rgb(103, 58, 183)",
    "rgb(63, 81, 181)",
    "rgb(33, 150, 243)",
    "rgb(3, 169, 244)",
    "rgb(0, 188, 212)",
    "rgb(0, 150, 136)",
    "rgb(76, 175, 80)",
    "rgb(139, 195, 74)",
    "rgb(205, 220, 57)",
    "rgb(255, 235, 59)",
    "rgb(255, 193, 7)",
    "rgb(255, 152, 0)",
    "rgb(255, 87, 34)"
]

FONT_SIZE_TITLE = 24
FONT_SIZE_AXIS = 22
FONT_SIZE_TICKS = 20
FONT_SIZE_LEGEND = 20

# Analyze Causes
Why are the results as they are? In this section we try to answer some question by analyzing the hidden states and attention patterns of the TrXL in different layers. To this end, a trained SHELM model was used to sample several example episodes of the Memory and Psychlab environments. During sampling the hidden states, attention distributions and rewards were saved and later visualized/ analyzed.

In [None]:
np.random.seed(101)
torch.cuda.manual_seed(101)
random.seed(101)

observations_path = os.path.join("data", "observations")

if os.path.exists(observations_path):
    
    observations_psych = sorted(
    glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep1", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep1", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_psych2 = sorted(
        glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep2", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem2 = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep2", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_psych3 = sorted(
        glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep3", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem3 = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep3", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
else:
    raise FileNotFoundError(f"Ensure the observations are downloaded with the script in `data/download_results.sh`")

In [None]:
def cosine_similarity(a, b):
    return np.dot(a,b) / (norm(a) * norm(b))

def load_model_output(env_v):
    """
    Load the model output data from a pickle file.

    Args:
        env_v (str): The environment version.

    Returns:
        Tuple[Dict[int, List[float]], Dict[int, List[float]], List[np.ndarray], Dict[int, List[float]], List[str]]:
            A tuple containing the following data:
            - cosine_distances (Dict[int, List[float]]): The cosine distances for all evaluated layers.
            - l2_distances (Dict[int, List[float]]): The L2 distances for all evaluated layers.
            - hidden_states (List[np.ndarray]): The hidden states for each timestep.
            - attentions (Dict[int, List[float]]): The attention weights for all evaluated layers.
            - tokens (List[str]): The tokens for the sampled episode.

    Raises:
        FileNotFoundError: If the data file does not exist.

    Note:
        Ensure the data is downloaded with the script in `data/download_results.sh`.
    """
    data_path = os.path.join("data", "investigation-data", f"investigation-data-{env_v}.pkl")
    if os.path.exists(data_path):
        with open(data_path, "rb") as f:
            data = pickle.load(f)
            cosine_distances = data["cosine_distances"]
            l2_distances = data["l2_distances"]
            hidden_states = data["hidden_states"]
            attentions = data["attention"]
            tokens = data["tokens"]
    else:
        raise FileNotFoundError(f"Ensure the data is downloaded with the script in `data/download_results.sh`")

    return cosine_distances, l2_distances, hidden_states, attentions, tokens

In [None]:
cosine_distances_mem, l2_distances_mem, hidden_states_mem, attention_mem, tokens_mem = \
load_model_output("mem")

In [None]:
cosine_distances_mem2, l2_distances_mem2, hidden_states_mem2, attention_mem2, tokens_mem2 = \
load_model_output("mem2")

In [None]:
cosine_distances_mem3, l2_distances_mem3, hidden_states_mem3, attention_mem3, tokens_mem3 = \
load_model_output("mem3")

In [None]:
cosine_distances_psy, l2_distances_psy, hidden_states_psy, attention_psy, tokens_psy = \
load_model_output("psy")

In [None]:
cosine_distances_psy2, l2_distances_psy2, hidden_states_psy2, attention_psy2, tokens_psy2 = \
load_model_output("psy2")

In [None]:
cosine_distances_psy3, l2_distances_psy3, hidden_states_psy3, attention_psy3, tokens_psy3 = \
load_model_output("psy3")

## Hidden States

### Raw Hidden States
The raw hidden states can be checked visually by plotting them as heatmap. Because there are hidden layers for all layers at every timestep the visualizations can be grouped in two ways:
- by timestep - then the hidden states of all timesteps are shown next to each other
- by layer - then the hidden states over all timesteps of one layer are shown next to each other

Furthermore we include a Histogram that shows the distribution of the values in the hidden states for each layer at every timestep.

In [None]:
def plot_hidden_states_slider(
        hidden_states: List[np.ndarray], 
        group_by: str = "layer"
) -> go.Figure:
    """
    Plot hidden states grouped either by 'layer' or 'timestep'.

    Parameters:
    - hidden_states (List[np.ndarray]): A list of numpy arrays representing the hidden states.
    - group_by (str, optional): The grouping method for the hidden states. Defaults to "layer".
    
    Returns:
    - go.Figure: A Plotly figure object displaying the hidden states with a slider.
    """
    if group_by == "layer":
        titles = [f"Hidden States of Layer {i}" for i in range(1,19)]
        num_cols = 18
        x_title = "Timesteps"
        x = [i for i in range(1, len(hidden_states)+1)]
        tickvals = [1,  5,  8, 12]
        
    elif group_by == "timestep":
        titles = [f"Hidden States at Timestep {i}" for i in range(1,len(hidden_states)+1)]
        num_cols = len(hidden_states)
        x_title = "Layers"
        x = [i for i in range(1, 19)]
        tickvals = [1,  6,  12, 18]

    fig = go.Figure()
    
    data = np.array(hidden_states)
    for step in range(num_cols):
        fig.add_trace(
            go.Heatmap(
                visible=False,
                z=data[:,step,:].T if group_by == "layer" else data[step,:,:].T,
                zmin=-2,
                zmax=2,
                x=x,
                xgap=1,
                colorscale='Inferno',
                colorbar_tickfont_size=FONT_SIZE_LEGEND
            )
        )
    
    # Make 0th trace visible
    fig.data[0].visible = True
    
    # Create and add slider
    steps = []
    for i in range(num_cols):
        step = dict(
            method="update",
            label=f"{group_by.capitalize()} {i+1}",
            args=[{"visible": [False] * len(fig.data)},
                  {"title": titles[i]}],
        )

        step["args"][0]["visible"][i] = True
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            title=x_title,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )

    return fig

In [None]:
def plot_hidden_states_distribution(
        hidden_states: List[np.ndarray], 
        title: str
) -> go.Figure:
    """
    Plot the distribution of hidden states over time using a slider.

    Parameters:
    - hidden_states (List[np.ndarray]): A list of numpy arrays representing the hidden states.
    - title (str): The title of the plot.
    
    Returns:
    - go.Figure: A Plotly figure object displaying the distribution of hidden states over time with a slider.
    """
    fig = go.Figure()
    
    
    for step in range(len(hidden_states)):
        for n in layer_of_interest:
            fig.add_trace(
                go.Histogram(
                    visible=False,
                    x=hidden_states[step][n-1],
                    name=f"Layer {n}",
                    marker_color=f"rgb{COLORS[n]}"
                )
            )
    
    # Make 0th trace visible
    for i in range(len(layer_of_interest)):
        fig.data[i].visible = True
    
    # Create and add slider
    steps = []
    for i in range(len(hidden_states)):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)},
                  {"legend_traceorder": "normal"}],  # layout attribute
        )
        # Calculate the start index for the visible traces in this step
        start_index = i * len(layer_of_interest)
        end_index = start_index + len(layer_of_interest)
        
        for j in range(start_index, end_index):
            step["args"][0]["visible"][j] = True  # Toggle i'th trace to "visible"
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            range=[-2.,2.],
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.15,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            traceorder="normal"
        ),
        barmode='overlay',
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )

    fig.update_traces(opacity=0.55)
    
    return fig

In [None]:
def plot_hidden_states_heatmap_box_combined(
        hidden_states: List[np.ndarray],
        title: str
) -> go.Figure:
    """
    Plot a heatmap and box plot combined for each hidden state in a sequence.

    Args:
        hidden_states (List[np.ndarray]): A list of 3D numpy arrays representing the hidden states.
        title (str): The title of the plot.

    Returns:
        go.Figure: A plotly figure object containing the heatmap and box plot. 
    """
    x = [i for i in range(1, len(hidden_states)+1)]

    fig = make_subplots(
        rows=2, cols=1, shared_yaxes=False, shared_xaxes=True,
        vertical_spacing=0.05
    )

    data = np.array(hidden_states)

    n_steps, n_layers, _ = data.shape

    for step in range(n_steps):
        fig.add_trace(
            go.Heatmap(
                z=data[step,:,:].T,
                zmin=-2,
                zmax=2,
                x=x,
                xgap=5,
                visible=False,
                colorscale='Inferno',
                colorbar=dict(
                    len=0.5,
                    y=0.77,
                    tickfont_size=FONT_SIZE_LEGEND
                )
            ),
            col=1,
            row=1
        )
            
        for idx, hidden in enumerate(hidden_states[step]):
            fig.add_trace(
                go.Box(
                    name=f"Layer {idx+1}",
                    y=hidden,
                    marker_color=RAINBOW_COLORS[idx],
                    visible=False,
                ),
                col=1,
                row=2
            )


    # Make 0th trace visible
    for i in range(n_layers+1):
        fig.data[i].visible = True
    
    # Create and add slider
    steps = []
    for i in range(n_steps):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)},
                  {"legend_traceorder": "normal"}],  # layout attribute
        )
        # Calculate the start index for the visible traces in this step
        start_index = i * (n_layers + 1)
        end_index = start_index + (n_layers + 1)
        
        for j in range(start_index, end_index):
            step["args"][0]["visible"][j] = True  # Toggle i'th trace to "visible"
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis2=dict(
            
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            visible=False,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis2=dict(
            range=[-2.2,2.2],
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        showlegend=False,
        barmode='overlay',
        height=800,
        width=1100,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )

    return fig

In [None]:
def plot_hidden_state_box(
        hidden_states: List[np.ndarray],
        title: str
) -> go.Figure:
    """
    Plot a box plot for each hidden state in a sequence to show the distribution of values in the hidden state.

    Args:
        hidden_states (List[np.ndarray]): A list of 2D numpy arrays representing the hidden states.
        title (str): The title of the plot.

    Returns:
        go.Figure: A plotly figure object containing the box plot. 
    """
    fig = go.Figure()
    
    
    for step in range(len(hidden_states)):
        for idx, hidden in enumerate(hidden_states[step]):
            fig.add_trace(
                go.Box(
                    name=f"Layer {idx+1}",
                    y=hidden,
                    marker_color=RAINBOW_COLORS[idx],
                    visible=False
                )
            )

    # Make 0th trace visible
    for i in range(len(hidden_states[0])):
        fig.data[i].visible = True
    
    # Create and add slider
    steps = []
    for i in range(len(hidden_states)):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)},
                  {"legend_traceorder": "normal"}],  # layout attribute
        )
        # Calculate the start index for the visible traces in this step
        start_index = i * len(hidden_states[0])
        end_index = start_index + len(hidden_states[0])
        
        for j in range(start_index, end_index):
            step["args"][0]["visible"][j] = True  # Toggle i'th trace to "visible"
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 70},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            range=[-2.2,2.2],
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        showlegend=False,
        barmode='overlay',
        height=700,
        width=1100,
        margin=dict(
            l=10,
            r=20,
            b=50,
            t=40,
            pad=10
        ),
    )
    
    return fig

#### Memory Environment
##### Episode 1


In [None]:
plot_hidden_states_slider(hidden_states_mem, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_mem, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_mem, "Distribution of Hidden States - Memory").show()

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_mem, "Hidden States - Memory").show()

##### Episode 2
###### Grouped by Timestep

In [None]:
plot_hidden_states_slider(hidden_states_mem2, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_mem2, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_mem2, "Distribution of Hidden States - Memory")

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_mem2, "Hidden States - Memory").show()

##### Episode 3
###### Grouped by Timestep

In [None]:
plot_hidden_states_slider(hidden_states_mem3, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_mem3, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_mem3, "Distribution of Hidden States - Memory")

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_mem3, "Hidden States - Memory").show()

#### Continuous Recognition Environment
##### Episode 1
###### Grouped by Timestep

In [None]:
plot_hidden_states_slider(hidden_states_psy, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_psy, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_psy, "Distribution of Hidden States - Continuous Recognition")

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_psy, "Hidden States - Continuous Recognition").show()

##### Episode 2
###### Grouped by Timestep

In [None]:
plot_hidden_states_slider(hidden_states_psy2, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_psy2, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_psy2, "Distribution of Hidden States - Continuous Recognition")

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_psy2, "Hidden States - Continuous Recognition").show()

##### Episode 3
###### Grouped by Timestep

In [None]:
plot_hidden_states_slider(hidden_states_psy2, "timestep").show()

###### Grouped by Layer

In [None]:
plot_hidden_states_slider(hidden_states_psy2, "layer").show()

###### Value Distribution

In [None]:
plot_hidden_states_distribution(hidden_states_psy2, "Distribution of Hidden States - Continuous Recognition")

In [None]:
plot_hidden_states_heatmap_box_combined(hidden_states_psy3, "Hidden States - Continuous Recognition").show()

### Distance/ Similarity between Hidden States
Visually inspecting the hidden states may not reveal much new insights, therefore we try to compare them in more detail by calculating various distance/ similarity measurements and checking if any pattern are present.
<br>
Similar to before, the similarites can be calculated between two groupings:
- timesteps - checking how similar the same layer is between two consecutive timesteps
- layers - checking how similar two layers are at the same timestep

The comparison between timesteps is visualized as a simple line plot showing the distance/ similarity between the same layer at consecutive timesteps.<br>
The comparison between layers is shown as a heatmap, as this allows to check the distance/ similarity between all possible layers at a glance.

In [None]:
def plot_distance(
        data: Dict[int, List[float]],
        dist_type: str,
        title: str
) -> go.Figure:
    """
    Plot a line plot showing the cosine distance between the same layer at consecutive timesteps.

    Args:
        data (Dict[int, List[float]]): A dictionary where the keys are the layer numbers and the values are lists of floats representing the distance data for each layer at each step.
        dist_type (str): The type of distance being plotted.
        title (str): The title of the plot.

    Returns:
        go.Figure: A plotly figure object containing the line plot.
    """
    fig = go.Figure()
    
    c_idx = 0
    
    for n_layer, y in data.items():
        fig.add_trace(
            go.Scatter(
                name=f"Layer {n_layer}",
                x=[i for i in range(1, len(y)+1)],
                y=y,
                mode="lines+markers",
                line=dict(
                    color=f"rgb{str(COLORS[int(n_layer)])}"
                )
            )
        )
        c_idx += 1

    for i in [1, 10]:
        fig.update_traces(legendgroup='group1', selector=dict(name=f'Layer {i}'))
    for i in [2, 15]:
        fig.update_traces(legendgroup='group2', selector=dict(name=f'Layer {i}'))
    for i in [3, 18]:
        fig.update_traces(legendgroup='group3', selector=dict(name=f'Layer {i}'))
    for i in [5]:
        fig.update_traces(legendgroup='group4', selector=dict(name=f'Layer {i}'))
    
    fig.update_layout(
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            title="Interaction Step",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            title=dist_type,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.25,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )

    return fig

In [None]:
def plot_distance_heatmap(
        hidden_states: List[np.ndarray],
        environment: str
) -> go.Figure:
    """
   Plot a heatmap of cosine similarity between hidden states of layer at the same timestep.

    Args:
        hidden_states (List[np.ndarray]): A list of 2D numpy arrays representing the hidden states.
        environment (str): The name of the environment.

    Returns:
        go.Figure: A plotly figure object containing the heatmap.
    """

    title = "MiniGrid - Memory" if "Memory" in environment else "PsychLab - Continuous Recognition"

    fig = go.Figure()
    
    for step in range(len(hidden_states)):
        data = np.zeros((18,18))
        for i, hidden_i in enumerate(hidden_states[step]):
            for j, hidden_j in enumerate(hidden_states[step]):
                data[i,j] = cosine_similarity(hidden_i, hidden_j)
    
            
        fig.add_trace(
            go.Heatmap(
                visible=False,
                z=data,
                zmin=0.,
                zmax=1.,
                x=[f"Layer {i+1}" for i in range(18)],
                y=[f"Layer {i+1}" for i in range(18)],
                colorscale='Inferno',
                colorbar_tickfont_size=FONT_SIZE_LEGEND
            )
        )
    
    # Make 0th trace visible
    fig.data[0].visible = True
    
    # Create and add slider
    steps = []
    for i in range(len(hidden_states)):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)}],
        )

        step["args"][0]["visible"][i] = True
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            ticksuffix="  ",
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )

    return fig

#### Memory Environment
##### Episode 1

In [None]:
sim_line_mem = plot_distance(cosine_distances_mem, "Cosine Similarity", "MiniGrid - Memory")

sim_line_mem.show()

In [None]:
sim_heat_mem = plot_distance_heatmap(hidden_states_mem, "Memory")

sim_heat_mem.show()

##### Episode 2

In [None]:
plot_distance(cosine_distances_mem2, "Cosine Similarity", "Cosine Similarity - Memory").show()

In [None]:
plot_distance_heatmap(hidden_states_mem2, "Memory").show()

##### Episode 3

In [None]:
plot_distance(cosine_distances_mem3, "Cosine Similarity", "Cosine Similarity - Memory").show()

In [None]:
plot_distance_heatmap(hidden_states_mem3, "Memory").show()

#### Continuous Recognition Environment
##### Episode 1

In [None]:
plot_distance(cosine_distances_psy, "Cosine Similarity", "Cosine Similarity - Continuous Recognition").show()

In [None]:
plot_distance_heatmap(hidden_states_psy, "Continuous Recognition").show()

In [None]:
sim_line_psy = plot_distance(cosine_distances_psy3, "Cosine Similarity", "PsychLab - Continuous Recognition")

sim_line_psy.show()

##### Episode 2

In [None]:
plot_distance(cosine_distances_psy2, "Cosine Similarity", "Cosine Similarity - Continuous Recognition").show()

In [None]:
plot_distance_heatmap(hidden_states_psy2, "Continuous Recognition").show()

##### Episode 3

In [None]:
plot_distance(cosine_distances_psy3, "Cosine Similarity", "Cosine Similarity - Continuous Recognition").show()

In [None]:
plot_distance_heatmap(hidden_states_psy3, "Continuous Recognition").show()

### Dimensionality Reduction
To further investiage how similar certain hidden states are, we apply tSNE and PCA and try to highlight further similarities between layers.

In [None]:
def plot_dimensionality_reduction(
        hidden_states: List[np.ndarray],
        reduction_func: Callable,
        title: str
) -> go.Figure:
    """
    Plot the dimensionality reduction of hidden states for each timestep.

    Args:
        hidden_states (List[np.ndarray]): A list of 2D numpy arrays representing the hidden states.
        reduction_func (Callable): A function that performs the dimensionality reduction.
        title (str): The title of the plot.

    Returns:
        go.Figure: A plotly figure object containing the dimensionality reduction plot.
    """
    fig = go.Figure()
    
    x_min, x_max, y_min, y_max = 0, 0, 0, 0
    for step in range(len(hidden_states)):
        projections = reduction_func(hidden_states[step])
        
        x_min_ = int(round(np.min(projections[:,0])/5-.5, 0)*5)
        x_max_ = int(round(np.max(projections[:,0])/5+.5, 0)*5)
        y_min_ = int(round(np.min(projections[:,1])/5-.5, 0)*5)
        y_max_ = int(round(np.max(projections[:,1])/5+.5, 0)*5)

        if x_min_ < x_min:
            x_min = x_min_
        if x_max_ > x_max:
            x_max = x_max_
        if y_min_ < y_min:
            y_min = y_min_
        if y_max_ > y_max:
            y_max = y_max_
    
        fig.add_trace(
            go.Scatter(
                x=projections[:,0],
                y=projections[:,1],
                mode="markers+text",
                marker=dict(
                        size=13,
                        color=RAINBOW_COLORS,
                    ),
                text=[f"Layer {i}" for i in range(1,len(projections)+1)],
                textfont_size=14,
                textposition="top center",
                visible=False
            )
        )
    
    # Make 0th trace visible
    fig.data[0].visible = True
    
    # Create and add slider
    steps = []
    for i in range(len(hidden_states)):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)}]
        )

        step["args"][0]["visible"][i] = True 

        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    x_range = [x_min, x_max]
    y_range = [y_min, y_max]
    fig.update_xaxes(range=x_range)
    fig.update_yaxes(range=y_range)
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.989,
        ),
        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.15,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            traceorder="normal",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=40,
            b=20,
            t=40,
            pad=10
        ),
    )
    return fig

In [None]:
tsne = TSNE(n_components=2, random_state=42, perplexity=5)
pca = PCA(n_components=2)

#### Memory Environment
##### Episode 1

In [None]:
plot_dimensionality_reduction(hidden_states_mem, tsne.fit_transform, "tSNE - Memory").show()

In [None]:
plot_dimensionality_reduction(hidden_states_mem, pca.fit_transform, "PCA - Memory").show()

##### Episode 2

In [None]:
plot_dimensionality_reduction(hidden_states_mem2, tsne.fit_transform, "tSNE - Memory").show()

In [None]:
plot_dimensionality_reduction(hidden_states_mem2, pca.fit_transform, "PCA - Memory").show()

##### Episode 3

In [None]:
plot_dimensionality_reduction(hidden_states_mem3, tsne.fit_transform, "tSNE - Memory").show()

In [None]:
plot_dimensionality_reduction(hidden_states_mem3, pca.fit_transform, "PCA - Memory").show()

#### Continuous Recognition Environment
##### Episode 1

In [None]:
plot_dimensionality_reduction(hidden_states_psy, tsne.fit_transform, "tSNE - Continuous Recognition").show()

In [None]:
plot_dimensionality_reduction(hidden_states_psy, pca.fit_transform, "PCA - Continuous Recognition").show()

##### Episode 2

In [None]:
plot_dimensionality_reduction(hidden_states_psy2, tsne.fit_transform, "tSNE - Continuous Recognition").show()

In [None]:
plot_dimensionality_reduction(hidden_states_psy2, pca.fit_transform, "PCA - Continuous Recognition").show()

##### Episode 3

In [None]:
plot_dimensionality_reduction(hidden_states_psy3, tsne.fit_transform, "tSNE - Continuous Recognition").show()

In [None]:
plot_dimensionality_reduction(hidden_states_psy3, pca.fit_transform, "PCA - Continuous Recognition").show()