In [None]:
import glob
import os
import pickle
import random
import warnings
from typing import List, Dict, Tuple

import numpy as np
import plotly.graph_objects as go
import torch
from numpy.linalg import norm
from plotly.subplots import make_subplots

warnings.filterwarnings('ignore')

layer_of_interest = [1, 2, 3, 5, 10, 15, 18]

COLORS = {
    1: (255, 127, 14), # orange
    3: (44, 160, 44), # green
    18: (214, 39, 40), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

COLORS_DARKER = {
    1: (229, 118, 20), # orange
    3: (31, 138, 44), # green
    18: (176, 32, 32), # red #D62728
    2: (148, 103, 189), # purple #9467BD
    5: (31, 119, 180), # blue #1F77B4
    10: (188, 189, 34), # lime green #BCBD22
    15: (23, 190, 207), # light blue #17BECF
    4: (140, 86, 75), # brown #8C564B
    6: (227, 119, 194), # pink #E377C2
    7: (127, 127, 127), # grey #7F7F7F
}

RAINBOW_COLORS = [
    "rgb(244, 67, 54)",
    "rgb(232, 30, 99)",
    "rgb(192, 35, 120)",
    "rgb(169, 37, 150)",
    "rgb(156, 39, 176)",
    "rgb(103, 58, 183)",
    "rgb(63, 81, 181)",
    "rgb(33, 150, 243)",
    "rgb(3, 169, 244)",
    "rgb(0, 188, 212)",
    "rgb(0, 150, 136)",
    "rgb(76, 175, 80)",
    "rgb(139, 195, 74)",
    "rgb(205, 220, 57)",
    "rgb(255, 235, 59)",
    "rgb(255, 193, 7)",
    "rgb(255, 152, 0)",
    "rgb(255, 87, 34)"
]

FONT_SIZE_TITLE = 24
FONT_SIZE_AXIS = 22
FONT_SIZE_TICKS = 20
FONT_SIZE_LEGEND = 20

# Analyze Causes
Why are the results as they are? In this section we try to answer some question by analyzing the hidden states and attention patterns of the TrXL in different layers. To this end, a trained SHELM model was used to sample several example episodes of the Memory and Psychlab environments. During sampling the hidden states, attention distributions and rewards were saved and later visualized/ analyzed.

In [None]:
np.random.seed(101)
torch.cuda.manual_seed(101)
random.seed(101)

# if not os.path.exists("/home/dominik/Nextcloud/Dokumente/Studium/Master/Thesis/Results/investigation-data-mem.pkl"):

    
    
observations_path = os.path.join("data", "observations")

if os.path.exists(observations_path):
    
    observations_psych = sorted(
    glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep1", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep1", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_psych2 = sorted(
        glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep2", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem2 = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep2", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_psych3 = sorted(
        glob.glob(os.path.join(observations_path, "Continuous-Recognition", "ep3", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
    observations_mem3 = sorted(
        glob.glob(os.path.join(observations_path, "Memory", "ep3", "*.png")),
        key=lambda x: int(x.split('step-')[-1].split('-')[0])
    )
else:
    raise FileNotFoundError(f"Ensure the observations are downloaded with the script in `data/download_results.sh`")

In [None]:
def cosine_similarity(a, b):
    return np.dot(a,b) / (norm(a) * norm(b))

def load_model_output(env_v: str) -> Tuple[Dict[int, List[float]], Dict[int, List[float]], List[np.ndarray], Dict[int, List[float]], List[str]]:
    """
    Load the model output data from a pickle file.

    Args:
        env_v (str): The environment version.

    Returns:
        Tuple[Dict[int, List[float]], Dict[int, List[float]], List[np.ndarray], Dict[int, List[float]], List[str]]:
            A tuple containing the following data:
            - cosine_distances (Dict[int, List[float]]): The cosine distances for all evaluated layers.
            - l2_distances (Dict[int, List[float]]): The L2 distances for all evaluated layers.
            - hidden_states (List[np.ndarray]): The hidden states for each timestep.
            - attentions (Dict[int, List[float]]): The attention weights for all evaluated layers.
            - tokens (List[str]): The tokens for the sampled episode.

    Raises:
        FileNotFoundError: If the data file does not exist.

    Note:
        Ensure the data is downloaded with the script in `data/download_results.sh`.
    """
    data_path = os.path.join("data", "investigation-data", f"investigation-data-{env_v}.pkl")
    if os.path.exists(data_path):
        with open(data_path, "rb") as f:
            data = pickle.load(f)
            cosine_distances = data["cosine_distances"]
            l2_distances = data["l2_distances"]
            hidden_states = data["hidden_states"]
            attentions = data["attention"]
            tokens = data["tokens"]
    else:
        raise FileNotFoundError(f"Ensure the data is downloaded with the script in `data/download_results.sh`")

    return cosine_distances, l2_distances, hidden_states, attentions, tokens

In [None]:
cosine_distances_mem, l2_distances_mem, hidden_states_mem, attention_mem, tokens_mem = \
load_model_output("mem")

In [None]:
cosine_distances_mem2, l2_distances_mem2, hidden_states_mem2, attention_mem2, tokens_mem2 = \
load_model_output("mem2")

In [None]:
cosine_distances_mem3, l2_distances_mem3, hidden_states_mem3, attention_mem3, tokens_mem3 = \
load_model_output("mem3")

In [None]:
cosine_distances_psy, l2_distances_psy, hidden_states_psy, attention_psy, tokens_psy = \
load_model_output("psy")

In [None]:
cosine_distances_psy2, l2_distances_psy2, hidden_states_psy2, attention_psy2, tokens_psy2 = \
load_model_output("psy2")

In [None]:
cosine_distances_psy3, l2_distances_psy3, hidden_states_psy3, attention_psy3, tokens_psy3 = \
load_model_output("psy3")

## Attention
Visualizing the attention mechanism in a transformer can be quite challenging considering there are multiple layers each with multiple attention heads. In the case of the TrXL, the default model has 18 hidden layers, each containing 16 attention heads.
### Softmax Distribution
One way of showing differences in the attention heads is by comparing the softmax distribution of heads. This can be done with a violin plot to actually show the distribution or by summing up the softmax and counting how many tokens are necessary to reach a certain threshold. This way attention heads can loosely be categorized by how big their window of attention is and if they only focus only on single tokens are attend to a more tokens. Furthermore we also tried to show to which token in the memory the input token attends.

In [None]:
def plot_softmax_comparison(
        attention: Dict[int, List[float]],
        negative: int,
        positive: int, 
) -> go.Figure:
    """
    Plot a comparison of the softmax values for two layers in the attention mechanism.

    Args:
        attention (Dict[int, List[float]]): The attention data, where the layers and the values are the attention weights over all timesteps.
        negative (int): The layer used for the negative side of the plot.
        positive (int): The layer used for the positive side of the plot.

    Returns:
        go.Figure: The plot showing the softmax values for each step and head for the negative and positive layers.

    Note:
        The plot is a grid of violin plots, with one violin plot per step and head.
        The violin plots show the distribution of softmax values for the negative and positive layers.
        The violin plots are colored according to the layer (negative: green, positive: red).
    """
    row_titles = [f"Head {i}" for i in range(1,17)]
    
    fig = make_subplots(
        rows=16, cols=12, shared_yaxes=True, shared_xaxes=True,
        row_titles=row_titles,
        vertical_spacing=0.0, horizontal_spacing=0.0,
    )
    
    show_legend = True

    for row in range(16):
        for col in range(12):
            fig.add_trace(
                go.Violin(
                    x=[f"Step {col+1}"]*len(attention[negative][col][row]),
                    y=attention[negative][col][row],
                    legendgroup=f'Layer {negative}', scalegroup=f'Layer {negative}', name=f'Layer {negative}',
                    side='negative',
                    line_color="rgb(44, 160, 44)",
                    showlegend=show_legend and row == 0,
                    width=0.5,
                    points=False,
                ),
                row=row+1,
                col=col+1
            )

            fig.add_trace(
                go.Violin(
                    x=[f"Step {col+1}"]*len(attention[positive][col][row]),
                    y=attention_mem[positive][col][row],
                    legendgroup=f'Layer {positive}', scalegroup=f'Layer {positive}', name=f'Layer {positive}',
                    side='positive',
                    line_color="rgb(214, 39, 40)",
                    showlegend=show_legend and row == 0,
                    width=0.5,
                    points=False,
                ),
                row=row+1,
                col=col+1
            )
            show_legend = False

    fig.update_layout(
        title=dict(
            text=f"Attention Softmax Distribution - Layer {negative} vs. Layer {positive} - Memory",
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.996,
        ),

        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            visible=False,
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.032,
            xanchor="center",
            x=.5,
            bgcolor="rgba(0,0,0,0)",
            font_size=FONT_SIZE_LEGEND
        ),
        violingap=0,
        violinmode='overlay',
        height=2000,
        width=1100,
        margin=dict(
            l=40,
            r=20,
            b=30,
            t=40,#20
            pad=10
        ),
    )
    
    fig.for_each_annotation(
        lambda a: a.update(x = -0.035, textangle=-90) if a.text in row_titles
        else()
    )
    
    return fig

In [None]:
def plot_heatmaps_slider(
        attention: Dict[int, List[float]],
        title: str,
        threshold: float=0.9
) -> go.Figure:
    """
    Plot heatmaps with a slider to toggle between different layers of attention. The heatmap shows the number of tokens necessary to sumup the softmax distribution to at least the threshold.

    Args:
        attention (Dict[int, List[float]]): A dictionary containing attention weights for different layers.
        title (str): The title of the plot.
        threshold (float, optional): The threshold for filtering attention weights. Defaults to 0.9.

    Returns:
        go.Figure: The plot object.

    Note:
        - The function creates a heatmap for each layer and adds a slider to toggle between the layers.
    """

    num_tokens = {i: [] for i in layer_of_interest}
    for l in layer_of_interest:
        tmp_tokens = []
        try:
            for heads in attention[l]:
                sorted_attention = -np.sort(-heads, axis=-1)
                cumsum = np.cumsum(sorted_attention, axis=1)
                tmp_tokens.append(np.where(cumsum >= threshold, False, True).sum(axis=1) + 1)
        except KeyError:
            continue
    
        num_tokens[l] = np.stack(tmp_tokens, axis=1)

    
    fig = go.Figure()

    keys = list(attention.keys())
    num_steps = len(attention[keys[0]])
    
    for l in attention:
        
        fig.add_trace(
            go.Heatmap(
                z=num_tokens[l],
                x=[f"Step {i+1}" for i in range(num_steps)],
                y=[f"Head {i+1}" for i in range(16)],
                colorscale='Inferno',
                zmin=1,
                zmax=16 if "Memory" in title else 255,
                visible=False,
                colorbar_tickfont_size=FONT_SIZE_LEGEND
            )
        )
    
    # Make 0th trace visible
    fig.data[0].visible = True
    
    # Create and add slider
    steps = []
    for i, l in enumerate(attention):
        step = dict(
            method="update",
            label=f"Layer {l}",
            args=[{"visible": [False] * len(fig.data)}]
        )

        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"

        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    if "Memory" in title:
        fig.for_each_trace(lambda a: a.update(text=a.z, texttemplate="%{text}", hovertemplate=None))
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.985,
        ),
        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.15,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            traceorder="normal",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=800,
        margin=dict(
            l=10,
            r=40,
            b=20,
            t=40,#20
            pad=10
        ),
    )
    return fig

In [None]:
def plot_attention_heatmaps_over_tokens_slider(
        attention: Dict[int, List[float]],
        layer: int,
        tokens: List[str],
        title: str
) -> go.Figure:
    """
    Generates a plot of attention heatmaps over tokens with a slider to navigate through time steps.

    Parameters:
    - attention (Dict[int, List[float]]): A dictionary containing attention weights for each layer and time step.
    - layer (int): The layer of attention to plot.
    - tokens (List[str]): The list of tokens.
    - title (str): The title of the plot.

    Returns:
    - fig (go.Figure): The plot of attention heatmaps over tokens with a slider.

    Note:
    - The plot includes a slider to navigate through the time steps.
    """
    fig = go.Figure()
    mem_len = 16 if "Memory" in title else 256
    for idx, timestep in enumerate(attention[layer]):
        mem_tokens = ["<EMPTY>"] * (mem_len-(idx+1)) + tokens[:idx+1]
        
        n_heads, n_steps = timestep.shape
        fig.add_trace(
            go.Heatmap(
                z=np.flip(timestep, 1),
                x=[f"Token {i+1}" for i in range(n_steps)],
                y=[f"Head {i+1}" for i in range(n_heads)][::-1],
                colorscale='turbo',
                zmin=0.,
                zmax=1.,
                visible=False,
                colorbar_tickfont_size=FONT_SIZE_LEGEND
            )
        )
    
    # Make 0th trace visible
    fig.data[0].visible = True
    
    # Create and add slider
    steps = []
    for i, l in enumerate(attention[layer]):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)}]
        )

        step["args"][0]["visible"][i] = True 

        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50 if "Memory" in title else 80},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        title=dict(
            text=title,
            font_size=FONT_SIZE_TITLE,
            xanchor="center",
            x=0.5,
            yanchor="top",
            y=0.985,
        ),
        xaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        yaxis=dict(
            tickfont_size=FONT_SIZE_TICKS,
            title_font_size=FONT_SIZE_AXIS,
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.15,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            traceorder="normal",
            font_size=FONT_SIZE_LEGEND
        ),
        height=700,
        width=1000,
        margin=dict(
            l=10,
            r=40,
            b=20,
            t=40,#20
            pad=10
        ),
    )
    return fig

