# Task Relevant Concepts, the Instruction feature and Adversarial Prompts for MemoryDT

## Set Up 

In [337]:
import sys 
import torch
import numpy as np
import pandas as pd
import plotly.express as px


import pandas as pd
import umap
from sklearn.metrics.pairwise import cosine_similarity
from scipy.cluster import hierarchy


sys.path.append("..")

# get the labels. 
from src.streamlit_app.constants import SPARSE_CHANNEL_NAMES
import itertools 

all_index_labels = [
    SPARSE_CHANNEL_NAMES,
    list(range(7)),
    list(range(7)),
]
indices = list(itertools.product(*all_index_labels))
index_labels = ["{0}, ({1},{2})".format(*index) for index in indices]


from src.streamlit_app.environment import get_env_and_dt
from src.environments.registration import register_envs
from src.environments.memory import MemoryEnv
from src.decision_transformer.offline_dataset import one_hot_encode_observation
from src.streamlit_app.causal_analysis_components import get_updated_obs
from src.visualization import render_minigrid_observations, render_minigrid_observation
from src.environments.utils import reverse_one_hot
import PIL
import io
from IPython.display import Image

register_envs()

# Functions for dealing with trajectories/adversaries
def get_all_memory_env_scenarios(env, rtg=0.8920, full_trajectory=False):

    env.reset()

    all_observations = []
    all_actions = []

    for target_obj in ["key", "ball"]:
        for target_pos in ["top", "bottom"]:
            observations = [] 
            actions= []

            # generate the new environment 
            env._gen_grid(7,7,target_obj=target_obj, target_pos=target_pos)
            obs = env.gen_obs()['image']
            observations.append(obs)
            n_steps = 6 if full_trajectory else 4
            for i in range(n_steps):
                if i == 4:
                    action = 0 if target_pos == "top" else 1
                else:
                    action = 2
                next_obs, reward, next_done, next_truncated, info  = env.step(action)
                observations.append(next_obs['image'])
                actions.append(action)

            all_observations.append(observations)
            all_actions.append(actions)

            

    all_observations = torch.from_numpy(np.stack(all_observations))
    all_actions = torch.from_numpy(np.stack(all_actions))

    one_hot_s = torch.stack([one_hot_encode_observation(all_observations[i]) for i in range(all_observations.shape[0])])

    # pad the first dimension to be 9 with all 0's 
    obs = torch.cat([
        torch.zeros(one_hot_s.shape[0], 9 - one_hot_s.shape[1], *one_hot_s.shape[2:]), 
        one_hot_s
        ], dim=1)

    actions = torch.stack([all_actions[i] for i in range(all_actions.shape[0])])
    # pad it with 7's so that the length is 9 
    actions = torch.cat([
        torch.ones(actions.shape[0], 8 - actions.shape[1]) * 7, 
        actions
        ], dim=1).to(int).unsqueeze(-1)
    actions.shape

    time = torch.cat([
        torch.tensor([0]*5),
        torch.arange(1,5),
        ], dim=0)
    # repeat it for each trajectory
    time = torch.stack([time] * actions.shape[0]).unsqueeze(-1)

    # just set to max for now 
    rtg = torch.ones(actions.shape[0], time.shape[1]).unsqueeze(-1) * rtg

    scenario_labels = [
        'Key, Key-Ball',
        'Key, Ball-Key',
        'Ball, Ball-Key',
        'Ball, Key-Ball'
    ]

    return obs, actions, rtg, time, scenario_labels

def get_images_batched(env, obs):
    all_images = []
    for scenario in range(obs.shape[0]):
        original_obs = np.stack([reverse_one_hot(obs[scenario][i]) for i in range(obs.shape[1])])
        images = render_minigrid_observations(env, original_obs)
        all_images.append(images)

    all_images = torch.from_numpy(np.stack(all_images))
    return all_images

def get_trajectory_plotly_animation(all_images, 
    scenario_labels = [
        ['Key, Key-Ball',
         'Key, Ball-Key',
         'Ball, Ball-Key',
         'Ball, Key-Ball',]
    ], start_frame=4):
    # show all trajectories with animation frame. 
    fig = px.imshow(all_images[:,start_frame:], animation_frame=1, facet_col=0)
    # add 0 - 6 as xticks mdoe linear
    labels = [str(i) for i in list(range(0,7))]

 
    # compute tick positions assuming the width of the plot is normalized to 1
    tick_positions =  np.linspace(0,224,8)[1:] - 16

    # add xticks
    fig.update_xaxes(
        tickmode = 'array',
        tickvals = tick_positions,
        ticktext = labels
    )

        # add xticks
    fig.update_yaxes(
        tickmode = 'array',
        tickvals = tick_positions,
        ticktext = labels
    )

    # update facet col names with scenario labels
    fig.for_each_annotation(lambda a: a.update(text=scenario_labels[0][int(a.text.split("=")[-1])]))
    # increase font size
    fig.update_layout(font=dict(size=24))
    # make it much larger
    fig.update_layout(width=1200)
    return fig 

def create_animation(fig, file_name="test.gif", 
                     additional_frames_start=0,
                     additional_frames_end=5, 
                     duration=500, loop=0):
    """
    Creates an animated GIF from a Plotly figure with frames.

    Parameters:
    - fig: The Plotly figure object containing the frames for the animation.
    - file_name (str): The name of the output GIF file.
    - additional_frames_start (int): The number of additional frames to add at the start of the GIF.
    - additional_frames_end (int): The number of additional frames to add at the end of the GIF.
    - duration (int): The duration of each frame in the GIF in milliseconds.
    - loop (int): The number of times to loop the GIF. 0 for infinite looping.

    Returns:
    None
    """
    frames = []



    for s, fr in enumerate(fig.frames):
        # set main traces to appropriate traces within plotly frame
        fig.update(data=fr.data)
        # move slider to correct place
        fig.layout.sliders[0].update(active=s)
        # hide controls
        fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 0
        # hide play button
        fig.layout.updatemenus[0].buttons[1].args[1]["frame"]["duration"] = 0
        # generate image of current state
        frames.append(PIL.Image.open(io.BytesIO(fig.to_image(format="png"))))

    # append duplicated first image more times, to keep animation stop at first status
    for i in range(additional_frames_start):
        frames.insert(0, frames[0])

    # append duplicated last image more times, to keep animation stop at last status
    for i in range(additional_frames_end):
        frames.append(frames[-1])

    # create animated GIF
    frames[0].save(
        file_name,
        save_all=True,
        append_images=frames[1:],
        optimize=False,
        duration=duration,
        loop=loop,
    )

def update_observations(obs, frame_to_update, positions, channels):
    # create a copy of obs to avoid editing it 
    obs = obs.clone()
    for scenario in range(obs.shape[0]):
        channel = channels[scenario]

        for position in positions:
            obs[scenario][frame_to_update] = get_updated_obs(
                obs[scenario], frame_to_update, position[0], position[1], channel
            )

    return obs


# Functions for dealing with embeddings
def get_channel_labels(index_labels):
    return [label.split(",")[0] for label in index_labels]

