In [10]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

In [21]:
import torch as t
import numpy as np
import pandas as pd
import einops
import plotly.express as px

from utils import ntensor_to_long

In [30]:
df_data = pd.read_csv(f"{ROOT}/datasets/cities.csv")
df_data.head()

Unnamed: 0,statement,label,city,country,correct_country
0,The city of Krasnodar is in Russia.,1,Krasnodar,Russia,Russia
1,The city of Krasnodar is in South Africa.,0,Krasnodar,South Africa,Russia
2,The city of Lodz is in Poland.,1,Lodz,Poland,Poland
3,The city of Lodz is in the Dominican Republic.,0,Lodz,the Dominican Republic,Poland
4,The city of Maracay is in Venezuela.,1,Maracay,Venezuela,Venezuela


In [16]:
activations = t.load(f"{ROOT}/activations/llama2-7b_cities.pt")
"shape: (statement layer pos d_model)"
directions = t.load(f"{ROOT}/directions/llama2-7b_cities_mm.pt")
"shape: (layer pos d_model)"
print(activations.shape, directions.shape)

torch.Size([1496, 32, 2, 4096]) torch.Size([32, 2, 4096])


In [96]:
old_dirs = t.load("directions/llama2-7b_cities_mm_old.pt").to("cpu")
t.allclose(old_dirs, directions[:, -1], atol=5e-4)

True

In [112]:
projections = einops.einsum(
    activations,
    directions / directions.norm(dim=-1, keepdim=True),
    "statement layer pos d_model, layer pos d_model -> pos layer statement"
)
projections = (projections - projections.mean(dim=-1, keepdim=True)) / projections.std(dim=-1, keepdim=True)
"shape: (pos layer statement)"

'shape: (pos layer statement)'

In [113]:
df = ntensor_to_long(projections, "projection", ["pos", "layer", "statement"])
df["pos"] = df["pos"].map({0: "penultimate", 1: "final"})
df["label"] = np.tile(df_data["label"].tolist(), 32 * 2)
df["label"] = df["label"].map({0: False, 1: True})

In [153]:
ylims = 1.1 * df.query("pos == 'final'").projection.abs().max()
fig_final = px.scatter(
    df.query("pos == 'final'"),
    x="statement",
    y="projection",
    animation_frame="layer",
    color="label",
    height=600,
    width=800,
)
fig_final.update_layout(
    title="Projection of resid_post onto factuality direction at the final token",
    xaxis_title="Cities Statement Index",
    yaxis_title="Projection (Standardized)",
    yaxis_range=[-ylims, ylims],
)

In [154]:
ylims = 1.1 * df.query("pos == 'penultimate'").projection.abs().max()
fig_penultimate = px.scatter(
    df.query("pos == 'penultimate'"),
    x="statement",
    y="projection",
    animation_frame="layer",
    color="label",
    height=600,
    width=800,
)
fig_penultimate.update_layout(
    title="Projection of resid_post onto factuality direction at the penultimate token",
    xaxis_title="Cities Statement Index",
    yaxis_title="Projection (Standardized)",
    yaxis_range=[-ylims, ylims],
)

In [155]:
fig_final.write_html(f"{ROOT}/figs/proj_final_llama2-7b_cities.html")
fig_penultimate.write_html(f"{ROOT}/figs/proj_penultimate_llama2-7b_cities.html")

In [170]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

final_dirs_normed = directions[:, 1] / directions[:, 1].norm(dim=-1, keepdim=True)
penult_dirs_normed = directions[:, 0] / directions[:, 0].norm(dim=-1, keepdim=True)

In [188]:
fig = make_subplots(rows=1, cols=3)
fig.add_trace(go.Heatmap(z=final_dirs_normed @ final_dirs_normed.T, coloraxis="coloraxis"), row=1, col=1)
fig.add_trace(go.Heatmap(z=penult_dirs_normed @ penult_dirs_normed.T, coloraxis="coloraxis"), row=1, col=2)
fig.add_trace(go.Heatmap(z=final_dirs_normed @ penult_dirs_normed.T, coloraxis="coloraxis"), row=1, col=3)
fig.update_xaxes(title_text="Layer @ Final Token", row=1, col=1)
fig.update_yaxes(title_text="Layer @ Final Token", row=1, col=1)
fig.update_xaxes(title_text="Layer @ Penultimate Token", row=1, col=2)
fig.update_yaxes(title_text="Layer @ Penultimate Token", row=1, col=2)
fig.update_xaxes(title_text="Layer @ Penultimate Token", row=1, col=3)
fig.update_yaxes(title_text="Layer @ Final Token", row=1, col=3)
fig.update_layout(
    title_text="Cosine similarities of factuality directions",
    coloraxis=dict(colorscale='RdBu', cmin=-1, cmax=1),
    height=500, width=1300,
)

In [189]:
fig.write_html(f"{ROOT}/figs/direction_cosims_llama2-7b_cities.html")