In [None]:
def plot_max_attention_per_layer_slider(
        attention: Dict[int, List[float]],
        tokens: List[str],
        layer: int,
        env: str
) -> go.Figure:
    """
    Generates a plot of the maximum attention per head for each time step in a given layer using a slider.
    The left side represents the tokens retrieved from the input and the right side represents the tokens already in the memory of the TrXL.

    Parameters:
    - attention (Dict[int, List[float]]): A dictionary containing attention weights for each layer and time step.
    - tokens (List[str]): The list of tokens.
    - layer (int): The layer of attention to plot.
    - env (str): The environment name.

    Returns:
    - fig (go.Figure): The plot of the maximum attention per head for each time step in the given layer with a slider.

    Note:
    - The plot includes a slider to navigate through the time steps.
    """
    fig = go.Figure()
    
    for idx, timestep in enumerate(attention[layer]):
        
        n_heads, n_steps = timestep.shape
        mem_len = 16 if "Memory" in env else 256
        
        if idx < mem_len:
            mem_tokens = ["<EMPTY>"] * (mem_len-(idx+1)) + tokens[:idx+1]
            
        else:
            start_idx = idx - mem_len +1
            end_idx = start_idx + mem_len
            mem_tokens = tokens[start_idx:end_idx]
            
        fig.add_trace(
            go.Parcoords(
                visible=False,
                line = dict(
                    color=np.max(timestep, axis=1),
                    colorscale="turbo",
                    showscale=True,
                    cmin=0.,
                    cmax=1.
                ),
                dimensions = list([
                    dict(
                        range = [0,len(tokens)],
                        #constraintrange = [1,1],
                        tickvals=[i for i in range(1,len(tokens)+1)],
                        ticktext=[f"{i+1}. {token}" for i, token in enumerate(tokens)],
                        label='Input Token',
                        values=[idx+1 for _ in range(n_heads)]
                    ),
                    dict(
                        range = [0,mem_len],
                        #constraintrange = [1,1],
                        tickvals=[i for i in range(1,mem_len+1)],
                        ticktext=[f"{i}. {token}" for i, token in enumerate(mem_tokens, 1)],#[::-1],
                        label='Memory',
                        values=np.argmax(timestep, axis=1)+1
                    )
                ]),
                #unselected = dict(line = dict(color = 'green', opacity = 0.0))
            )
        )
    
    fig.data[0].visible = True
    
    
    steps = []
    for i in range(idx+1):
        step = dict(
            method="update",
            label=f"Timestep {i+1}",
            args=[{"visible": [False] * len(fig.data)}]
        )
    
        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    
        steps.append(step)
    
    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]
    
    
    fig.update_layout(
        plot_bgcolor = 'white',
        paper_bgcolor = 'white'
    )
    
    fig.update_layout(
        sliders=sliders,
        title=f"Max. Attention Per Head for Layer {layer} - {env}",
        title_x=0.5,
        title_y=0.985,
        title_xanchor="center",
        title_yanchor="top",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.15,
            xanchor="center",
            x=0.5,
            bgcolor="rgba(0,0,0,0)",
            traceorder="normal"
        ),
        height=700 if "Memory" in env else 1000,
        width=1000,
        margin=dict(
            l=100,
            r=40,
            b=20,
            t=40,#20
            pad=10
        ),
    )
    
    return fig