def plot_norms(norms, channel_labels, index_labels):
    fig = px.strip(
        # x = channel_labels,
        y=norms, 
        labels={"y": "L2 Norm", "x": "Channel"}, 
        color = channel_labels,
        hover_name=index_labels,
        orientation="v",
        template="plotly+presentation",
         color_discrete_sequence=[
            "#0068c9",
            "#83c9ff",
            "#ff2b2b",
            "#ffabab",
            "#29b09d",
            "#7defa1",
            "#ff8700",
            "#ffd16a",
            "#6d3fc0",
            "#d5dae5",
        ],
    )
    # hide the legend
    # fig.update_layout(showlegend=False)
    fig.update_layout(font=dict(size=24))

    # increase size of markers in legend only
    
    # make scatter points larger
    fig.update_traces(marker=dict(size=7))
    fig.update_layout(legend= {'itemsizing': 'constant'})

    # label x - axis as "Channel"
    fig.update_xaxes(title_text="Channel")

    # make it wide and tall
    fig.update_layout(height=800, width=1000)
    return fig

def plot_umap_result(ummap_result, channel_labels, index_labels):

    fig = px.scatter(
        ummap_result, 
        x= 0,
        y=1,
        color=channel_labels,
        hover_name=index_labels,
        opacity=0.5,
        template="plotly",
    )

    # make points larger
    fig.update_traces(marker=dict(size=16))

    # make it wide and tall
    fig.update_layout(height=800, width=1200)

    return fig

def get_filtered_embeddings(embeddings, index_labels, channels = ["key", "ball"], norm = 0):

    channel_labels = get_channel_labels(index_labels)
    norms = torch.norm(embeddings, dim=0)
    index_mask = [label in channels and norms[i] >= norm for i, label in enumerate(channel_labels)]
    restricted_embeddings = embeddings.T[index_mask]
    restricted_labels = [label for label, mask in zip(index_labels, index_mask) if mask]
    restricted_channels = get_channel_labels(restricted_labels)

    assert len(restricted_embeddings) == len(restricted_labels)
    assert len(restricted_embeddings) == len(restricted_channels)

    return restricted_embeddings, restricted_labels, restricted_channels 

def get_long_similarity_df(cosine_similarity_df):
    cosine_similarity_df_long = pd.melt(cosine_similarity_df, ignore_index=False).reset_index()

    # remove the second half of rows
    cosine_similarity_df_long = cosine_similarity_df_long[cosine_similarity_df_long["index"] < cosine_similarity_df_long["variable"]]

    # rename the columns
    cosine_similarity_df_long.columns = ["channel_1", "channel_2", "cosine_similarity"]

    # add a column for if the channels are the same or not
    cosine_similarity_df_long["same_channel"] = cosine_similarity_df_long["channel_1"].str.split(",").str[0] == cosine_similarity_df_long["channel_2"].str.split(",").str[0]

    # add a column for if the positions are the same or not
    cosine_similarity_df_long["same_position"] = cosine_similarity_df_long["channel_1"].str.split("(").str[1] == cosine_similarity_df_long["channel_2"].str.split("(").str[1]

    # combine the two so we can have different rows for each combination
    cosine_similarity_df_long["category"] = cosine_similarity_df_long["same_channel"].astype(str) + ", " + cosine_similarity_df_long["same_position"].astype(str)

    return cosine_similarity_df_long

def get_similar_embeddings(cosine_similarity_df_long, embeddings, index_labels, similarity_threshold = 0.5, embedding_string = "ball (1,6)"):

    criteria_one = cosine_similarity_df_long["channel_1"].str.contains(embedding_string, regex=False) | \
        cosine_similarity_df_long["channel_2"].str.contains(embedding_string, regex=False)
    criteria_two = cosine_similarity_df_long["cosine_similarity"].abs() > similarity_threshold
    mask = criteria_one & criteria_two

    masked_matrix = cosine_similarity_df_long[mask].sort_values(by="cosine_similarity", ascending=False)

    # get all unique channel_1 and channel_2 values and put them in a list
    vocab_items = list(set(list(masked_matrix["channel_1"].unique()) + list(masked_matrix["channel_2"].unique())))

    index_mask = [True if item in vocab_items else False for item in index_labels]
    restricted_embeddings = embeddings[index_mask]
    restricted_labels = [label for label, mask in zip(index_labels, index_mask) if mask]
    restricted_channels = get_channel_labels(restricted_labels)

    return restricted_embeddings, restricted_labels, restricted_channels

def plot_cosine_similarity_heatmap(df, restricted_labels, reorder = False, title="Pairwise Cosine Similarity Heatmap"):
    data_array = df.to_numpy()
    linkage = hierarchy.linkage(data_array)
    dendrogram = hierarchy.dendrogram(
        linkage, no_plot=True, color_threshold=-np.inf
    )
    if reorder:
        reordered_ind = dendrogram["leaves"]
        # reorder df by ind
        df = df.iloc[reordered_ind, reordered_ind]
        # data_array = df.to_numpy()

    # plot the cosine similarity matrix
    fig = fig = px.imshow(
            df,
            color_continuous_scale="RdBu",
            title="Reordered - Pairwise Cosine Similarity Heatmap",
            color_continuous_midpoint=0.0,
            labels={"color": "Cosine Similarity"},
        )
    fig.update_xaxes(
        tickmode="array",
        tickvals=list(range(len(restricted_labels))),
        ticktext=restricted_labels,
        showgrid=False,
    )
    fig.update_yaxes(
        tickmode="array",
        tickvals=list(range(len(restricted_labels))),
        ticktext=restricted_labels,
        showgrid=False,
    )

    # don't show axes if there are more than 20 rows 
    if df.shape[0] > 20:
        fig.update_xaxes(
            visible=False,
        )
        fig.update_yaxes(
            visible=False,
        )
    return fig


[33mWARN: Overriding environment DynamicObstaclesMultiEnv-v0 already in registry.[0m


[33mWARN: Overriding environment CrossingMultiEnv-v0 already in registry.[0m


[33mWARN: Overriding environment MultiRoomMultiEnv-v0 already in registry.[0m


[33mWARN: Overriding environment Probe1-v0 already in registry.[0m


[33mWARN: Overriding environment Probe2-v0 already in registry.[0m


[33mWARN: Overriding environment Probe3-v0 already in registry.[0m


[33mWARN: Overriding environment Probe4-v0 already in registry.[0m


[33mWARN: Overriding environment Probe5-v0 already in registry.[0m


[33mWARN: Overriding environment Probe6-v0 already in registry.[0m


[33mWARN: Overriding environment MiniGrid-MemoryS7RandomDirection-v0 already in registry.[0m


[33mWARN: Overriding environment MiniGrid-MemoryS7FixedStart-v0 already in registry.[0m



In [366]:
# GET MODEL AND ENVIRONMENT
model_index = {
    # The original post, reproduced with 1/10th of the total training epochs
    "models/MiniGrid-MemoryS7FixedStart-v0/WorkingModel.pt": "MemoryDT",
    "models/MiniGrid-MemoryS7FixedStart-v0/MemoryGatedMLP.pt": "MemoryDTGatedMLP",
    "models/MiniGrid-Dynamic-Obstacles-8x8-v0/ReproduceOriginalPostShort.pt": "DynamicObstaclesDT_reproduction",
}

env, dt = get_env_and_dt("../models/MiniGrid-MemoryS7FixedStart-v0/WorkingModel.pt")
_, dt_gated_mlp = get_env_and_dt("../models/MiniGrid-MemoryS7FixedStart-v0/MemoryGatedMLP.pt")


embeddings = dt.state_embedding.weight.detach()
centred_embeddings = embeddings - embeddings.mean(dim=0)
normalized_embeddings = centred_embeddings / torch.norm(centred_embeddings, dim=1, keepdim=True)
norms = torch.norm(embeddings.T, dim=1)
channel_labels = get_channel_labels(index_labels)

restricted_embeddings, restricted_labels, restricted_channels = get_filtered_embeddings(normalized_embeddings, index_labels, channels=["key", "ball"])
cosine_similarity_matrix = cosine_similarity(restricted_embeddings)
cosine_similarity_df = pd.DataFrame(cosine_similarity_matrix, columns=restricted_labels, index=restricted_labels)
cosine_similarity_df_long = get_long_similarity_df(cosine_similarity_df)




## Understanding Embedding Space

Norms of the embeddings vary greatly, as weight decay was not applied to them, and some channels/positions were more important than others.

In [367]:

plot_norms(norms, channel_labels, index_labels).show()

### U-Map


U-Map struggles for the entire dataset since many data points are not important and whose directions are likely meaningless. 

In [358]:
reducer = umap.UMAP(
    n_neighbors=8,
    min_dist=0.05,
    n_components=2,
    metric="cosine",
    random_state=42,
)

ummap_result = reducer.fit_transform(embeddings.T)
plot_umap_result(ummap_result, channel_labels, index_labels).show()


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.



Clustering online large normed vectors for keys and ball reveals more structure, but it's hard to interpret the result. Clustering just keys/ball is more useful.

In [368]:
restricted_embeddings_tmp, restricted_labels_tmp, restricted_channels_tmp = get_filtered_embeddings(normalized_embeddings, index_labels, channels=["key", "ball"], norm = 0.8)

reducer = umap.UMAP(
    n_neighbors=3,
    min_dist=0.2,
    n_components=2,
    metric="cosine",
    random_state=42,
)

ummap_result = reducer.fit_transform(restricted_embeddings_tmp)
plot_umap_result(
    ummap_result, 
    restricted_channels_tmp, 
    restricted_labels_tmp
).show()


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.



Cosine similarity plots however, show more structure.

In [369]:
def plot_cosine_similarity_heatmap(df, restricted_labels, reorder = False, title="Pairwise Cosine Similarity Heatmap"):
    data_array = df.to_numpy()
    linkage = hierarchy.linkage(data_array)
    dendrogram = hierarchy.dendrogram(
        linkage, no_plot=True, color_threshold=-np.inf
    )
    if reorder:
        reordered_ind = dendrogram["leaves"]
        # reorder df by ind
        df = df.iloc[reordered_ind, reordered_ind]
        # data_array = df.to_numpy()

    # plot the cosine similarity matrix
    fig = fig = px.imshow(
            df,
            color_continuous_scale="RdBu",
            color_continuous_midpoint=0.0,
            labels={"color": "Cosine Similarity"},
        )
    fig.update_xaxes(
        tickmode="array",
        tickvals=list(range(len(restricted_labels))),
        ticktext=restricted_labels,
        showgrid=False,
    )
    fig.update_yaxes(
        tickmode="array",
        tickvals=list(range(len(restricted_labels))),
        ticktext=restricted_labels,
        showgrid=False,
    )

    # don't show axes if there are more than 20 rows 
    if df.shape[0] > 20:
        fig.update_xaxes(
            visible=False,
        )
        fig.update_yaxes(
            visible=False,
        )
    return fig

fig = plot_cosine_similarity_heatmap(cosine_similarity_df, restricted_labels)
fig.update_layout(height=800, width=1000)
fig.update_layout(font=dict(size=24))
fig.show()

fig = plot_cosine_similarity_heatmap(cosine_similarity_df, restricted_labels, reorder=True)
fig.update_layout(height=800, width=1000)
fig.update_layout(font=dict(size=24))
fig.show()

It seems like particular embeddings can be highly correlated and this sometimes runs accross keys/balls or accross positions. We can visualize the distribution of cosine similarity by pairs which share channels/positions. It's possible to pull out groups of related embeddings using absolute cosine similarity as a distance metric. 

In [389]:
restricted_embeddings, restricted_labels, restricted_channels = get_filtered_embeddings(normalized_embeddings, index_labels, channels=["key", "ball"], norm = 0.8)
cosine_similarity_matrix = cosine_similarity(restricted_embeddings)
cosine_similarity_df = pd.DataFrame(cosine_similarity_matrix, columns=restricted_labels, index=restricted_labels)
cosine_similarity_df_long = get_long_similarity_df(cosine_similarity_df)

fig = px.violin(
    cosine_similarity_df_long,
    x="cosine_similarity",
    hover_data=["channel_1", "channel_2"],
    y = "category",
    color="category",
    labels={"category": "Same Channel, Same Position"},
    template="plotly",
    points="all",
)
# rename x to "Cosine Similarity"
fig.update_xaxes(title_text="Cosine Similarity")
# make font much larger
fig.update_layout(font=dict(size=18))

# remove legend
fig.update_layout(showlegend=False)

# make it wide and tall
fig.update_layout(height=500, width=1200)
fig.show()

Take-aways from this:
1. Most aligned embeddings corresponded to different channels at different positions (ie: embeddings that could be present simultaneously.)
2. Most antimpodal embeddings tended to correspond to keys/balls at the same positions which were anti-correlated. 

We can then find subsets of embeddings which we think might have geometric structure.

In [None]:

restricted_embeddings_tmp, restricted_labels_tmp, restricted_channels_tmp = get_similar_embeddings(
    cosine_similarity_df_long, restricted_embeddings,restricted_labels,
    similarity_threshold = 0.60,embedding_string = "ball, (2,6)")

cosine_similarity_matrix_tmp = cosine_similarity(restricted_embeddings_tmp)
cosine_similarity_df_tmp = pd.DataFrame(cosine_similarity_matrix_tmp, columns=restricted_labels_tmp, index=restricted_labels_tmp)
fig = plot_cosine_similarity_heatmap(cosine_similarity_df_tmp, restricted_labels_tmp, reorder=True)
#increase forn size
fig.update_layout(font=dict(size=18))
fig.update_layout(height=800, width=1000)
fig.show()

In [None]:
restricted_embeddings_tmp, restricted_labels_tmp, restricted_channels_tmp = get_similar_embeddings(
    cosine_similarity_df_long, restricted_embeddings,restricted_labels,
    similarity_threshold = 0.60, embedding_string = "ball, (5,2)")

cosine_similarity_matrix_tmp = cosine_similarity(restricted_embeddings_tmp)
cosine_similarity_df_tmp = pd.DataFrame(cosine_similarity_matrix_tmp, columns=restricted_labels_tmp, index=restricted_labels_tmp)
fig = plot_cosine_similarity_heatmap(cosine_similarity_df_tmp, restricted_labels_tmp, reorder=True)
#increase forn size
fig.update_layout(font=dict(size=18))
fig.update_layout(height=800, width=1000)
fig.show()

Having found these clusters, we can proceed to PCA to understand the underlying geometry. We used the streamlit app for PCA analysis. 

## Adversarial Examples



### Creating Eval Trajectories 

In order to understand the models behavior better, we can prompt it with scenarios where it's observations are slightly different. This is not limited to the set of possible scenarios according to the game rules. 

There are two major categories of the kinds of observations we are generating:
1. Flipping objects. Here we change things like what the instruction was at the start of the trajectory or the top target item. 
2. Adding objects. Here we add objects that wouldn't be present in a normal run of the model. 

In [None]:
env = MemoryEnv(
    size=7, 
    random_direction=False, 
    random_start_pos=False, 
    random_length=False, 
    render_mode="rgb_array")

# get all scenarios
image_obs, actions, rtg, time, scenario_labels = get_all_memory_env_scenarios(env, rtg=0.0, full_trajectory=True)

all_images = get_images_batched(env, image_obs)
fig = get_trajectory_plotly_animation(all_images, start_frame=2)
fig.show()
# make animation
create_animation(fig, file_name="full_trajectory.gif", additional_frames_start=1, duration=300, loop=0)
# display(Image(filename="full_trajecory_gid.gif")) # render in notebook

In [None]:
# We need the last two frames for the optimal path GIF but using that for obs screws up our other analyses.
# So we set obs to be four blank frames (first two frames, twice) then the four frames leading up to the decision point.
obs, actions, rtg, time, scenario_labels = get_all_memory_env_scenarios(env, rtg=0.0, full_trajectory=False)

#  Example usage of the "update_observations" function. 
frame_to_update = 4
positions = [(4, 2)]
channels = [6,6,5,5] # 6 is ball, 5 is key. # We edit all four base scenarios at once.

new_obs = update_observations(obs, frame_to_update, positions, channels)
updated_images = get_images_batched(env, new_obs)
fig = get_trajectory_plotly_animation(updated_images)
fig.show()

### Observation Library


Flipping instructions/targets

- Flipped instruction S5
- Flipped targets S5
- Flipped targets S9
- Flipped targets S5 and S9

Adding Objects/Adversarial

- Adding Instruction Complement at 0,5 in S5
- Adding Instruction Complement at 4,2 in S5
- Adding Instruction Complement at 4,2 and 0,5 in S5
- Adding Instruction Complement at 0,5 in S9
- Adding Instruction Complement at 4,2 in S9
- Adding Instruction Complement at 4,2 and 0,5 in S9

In [None]:
# Herefollows a library of alternative scenarios. 

def generate_scenario_dict(obs):
    original_obs = obs.clone()

    ## Flipped Observations

    # flipped instruction
    instruction_complement_26_S5 = update_observations(
        original_obs, 
        frame_to_update=4, 
        positions = [(2, 6)],
        channels = [6,6,5,5]
    )

    # flipped targets S5 (half)
    flipped_targets_half_start = update_observations(
        original_obs, 
        frame_to_update=4, 
        positions = [(1,2)],
        channels = [6,5,5,6]
    )

    # flipped targets S5 (full)
    target_complement_S5 = update_observations(
        flipped_targets_half_start.clone(), 
        frame_to_update=4, 
        positions = [(5,2)],
        channels = [5,6,6,5]
    )

    # flipped targets S9 (half)
    flipped_targets_half_end = update_observations(
        original_obs, 
        frame_to_update=8, 
        positions = [(1,6)],
        channels = [6,5,5,6]
    )

    # flipped targets S9 (full)
    target_complement_S9 = update_observations(
        flipped_targets_half_end, 
        frame_to_update=8, 
        positions = [(5,6)],
        channels = [5,6,6,5]
    )

    # flipped targets at S5 and S9 (half)
    flipped_targets_half_start_and_end = update_observations(
        target_complement_S5, 
        frame_to_update=8, 
        positions = [(1,6)],
        channels = [6,5,5,6]
    )

    # flipped targets at S5 and S9 (full)
    target_complement_S5_S9 = update_observations(
        flipped_targets_half_start_and_end, 
        frame_to_update=8, 
        positions = [(5,6)],
        channels = [5,6,6,5]
    )

    ##  Adversarial

    # Adding Instruction Complement at (0,5) in S5
    instruction_complement_05_S5 = update_observations(
        original_obs, 
        frame_to_update=4, 
        positions = [(0, 5)],
        channels = [6,6,5,5]
    )

    # Adding Instruction Complement at (4,2) in S5
    instruction_complement_42_S5 = update_observations(
        original_obs, 
        frame_to_update=4, 
        positions = [(4,2)],
        channels = [6,6,5,5]
    )

    # Adding Instruction Complement at (0,5), (4,2) in S5
    instruction_complement_05_42_S5 = update_observations(
        original_obs, 
        frame_to_update=4, 
        positions = [(0,5),(4,2)],
        channels = [6,6,5,5]
    )

    # Adding Instruction Complement at (0,5) in S9
    instruction_complement_05_S9 = update_observations(
        original_obs, 
        frame_to_update=8, 
        positions = [(0, 5)],
        channels = [6,6,5,5]
    )

    # Adding Instruction Complement at (4,2) in S9
    instruction_complement_42_S9 = update_observations(
        original_obs, 
        frame_to_update=8, 
        positions = [(4,2)],
        channels = [6,6,5,5]
    )

    # Adding Instruction Complement at (0,5), (4,2) in S9
    instruction_complement_05_42_S9 = update_observations(
        original_obs, 
        frame_to_update=8, 
        positions = [(0,5),(4,2)],
        channels = [6,6,5,5]
    )

    # create a dictionary of all the scenarios
    scenario_dict = {
        "Original": original_obs,
        "Instruction Flipped": instruction_complement_26_S5,
        # S5
        "Complement (0,5)": instruction_complement_05_S5,
        "Complement (4,2)": instruction_complement_42_S5,
        "Complement (0,5), (4,2)": instruction_complement_05_42_S5,
        # S9
        "Instruction Complement (0,5) (S9)": instruction_complement_05_S9,
        "Instruction Complement (4,2) (S9)": instruction_complement_42_S9,
        "Instruction Complement (0,5), (4,2) (S9)": instruction_complement_05_42_S9,

        # Target changes
        "Target Complement (S5)": target_complement_S5,
        "Target Complement (S9)": target_complement_S9,
        "Target Complement (S5, S9)": target_complement_S5_S9,
    }

    return scenario_dict
    
scenario_dict = generate_scenario_dict(obs)

In [None]:
all_obs = torch.cat([scenario_dict[scenario].clone() for scenario in scenario_dict.keys()], dim=0)
updated_images = get_images_batched(env, all_obs)

def post_observation_plot(images, facet_labels, facet_col_wrap=4):

    fig = px.imshow(images, facet_col=0, facet_col_wrap=facet_col_wrap)
    # add 0 - 6 as xticks mdoe linear
    labels = [str(i) for i in list(range(0,7))]

    # compute tick positions assuming the width of the plot is normalized to 1
    tick_positions =  np.linspace(0,224,8)[1:] - 16
    # add xticks
    fig.update_xaxes(
        tickmode = 'array',
        tickvals = tick_positions,
        ticktext = labels
    )
        # add xticks
    fig.update_yaxes(
        tickmode = 'array',
        tickvals = tick_positions,
        ticktext = labels
    )
    # update facet col names with scenario labels
    fig.for_each_annotation(lambda a: a.update(text=facet_labels[int(a.text.split("=")[-1])]))
    # reduce facet col font size
    fig.for_each_annotation(lambda a: a.update(font=dict(size=18)))
    # increase font size
    fig.update_layout(font=dict(size=24))
    # make it much larger
    fig.update_layout(width=1200)

    return fig

# get images 0 2 8 12 16 with index
idxes = [0,3,8,12,16]

facet_labels = [
    "Original",
    "Instruction Flipped",
    "Complement (0,5)",
    "Complement (4,2)",
    "Complement (0,5), (4,2)",
]

fig = post_observation_plot(updated_images[idxes,...][:,4], facet_labels, facet_col_wrap=5)
fig.show()

In [None]:
idxes = [0, 22]

facet_labels = [
    "Original_S5",
    "target_complement_S5",
]

fig = post_observation_plot(updated_images[idxes,...][:,4], facet_labels, facet_col_wrap=2)
fig.show()

idxes = [0, 40]

facet_labels = [
    "Original_S9",
    "target_complement_S9",
]

fig = post_observation_plot(updated_images[idxes,...][:,8], facet_labels, facet_col_wrap=2)
fig.show()


In [None]:
# everything start once.

fig = post_observation_plot(updated_images[::4,4], list(scenario_dict.keys()))
#make the figure wider
fig.update_layout(height =1200, width=1200)
# reduce facet col font size
fig.for_each_annotation(lambda a: a.update(font=dict(size=12)))
fig.show()

In [None]:
# everything end once.

fig = post_observation_plot(updated_images[::4,8], list(scenario_dict.keys()))
#make the figure wider
fig.update_layout(height =1200, width=1200)
# reduce facet col font size
fig.for_each_annotation(lambda a: a.update(font=dict(size=12)))
fig.show()

### Generate Activation Cache and Plot Effectivess of Changes

####  Instruction and Target Patches

In [None]:
import pandas as pd
from src.streamlit_app.constants import IDX_TO_ACTION

def get_predictions_for_scenarios(dt, scenario_dict, edit_descriptions, rtg = 0.8920):

    # stack all the obs together
    all_obs = torch.cat([scenario_dict[scenario].clone() for scenario in edit_descriptions], dim=0)

    # set RTG
    rtg = torch.ones(actions.shape[0], time.shape[1]).unsqueeze(-1) * rtg

    # convert to tokens 
    num_edits = len(edit_descriptions)
    tokens = dt.to_tokens(
        all_obs, 
        actions.repeat(num_edits,1,1),
        rtg.repeat(num_edits,1,1), 
        time.repeat(num_edits,1,1))

    # run transfomer
    x, cache = dt.transformer.run_with_cache(tokens, remove_batch_dim=False)

    # get preds
    _, action_preds, _ = dt.get_logits(
        x, 
        batch_size=all_obs.shape[0], 
        seq_length=all_obs.shape[1],
        no_actions=False,  # we always pad now.
    ) # internal method that sometimes gets different args so we need to know them .

    return action_preds, tokens, cache

def compile_action_preferences_df(
        action_preds, 
        edit_descriptions, 
        scenario_labels,
        reference_control="Original",
        reference_test="Instruction Flipped"
    ):

    edit_descriptions_repeated = [item for item in edit_descriptions for i in range(len(scenario_labels))]
    logits = pd.DataFrame(action_preds[:,-1,:].detach(), columns=IDX_TO_ACTION.values())
    logits["Edit"] = edit_descriptions_repeated
    logits["scenario"] = scenario_labels * len(edit_descriptions)
    logits["left_minus_right"] = logits["left"] - logits["right"]
    logits["instruction"] = logits["scenario"].apply(lambda x: x.split(", ")[0])
    logits["target"] = logits["scenario"].apply(lambda x: x.split(", ")[1])
    
    # correct action is left if the instruction matches the left side of the target, right otherwise
    logits["correct_action"] = logits.apply(lambda x: "left" if x["instruction"] == x["target"].split("-")[0] else "right", axis=1)



    # get percent change from reference control
    original_rows = logits[logits["Edit"] == reference_control]
    original_lmr = original_rows.groupby('scenario')['left_minus_right'].first().reset_index()
    logits = pd.merge(logits, original_lmr, on='scenario', how='left', suffixes=('', '_control'))

    # get percent change from reference test
    original_rows = logits[logits["Edit"] == reference_test]
    original_lmr = original_rows.groupby('scenario')['left_minus_right'].first().reset_index()
    logits = pd.merge(logits, original_lmr, on='scenario', how='left', suffixes=('', '_test'))

    logits["percent_change"] = (logits["left_minus_right"] - logits["left_minus_right_control"]) / \
         (logits["left_minus_right_test"] - logits["left_minus_right_control"])
    
    
    keep_columns = ["scenario", "Edit", "instruction", "target", "correct_action", "left_minus_right","percent_change"]
    logits = logits[keep_columns]
    # sort by scenario
    # logits = logits.sort_values(by=["scenario", "Edit"])
    return logits

def adversarial_experiment(dt, scenario_dict, edit_descriptions, reference_control, reference_test):

    action_preds, tokens_high_rtg, cache = get_predictions_for_scenarios(dt, scenario_dict, edit_descriptions, rtg = 0.8920)
    logits_df_rtg_high = compile_action_preferences_df(action_preds, edit_descriptions, scenario_labels,
                                            reference_control =reference_control,
                                            reference_test = reference_test)
    action_preds, tokens_log_rtg, cache = get_predictions_for_scenarios(dt, scenario_dict, edit_descriptions, rtg = 0.000)
    logits_df_rtg_low = compile_action_preferences_df(action_preds, edit_descriptions, scenario_labels,
                                            reference_control = reference_control,
                                            reference_test = reference_test)


    logits_df_rtg_high["RTG"] = "0.892"
    logits_df_rtg_low["RTG"] = "0.000"

    # combine the two dataframes
    logits_df = pd.concat([logits_df_rtg_high, logits_df_rtg_low])

    # add RTG to scenario
    logits_df["scenario"] = logits_df["scenario"] + ", " + logits_df["RTG"]

    # reorder the scenarios
    # set edit order
    edit_order = [
        "Original",
        "Complement (0,5)",
        "Complement (4,2)",
        "Complement (0,5), (4,2)",
        "Instruction Flipped",
        "Target Complement (S5)",
        "Target Complement (S9)",
        "Target Complement (S5, S9)",
        "Instruction Complement (0,5) (S9)",
        "Instruction Complement (4,2) (S9)",
        "Instruction Complement (0,5), (4,2) (S9)",
    ]

    logits_df["Edit"] = pd.Categorical(logits_df["Edit"], edit_order)
    # logits_df = logits_df.sort_values(by=["scenario", "Edit"])

    tokens = torch.cat([tokens_high_rtg, tokens_log_rtg], dim=0)

    return logits_df, tokens

def get_bar_chart(
        logits_df, 
        scenario_labels,
        metric = "percent_change"):

    fig = px.bar(
        logits_df.round(2),
        x = "Edit",
        y=metric,
        color="Edit", 
        facet_col="scenario",
        facet_col_wrap=4,
        barmode="group",
        text_auto=".2f",
        template="plotly+presentation",
         color_discrete_sequence=[
        "#0068c9",
        "#83c9ff",
        "#ff2b2b",
        "#ffabab",
        "#29b09d",
        "#7defa1",
        "#ff8700",
        "#ffd16a",
        "#6d3fc0",
        "#d5dae5",
    ],
    )
    fig.update_layout(font=dict(size=24))
    #update facet col label to be scenario
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    # remove x and why labels
    fig.update_xaxes(title_text="", showticklabels=False)
    # remove y labels
    fig.update_yaxes(title_text="", showticklabels=False)
    # put facet row labels on left
    fig.update_yaxes(tickangle=0, side="left")
    # increase the size of the bar 
    fig.update_traces(marker_line_width=0, width=0.8)

    fig.update_yaxes(side='right')
    # make it taller and narrower
    fig.update_layout(height=800, width=1600)

    return fig 

# Select the scenarios to use
edit_descriptions = [
    "Original",
    "Complement (0,5)",
    "Complement (4,2)",
    "Complement (0,5), (4,2)",
    "Instruction Flipped",
]

logits_df, tokens = adversarial_experiment(dt, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Instruction Flipped")
logits_df_tmp = logits_df[logits_df["Edit"] != "Original"]
# fig = get_bar_chart(logits_df_tmp, scenario_labels, metric="left_minus_right")
fig = get_bar_chart(logits_df_tmp, scenario_labels)
fig.show()

# Select the scenarios to use
edit_descriptions = [
    "Original",
    "Target Complement (S5)",
    "Target Complement (S9)",
    "Target Complement (S5, S9)",
    ]

logits_df, tokens = adversarial_experiment(dt, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Target Complement (S5, S9)")
logits_df_tmp = logits_df[logits_df["Edit"] != "Original"]
# fig = get_bar_chart(logits_df_tmp, scenario_labels, metric="left_minus_right")
fig = get_bar_chart(logits_df_tmp, scenario_labels)
fig.show()

In [None]:
logits_df

#### Same Experiments with MemoryDT GatedMLP

In [None]:
edit_descriptions = [
    "Original",
    "Complement (0,5)",
    "Complement (4,2)",
    "Complement (0,5), (4,2)",
    "Instruction Flipped",
]


logits_df = adversarial_experiment(dt_gated_mlp, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Instruction Flipped")
logits_df_tmp = logits_df[logits_df["Edit"] != "Original"]
fig = get_bar_chart(logits_df_tmp, scenario_labels)
fig.show()

# Select the scenarios to use
edit_descriptions = [
    "Original",
    "Target Complement (S5)",
    "Target Complement (S9)",
    "Target Complement (S5, S9)",
    ]

logits_df = adversarial_experiment(dt_gated_mlp, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Target Complement (S5, S9)")
logits_df_tmp = logits_df[logits_df["Edit"] != "Original"]
fig = get_bar_chart(logits_df_tmp, scenario_labels)
fig.show()

## Intervening Directly on Principle Components (Features)



### Load in Features from App

Here we experiment with intervening directly on the principle components found. We load these in from the app.

In [390]:
from src.streamlit_app.constants import STATE_EMBEDDING_LABELS
from src.streamlit_app.features import load_features

# load embeddings in 
embeddings = dt.state_embedding.weight.detach()
print(embeddings.shape)

features, feature_metadata = load_features("../features")
print(features.shape)
feature_metadata

torch.Size([256, 980])
torch.Size([16, 256])


Unnamed: 0,file_name,model,feature_names,generated_via,embeddings_idx,embeddings_labels,timestamp,feature_idx
0,target_features,MemoryDT,targets_at_end,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,targets_at_end
1,target_features,MemoryDT,targets_at_beginning,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,targets_at_beginning
2,target_features,MemoryDT,PC3,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC3_target_features
3,target_features,MemoryDT,PC4,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC4_target_features
4,target_features,MemoryDT,PC5,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC5_target_features
5,target_features,MemoryDT,PC6,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC6_target_features
6,target_features,MemoryDT,PC7,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC7_target_features
7,target_features,MemoryDT,PC8,PCA-Embeddings-Subsets,"[258, 307, 335, 286, 331, 282, 254, 303]","[key, (1,6), ball, (1,6), ball, (5,6), key, (5...",TODO: add timestamp,PC8_target_features
8,instructions,MemoryDT,instruction_feature_1,PCA-Embeddings-Subsets,"[299, 250, 325, 276, 314, 265, 324, 275]","[ball, (0,5), key, (0,5), ball, (4,3), key, (4...",TODO: add timestamp,instruction_feature_1
9,instructions,MemoryDT,instruction_feature_2,PCA-Embeddings-Subsets,"[299, 250, 325, 276, 314, 265, 324, 275]","[ball, (0,5), key, (0,5), ball, (4,3), key, (4...",TODO: add timestamp,instruction_feature_2


In [391]:
fig = px.imshow(features @ features.T, 
          template="plotly_white", 
          color_continuous_midpoint=0, 
          color_continuous_scale="RdBu") # check they are orthogonal (generated by PCA).

fig.update_layout(width=1200, height=1200)
fig.show()

In [392]:
# project each ontop embeddings and make a dataframe (features[instruction_0_idx] @ embeddings )

instruction_0_idx = feature_metadata[feature_metadata["feature_names"] == "instruction_feature_1"].index[0]
instruction_1_idx = feature_metadata[feature_metadata["feature_names"] == "instruction_feature_2"].index[0]
target_0_idx = feature_metadata[feature_metadata["feature_names"] == "targets_at_end"].index[0]
target_1_idx = feature_metadata[feature_metadata["feature_names"] == "targets_at_beginning"].index[0]

instruction_feature_1 = features[instruction_0_idx]
instruction_feature_2 = features[instruction_1_idx]
targets_at_end = features[target_0_idx]
targets_at_beginning = features[target_1_idx]

df = pd.DataFrame(
    {   "label": STATE_EMBEDDING_LABELS,
        "channel": [i.split(",")[0] for i in STATE_EMBEDDING_LABELS],
        "instruction_0": features[instruction_0_idx] @ normalized_embeddings,
        "instruction_1": features[instruction_1_idx] @ normalized_embeddings,
        "target_0": features[target_0_idx] @ normalized_embeddings,
        "target_1": features[target_1_idx] @ normalized_embeddings,
    }
)

In [393]:
fig = px.strip(df, y=["instruction_0", "instruction_1", "target_0", "target_1"], hover_data=["label"], template = "plotly_white")
#update figure size
fig.update_layout(height=800, width=1000)
fig.show()

In [394]:
fig = px.scatter(
    df, 
    x= "instruction_0",
    y="instruction_1",
    color="channel",
    hover_data=["label"],
    template="plotly",
)
fig.update_layout(height=800, width=1000)
fig.show()

### Distribution of PC's in adversarial Examples

In [395]:
# Select the scenarios to use
edit_descriptions = [
    "Original",
    "Complement (0,5)",
    "Complement (4,2)",
    "Complement (0,5), (4,2)",
    "Instruction Flipped",
]

logits_df, tokens = adversarial_experiment(dt, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Instruction Flipped")

token_embeddings = tokens[:,13,:].detach()


instruction_feature_1 = features[instruction_0_idx]
instruction_feature_2 = features[instruction_1_idx]
targets_at_end = features[target_0_idx]
targets_at_beginning = features[target_1_idx]


df = pd.DataFrame(
    { 
        "Instruction PC0": instruction_feature_1 @ token_embeddings.T,
        "Instruction PC1": instruction_feature_2 @ token_embeddings.T,
    },
    index = logits_df.index
)

logits_df = pd.concat([logits_df, df], axis=1)

def plot_feature_against_logit_difference(logits_df, feature="Instruction PC0"):

    fig = px.scatter(
        logits_df, 
        x=feature,
        y="left_minus_right",
        color="Edit",
        facet_col="scenario",
        facet_col_wrap=4,
        template="plotly+presentation",
        color_discrete_sequence=[
            "#0068c9",
            "#83c9ff",
            "#ff2b2b",
            "#ffabab",
            "#29b09d",
            "#7defa1",
            "#ff8700",
            "#ffd16a",
            "#6d3fc0",
            "#d5dae5",
        ],
        labels={
            feature: "Feature Projection",
        },
    )
    #make points larger
    fig.update_traces(marker=dict(size=26))
    # make the font larger
    fig.update_layout(font=dict(size=20))
    # remove scenario= from facet col labels
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.update_layout(height=800, width=1600)
    return fig 

fig = plot_feature_against_logit_difference(logits_df, feature="Instruction PC0")
fig.show()

This is a great result. You can see that as we add adversaries, we flip the feature, and as the feature flips, the left - right logit difference changes almost exactly linearly. The same can't be said for PC1, which doesn't have a first order effect that can be easily visualized. 

In [396]:
fig = plot_feature_against_logit_difference(logits_df, feature="Instruction PC1")
fig.show()

Instruction PC1 doesn't as obviously with the adversaries. So the adversarial examples aren't as good for studying PC1. 

In [397]:
edit_descriptions = [
    "Original",
    "Target Complement (S5)",
    "Target Complement (S9)",
    "Target Complement (S5, S9)",
    ]

logits_df, tokens = adversarial_experiment(dt, scenario_dict, edit_descriptions, reference_control="Original", reference_test="Target Complement (S5, S9)")

token_embeddings_S5 = tokens[:,13,:].detach()
token_embeddings_S9 = tokens[:,-1,:].detach()

# I mislabelled them when I saved the PC's. 
df = pd.DataFrame(
    { 
        "S9 Target Feature at S5": targets_at_beginning @ token_embeddings_S5.T,
        "S9 Target Feature at S9": targets_at_beginning @ token_embeddings_S9.T,
        "S5 Target Feature at S5": targets_at_end @ token_embeddings_S5.T,
        "S5 Target Feature at S9": targets_at_end @ token_embeddings_S9.T,
    },
    index = logits_df.index
)

logits_df = pd.concat([logits_df, df], axis=1)

plot_feature_against_logit_difference(logits_df, feature="S5 Target Feature at S5").show()

Here we see the the S5 target feature does vary when we vary the targets. However, since we're looking at the projection of the S5 position of the residual stream, changes to S9 along have no effect on the feature. 
On the other hand, changes to the S5 feature don't effect the logit different very much, hence the square pattern.

In [398]:
plot_feature_against_logit_difference(logits_df, feature="S9 Target Feature at S9").show()

Here we see the the S9 target feature and left minus right logit difference for each of the target situations. Clearly flipping the target at S9 flips the target feature AND changes the logit difference. 

In [399]:
# plot_feature_against_logit_difference(logits_df, feature="S9 Target Feature at S5").show() # S9 target feature is not effected by changes at S5

### Continuous Feature Intervention

In [400]:
import pandas as pd

def get_hook(position, coefficients, feature):
    '''
    - position: int
    - coefficients: (batch_size,) A vector of coefficients for the feature to be at in the batch. 
    - feature: (hidden_dim,)
    '''

    cached_embeddings = []
    # normalize the feature 
    feature = feature / torch.norm(feature)
    
    def hook(resid_pre, hook):
        '''
        This hook projects the resid pre at a layer to be orthogonal to the feature vector
        and then adds the feature add some strength level. 

        - resid_pre: (batch_size, seq_length, hidden_dim)
        - hook: (batch_size, hidden_dim)
        '''
        embedding = resid_pre[:,position,:].clone()

        # projection into Feature 0. 
        original_project = (feature @ embedding.T)
        # scale the projection at that position 
        orthogonal_proj = embedding - original_project[:,None] * feature # remove current level
        embedding =  orthogonal_proj + coefficients[:,None] * feature # new level

        cached_embeddings.append(embedding.detach())
        # set the new embedding
        resid_pre[:,position,:] = embedding

    return hook, cached_embeddings


def get_predictions_for_scenarios_with_hooks(
        dt, 
        scenario_dict, 
        edit_descriptions, 
        editing_hooks,
        rtg = 0.8920):
    ''' 
    # example editing hooks. 
    # get hook
    ave_hook = get_hook(
        position=13,
        coeff = 1, 
        feature=features[0]
    )
    editing_hooks = [(f"blocks.0.hook_resid_pre", ave_hook)]
    
    '''

    # stack all the obs together
    all_obs = torch.cat([scenario_dict[scenario].clone() for scenario in edit_descriptions], dim=0)

    # set RTG
    rtg = torch.ones(actions.shape[0], time.shape[1]).unsqueeze(-1) * rtg

    # convert to tokens 
    num_edits = len(edit_descriptions)
    tokens = dt.to_tokens(
        all_obs, 
        actions.repeat(num_edits,1,1),
        rtg.repeat(num_edits,1,1), 
        time.repeat(num_edits,1,1))

    # run transfomer with hook
    with dt.transformer.hooks(fwd_hooks=editing_hooks):
        x, cache = dt.transformer.run_with_cache(tokens.clone(), remove_batch_dim=False)

    # get preds
    _, action_preds, _ = dt.get_logits(
        x, 
        batch_size=all_obs.shape[0], 
        seq_length=all_obs.shape[1],
        no_actions=False,  # we always pad now.
    ) # internal method that sometimes gets different args so we need to know them .

    return action_preds, tokens, cache


def calculate_percent_change(df, grouping_col, value_col, reference_control, reference_test):
    # Get percent change from reference control
    original_rows_control = df[df["Edit"] == reference_control]
    original_values_control = original_rows_control.groupby(grouping_col)[value_col].first().reset_index()
    df = pd.merge(df, original_values_control, on=grouping_col, how='left', suffixes=('', '_control'))

    # Get percent change from reference test
    original_rows_test = df[df["Edit"] == reference_test]
    original_values_test = original_rows_test.groupby(grouping_col)[value_col].first().reset_index()
    df = pd.merge(df, original_values_test, on=grouping_col, how='left', suffixes=('', '_test'))

    # Calculate percent change
    df["percent_change"] = (df[value_col] - df[f"{value_col}_control"]) / \
                           (df[f"{value_col}_test"] - df[f"{value_col}_control"])
    
    return df


def feature_range_experiment(dt, scenario_dict, feature, scenario_labels, injection_position = 13, injection_magnitudes = torch.tensor([-2,-1,0,1,2])):

    edit_descriptions = ["Original"] * injection_magnitudes.shape[0]

    ave_hook, cached_embeddings = get_hook(
        position=injection_position,
        coefficients= injection_magnitudes.repeat(4,1).T.flatten(), # we reorder so this needs to added in in chunks of 4. 
        feature=feature
    )
    editing_hooks = [(f"blocks.0.hook_resid_pre", ave_hook)]

    action_preds, embeddings_1, cache_1 = get_predictions_for_scenarios_with_hooks(dt, scenario_dict, edit_descriptions, editing_hooks, rtg = 0.8920)
    logits_df_rtg_high = compile_action_preferences_df(action_preds, edit_descriptions, scenario_labels,
                                            reference_control = "-2",
                                            reference_test = "2")
    
    logits_df_rtg_high["feature_before"] = feature @ embeddings_1[:,injection_position].detach().T
    logits_df_rtg_high["feature_after"] = feature @ cached_embeddings.pop().detach().T


    action_preds, embeddings_2, cache_2 = get_predictions_for_scenarios_with_hooks(dt, scenario_dict, edit_descriptions, editing_hooks, rtg = 0.000)
    logits_df_rtg_low = compile_action_preferences_df(action_preds, edit_descriptions, scenario_labels,
                                            reference_control = "-2",
                                            reference_test = "2")


    logits_df_rtg_low["feature_before"] = feature @ embeddings_2[:,13].detach().T
    logits_df_rtg_low["feature_after"] = feature @ cached_embeddings.pop().detach().T

    logits_df_rtg_high["RTG"] = "0.892"
    logits_df_rtg_low["RTG"] = "0.000"

    # combine the two dataframes
    logits_df = pd.concat([logits_df_rtg_high, logits_df_rtg_low])
    logits_df["Edit"] = torch.tensor([-2,-1,0,1,2]).repeat(4,1).T.flatten().repeat(2)
    logits_df["scenario"] = logits_df["scenario"]  + ", " + logits_df["RTG"]
    # reorder the scenarios
    # set edit order
    # edit_order = nice_names

    # logits_df["Edit"] = pd.Categorical(logits_df["Edit"], edit_order)
    # logits_df = logits_df.sort_values(by=["scenario", "Edit"])
    logits_df.reset_index(inplace=True, drop=True)
    # logits_df.head()

    return logits_df

def feature_experiment_bar_chart(
        logits_df,
        normalize = False,
        negative_control = None, 
        positive_control = None,
        facet_col_wrap = 4,
        ):
    metric = "left_minus_right"

    if normalize:
        logits_df = calculate_percent_change(logits_df, "scenario", "left_minus_right", negative_control, positive_control)
        metric = "percent_change"

    fig = px.bar(
        logits_df.round(2),
        x = "Edit",
        y=metric,
        color="Edit", 
        facet_col="scenario",
        facet_col_wrap=facet_col_wrap,
        barmode="group",
        text_auto=".2f",
        template="plotly+presentation",
            color_discrete_sequence=[
        "#0068c9",
        "#83c9ff",
        "#ff2b2b",
        "#ffabab",
        "#29b09d",
        "#7defa1",
        "#ff8700",
        "#ffd16a",
        "#6d3fc0",
        "#d5dae5",
    ],
    )
    # fig.update_layout(font=dict(size=18))
    #update facet col label to be scenario
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

    # reduce facet col font size
    fig.for_each_annotation(lambda a: a.update(font=dict(size=24)))
    # remove x and why labels
    fig.update_xaxes(title_text="", showticklabels=False)
    # remove y labels
    fig.update_yaxes(title_text="", showticklabels=False)
    # put facet row labels on left
    fig.update_yaxes(tickangle=0, side="left")
    # increase the size of the bar 
    fig.update_traces(marker_line_width=0, width=0.8)

    fig.update_yaxes(side='right')
    # make it taller and narrower
    fig.update_layout(height=800, width=1600)

    return fig

def feature_experiment_scatter_chart(
        logits_df,
        normalize = False,
        negative_control = None, 
        positive_control = None,
        facet_col_wrap = 4,
        ):
    metric = "left_minus_right"

    if normalize:
        logits_df = calculate_percent_change(logits_df, "scenario", "left_minus_right", negative_control, positive_control)
        metric = "percent_change"

    fig = px.scatter(
        logits_df.round(2),
        x = "Edit",
        y=metric,
        color="Edit", 
        facet_col="scenario",
        facet_col_wrap=facet_col_wrap,
        # barmode="group",
        # text_auto=".2f",
        template="plotly+presentation",
        color_continuous_scale="RdBu",
    )

    #make points larger
    fig.update_traces(marker=dict(size=26))

    # add border to points
    fig.update_traces(marker_line_width=1, marker_line_color="black")

    # make the font larger
    fig.update_layout(font=dict(size=20))
    # remove scenario= from facet col labels
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.update_layout(height=800, width=1600)

    return fig

In [401]:
logits_df_instruction_feature_0 = feature_range_experiment(
    dt, scenario_dict, feature = instruction_feature_1, 
    injection_position=13,
    scenario_labels = scenario_labels)

logits_df_instruction_feature_1 = feature_range_experiment(
    dt, scenario_dict, feature = instruction_feature_2, 
    injection_position=13,
    scenario_labels = scenario_labels)

logits_df_both_instruction_features = feature_range_experiment(
    dt, scenario_dict, feature = instruction_feature_1 * 0.70 + instruction_feature_2*0.15,
    injection_position=13,
    scenario_labels = scenario_labels)

In [402]:
feature_experiment_scatter_chart(logits_df_instruction_feature_0, normalize = False).show()
feature_experiment_scatter_chart(logits_df_instruction_feature_1, normalize = False).show()
# feature_experiment_scatter_chart(logits_df_both_instruction_features, normalize = False).show()

# feature_experiment_bar_chart(logits_df_instruction_feature_0, normalize = False, positive_control=2,negative_control=-2).show()
# feature_experiment_bar_chart(logits_df_instruction_feature_1, normalize = False, positive_control=2,negative_control=-2).show()
# feature_experiment_bar_chart(logits_df_both_instruction_features, normalize = False, positive_control=2,negative_control=-2).show()

In [403]:
logits_df_instruction_feature_0_S9 = feature_range_experiment(
    dt, scenario_dict, feature = instruction_feature_1, 
    injection_position=-1,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_instruction_feature_0_S9, normalize = False).show()

In [404]:

logits_df_instruction_feature_1_S9 = feature_range_experiment(
    dt, scenario_dict, feature = instruction_feature_2, 
    injection_position=-1,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_instruction_feature_1_S9, normalize = False).show()

In [405]:
logits_df_target_feature_0_S5 = feature_range_experiment(
    dt, scenario_dict, feature = targets_at_end, # should be at beginning. 
    injection_position=13,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_target_feature_0_S5, normalize = False).show()

logits_df_target_feature_0_S9 = feature_range_experiment(
    dt, scenario_dict, feature = targets_at_end, # should be at beginning. 
    injection_position=-1,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_target_feature_0_S9, normalize = False).show()

In [406]:

logits_df_target_feature_1_S5 = feature_range_experiment(
    dt, scenario_dict, feature = targets_at_beginning, # should be at end 
    injection_position=13,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_target_feature_1_S5, normalize = False).show()


logits_df_target_feature_1_S9 = feature_range_experiment(
    dt, scenario_dict, feature = targets_at_beginning, # should be at end 
    injection_position=-1,
    scenario_labels = scenario_labels)

feature_experiment_scatter_chart(logits_df_target_feature_1_S9, normalize = False).show()