#### Memory Environment
##### Episode 1

In [None]:
plot_softmax_comparison(attention_mem, 3, 2).show()

In [None]:
plot_heatmaps_slider(attention_mem, "No. Tokens for Softmax >= 0.9 - Memory").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_mem, 3, tokens_mem,"Memory").show()

In [None]:
plot_max_attention_per_layer_slider(attention_mem, tokens_mem, 3, "Memory")

##### Episode 2

In [None]:
plot_softmax_comparison(attention_mem2, 3, 18).show()

In [None]:
plot_heatmaps_slider(attention_mem2, "No. Tokens for Softmax >= 0.9 - Memory").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_mem2, 3, tokens_mem2, "Memory").show()

In [None]:
plot_max_attention_per_layer_slider(attention_mem2, tokens_mem2, 3, "Memory")

##### Episode 3

In [None]:
plot_softmax_comparison(attention_mem3, 3, 18).show()

In [None]:
plot_heatmaps_slider(attention_mem3, "No. Tokens for Softmax >= 0.9 - Memory").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_mem3, 3, tokens_mem3, "Memory").show()

In [None]:
plot_max_attention_per_layer_slider(attention_mem3, tokens_mem3, 15, "Memory")

#### Continuous Recognition Environment
Not all above visualization are suitable for visualizing the softmax distribution.

##### Episode 1

In [None]:
plot_heatmaps_slider(attention_psy, "No. Tokens for Softmax >= 0.9 - Continuous Recognition").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_psy, 3, tokens_psy, "Continuous Recognition").show()

In [None]:
plot_max_attention_per_layer_slider(attention_psy, tokens_psy, 5, "Continuous Recognition")

##### Episode 2

In [None]:
plot_heatmaps_slider(attention_psy2, "No. Tokens for Softmax >= 0.9 - Continuous Recognition").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_psy2, 18, tokens_psy2, "Continuous Recognition").show()

In [None]:
plot_max_attention_per_layer_slider(attention_psy2, tokens_psy2, 3, "Continuous Recognition").show()

##### Episode 3

In [None]:
plot_heatmaps_slider(attention_psy3, "No. Tokens for Softmax >= 0.9 - Continuous Recognition").show()

In [None]:
plot_attention_heatmaps_over_tokens_slider(attention_psy3, 18, tokens_psy3, "Continuous Recognition").show()

In [None]:
plot_max_attention_per_layer_slider(attention_psy3, tokens_psy3, 10, "Continuous Recognition").show()

## Rewards
Seeing the different attention pattern in the different heads raises the question which heads are the most important ones. To analyze this a trained SHELM model was was initialized with one of the 16 heads masked out. This model was then used to sample 50 episodes while collecting the rewards. Due to limited compute resources not all possible head maskings could be evaluated. 

In [None]:
reward_path = os.path.join("data", "rewards")

if os.path.exists(reward_path):
    rews_mem = np.load(os.path.join(reward_path, "Memory-Rewards.npy"))
    rews_psy = np.load(os.path.join(reward_path, "Psychlab-Rewards.npy"))
else:
    raise FileNotFoundError(f"Ensure the reward data is downloaded with the script in `data/download_results.sh`")

In [None]:
def plot_head_mask(
        data: np.ndarray,
        title: str
    ) -> go.Figure:
    """
    Generates a box plot of accumulated rewards if sayed head is masked out during inference.

    Parameters:
    - data (np.ndarray): A 2D array containing the accumulated rewards for each head.
    - title (str): The title of the plot.

    Returns:
    - fig (go.Figure): The box plot of accumulated rewards for each head with the option to mask out a specific head.
    """

    fig = go.Figure()

    for i, rews in enumerate(data, 1):
        fig.add_trace(
            go.Box(
                name=f"Head {i}",
                y=rews,
                marker_color=RAINBOW_COLORS[i-1],
                showlegend=False
            )
        )
    
    fig.update_layout(
        title=title,
        xaxis_title="Masked Out Head",
        yaxis_title="Accumulated Reward",
        title_x=0.5,
        title_y=0.985,
        title_xanchor="center",
        title_yanchor="top",
        legend_traceorder="reversed",
        height=700,
        width=900,
        margin=dict(
            l=10,
            r=20,
            b=20,
            t=40,
            pad=10
        ),
    )
    
    fig.update_yaxes(
        range=[0., 1.] if "Memory" in title else [0., 55.]
    )
    
    return fig

#### Memory Environment

In [None]:
plot_head_mask(rews_mem, "Reward with Attention Head masked out - Memory").show()

#### Continuous Recognition Environment

In [None]:
plot_head_mask(rews_psy, "Reward with Attention Head masked out - Continuous Recognition").show()