In [1]:
from transformers import AutoTokenizer
import torch
import numpy as np

model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [112]:
n_layers = 32
n_heads = 32
n_key_value_heads = 8

### Plotting

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_atp(atp, str_tokens, component, n_last_tokens=128, val=1, prepend_bos=True):

    xs = [f"{tok} | {i}" for i, tok in enumerate(str_tokens[-n_last_tokens:])]
    
    if component in ['z', 'q', 'result']:
        ys = [f'L{i}H{j}' for i in range(n_layers) for j in range(n_heads)]
    elif component in ['k', 'v']:
        ys = [f'L{i}{component.upper()}{j}' for i in range(n_layers) for j in range(n_key_value_heads)]
    else:
        ys = [f"L{l} {component.upper()}" for l in range(n_layers)]
        
    fig = px.imshow(
        atp[:, -n_last_tokens:].cpu().numpy(), 
        x=xs,
        y=ys,
        color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
    )
    
    return fig

In [None]:
def plot_qkv_atp(str_tokens, atp_q, atp_k, atp_v, n_last_tokens=256, val=1, **kwargs):
    fig = make_subplots(rows=3, cols=1, subplot_titles=("Queries", "Keys", "Values"), shared_xaxes=True, vertical_spacing=0.05)
    
    for i, (atp, hook) in enumerate(zip([atp_q, atp_k, atp_v], ['q', 'k', 'v'])):
        plot = plot_atp(atp, str_tokens, component=hook, prepend_bos=False, n_last_tokens=n_last_tokens)
        for trace in plot.data:
            fig.add_trace(trace, row=1+i, col=1)
        
    fig.update_layout(
        coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
        showlegend=False,
        **kwargs
    )
    
    return fig

def plot_comp_atp(str_tokens, atp_rs, atp_mlp, atp_attn, n_last_tokens=256, val=1, **kwargs):
    fig = make_subplots(rows=3, cols=1, subplot_titles=("Residual Stream", "MLPs", "Attention"), shared_xaxes=True, vertical_spacing=0.05)
    
    for i, (atp, hook) in enumerate(zip([atp_rs, atp_mlp, atp_attn], ['resid_pre', 'mlp_out', 'attn_out'])):
        plot = plot_atp(atp, str_tokens, component=hook, prepend_bos=False, n_last_tokens=n_last_tokens)
        for trace in plot.data:
            fig.add_trace(trace, row=1+i, col=1)
        
    fig.update_layout(
        coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
        showlegend=False,
        **kwargs
    )
    
    return fig

In [176]:
def plot_patterns(str_tokens, patterns, patterns_label, n_cols, query_offset, key_offset):
    
    n_rows = len(patterns) // n_cols + int(len(patterns) % n_cols != 0)
    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=patterns_label, vertical_spacing=0.05, horizontal_spacing=0.05)

    labels = [f"{tok} ({i})" for i, tok in enumerate(str_tokens)]
    query_labels = labels[query_offset:]
    key_labels = labels[key_offset:]

    for i, (pattern, p_label) in enumerate(zip(patterns, patterns_label)):
        layer, head = p_label.split('H')
        layer = int(layer[1:])
        head = int(head[:-1])

        row = i // n_cols + 1
        col = i % n_cols + 1
    
        fig.add_trace(px.imshow(
            pattern[query_offset:, key_offset:],
            labels=dict(x="Keys", y="Queries", color="Attention Score"),
            x=key_labels,
            y=query_labels
        ).data[0], row=row, col=col)

        fig.update_xaxes(tickangle=35)
        fig.update_layout(coloraxis_colorbar=dict(title="Score"))

    fig.update_layout(
        coloraxis=dict(
            colorscale='Blues',
            cmin=1,
            cmax=0
        ),
        height=400 * n_rows,
        width=500 * n_cols,
        title_text="Attention Patterns"
    )
    return fig

### Data

In [147]:
model_label = model_name.split('/')[-1]

idx = 2

clean_tokens_v1 = torch.load(f'patches/{idx}_clean_tokens.bin', map_location=torch.device('cpu'))
corr_tokens_v1 = torch.load(f'patches/{idx}_corr_tokens.bin', map_location=torch.device('cpu'))

clean_str_tokens_v1 = [tok.replace('Ġ', '')for tok in tokenizer.convert_ids_to_tokens(clean_tokens_v1)]
corr_str_tokens_v1 = [tok.replace('Ġ', '')for tok in tokenizer.convert_ids_to_tokens(corr_tokens_v1)]

clean_tokens_v2 = torch.load(f'patches/{idx}_clean_tokens_v2.bin', map_location=torch.device('cpu'))
corr_tokens_v2 = torch.load(f'patches/{idx}_corr_tokens_v2.bin', map_location=torch.device('cpu'))

clean_str_tokens_v2 = [tok.replace('Ġ', '')for tok in tokenizer.convert_ids_to_tokens(clean_tokens_v2)]
corr_str_tokens_v2 = [tok.replace('Ġ', '')for tok in tokenizer.convert_ids_to_tokens(corr_tokens_v2)]

In [148]:
def load_patch(idx, comp, patch):
    return torch.load(f'patches/{idx}_s1_AtP_{comp}{patch}.bin')

## Version 1

### Components

In [149]:
atp_rs_v1 = load_patch(idx, 'resid_pre', '') 
atp_mlp_v1 = load_patch(idx, 'mlp_out', '') 
atp_attn_v1 = load_patch(idx, 'attn_out', '')

In [150]:
fig = plot_comp_atp(clean_str_tokens_v1, atp_rs_v1, atp_mlp_v1, atp_attn_v1, val=1)
fig.update_layout(height=1200)

### Heads

In [151]:
atp_h_v1 = load_patch(idx, 'result', '')

In [152]:
fig = px.histogram(atp_h_v1.max(1).values)
fig.add_vline(x=0.05, line_dash="dash", line_color="red")
fig.show()

In [153]:
h_ys = [f'L{i}H{j}' for i in range(n_layers) for j in range(n_heads)]

In [154]:
threshold = 0.2

h_mask = atp_h_v1.abs().max(1).values > threshold
h_ys_ = np.array(h_ys)[h_mask]

len(h_ys_)

18

In [155]:
value_mask = atp_h_v1.abs().argmax(1)[h_mask]
heads_v1 = [y + '+' if s > 0 else y + '-' for y, s in zip(h_ys_, atp_h_v1[h_mask, value_mask])]

In [156]:
n_last_tokens = 256
val = 1
xs = [f"{tok} | {i}" for i, tok in enumerate(clean_str_tokens_v1[-n_last_tokens:])]

fig = px.imshow(
    atp_h_v1[h_mask, -n_last_tokens:].cpu().numpy(), 
    x=xs,
    y=h_ys_,
    color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
)

fig.update_layout(
    coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
    showlegend=False,
    height = 800
)

### QKV

In [157]:
atp_q_v1 = load_patch(idx, 'q', '') 
atp_k_v1 = load_patch(idx, 'k', '') 
atp_v_v1 = load_patch(idx, 'v', '') 

In [158]:
fig = make_subplots(rows=3, cols=1, subplot_titles=("Residual Stream", "MLPs", "Attention"), shared_xaxes=True, vertical_spacing=0.05)
    
for i, atp in enumerate([atp_q_v1, atp_k_v1, atp_v_v1]):
    plot = px.histogram(atp.max(1).values)
    for trace in plot.data:
        fig.add_trace(trace, row=1+i, col=1)

fig.add_vline(x=0.05, line_dash="dash", line_color="red")
fig.update_layout(height=600)
fig.show()

In [159]:
q_ys = [f'L{i}H{j}' for i in range(n_layers) for j in range(n_heads)]
k_ys = [f'L{i}K{j}' for i in range(n_layers) for j in range(n_key_value_heads)]
v_ys = [f'L{i}V{j}' for i in range(n_layers) for j in range(n_key_value_heads)]

In [160]:
threshold = 0.2

q_mask = atp_q_v1.abs().max(1).values > threshold
k_mask = atp_k_v1.abs().max(1).values > threshold
v_mask = atp_v_v1.abs().max(1).values > threshold

In [161]:
q_ys_ = np.array(q_ys)[q_mask]
k_ys_ = np.array(k_ys)[k_mask]
v_ys_ = np.array(v_ys)[v_mask]

len(q_ys_), len(k_ys_), len(v_ys_)

(10, 5, 9)

In [162]:
fig = make_subplots(rows=3, cols=1, subplot_titles=("Queries", "Keys", "Values"), shared_xaxes=True, vertical_spacing=0.05)
n_last_tokens = 256
val = 1
xs = [f"{tok} | {i}" for i, tok in enumerate(clean_str_tokens_v1[-n_last_tokens:])]

for i, (atp, ys, mask) in enumerate(zip([atp_q_v1, atp_k_v1, atp_v_v1], [q_ys_, k_ys_, v_ys_], [q_mask, k_mask, v_mask])):
    plot = px.imshow(
        atp[mask, -n_last_tokens:].cpu().numpy(), 
        x=xs,
        y=ys,
        color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
    )
    for trace in plot.data:
        fig.add_trace(trace, row=1+i, col=1)
    
fig.update_layout(
    coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
    showlegend=False,
    height = 1200
)

### Patterns

In [163]:
clean_patterns_v1 = torch.load(f'patches/{idx}_patterns_clean_v1.bin')
corr_patterns_v1 = torch.load(f'patches/{idx}_patterns_corr_v1.bin')

In [178]:
plot_str_tokens = [tok[-10:] for tok in corr_str_tokens_v1]
plot_patterns(plot_str_tokens, corr_patterns_v1, heads_v1, 3, 400, 400)

## Version 2

### Components

In [181]:
atp_rs_v2 = load_patch(idx, 'resid_pre', '_v2') 
atp_mlp_v2 = load_patch(idx, 'mlp_out', '_v2') 
atp_attn_v2 = load_patch(idx, 'attn_out', '_v2') 

In [182]:
fig = plot_comp_atp(clean_str_tokens_v2, atp_rs_v2, atp_mlp_v2, atp_attn_v2, val=1.5)
fig.update_layout(height=1200)

### Heads

In [183]:
atp_h_v2= load_patch(idx, 'result', '_v2')

In [184]:
fig = px.histogram(atp_h_v2.max(1).values)
fig.add_vline(x=0.05, line_dash="dash", line_color="red")
fig.show()

In [185]:
h_ys = [f'L{i}H{j}' for i in range(n_layers) for j in range(n_heads)]

In [186]:
threshold = 0.2

h_mask = atp_h_v2.abs().max(1).values > threshold
h_ys_ = np.array(h_ys)[h_mask]

len(h_ys_)

18

In [187]:
value_mask = atp_h_v2.abs().argmax(1)[h_mask]
heads_v2 = [y + '+' if s > 0 else y + '-' for y, s in zip(h_ys_, atp_h_v2[h_mask, value_mask])]

In [188]:
n_last_tokens = 256
val = 1
xs = [f"{tok} | {i}" for i, tok in enumerate(clean_str_tokens_v2[-n_last_tokens:])]

fig = px.imshow(
    atp_h_v2[h_mask, -n_last_tokens:].cpu().numpy(), 
    x=xs,
    y=h_ys_,
    color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
)

fig.update_layout(
    coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
    showlegend=False,
    height = 800
)

### QKV

In [189]:
atp_q_v2 = load_patch(idx, 'q', '_v2') 
atp_k_v2 = load_patch(idx, 'k', '_v2') 
atp_v_v2 = load_patch(idx, 'v', '_v2') 

In [190]:
fig = make_subplots(rows=3, cols=1, subplot_titles=("Residual Stream", "MLPs", "Attention"), shared_xaxes=True, vertical_spacing=0.05)
    
for i, atp in enumerate([atp_q_v2, atp_k_v2, atp_v_v2]):
    plot = px.histogram(atp.max(1).values)
    for trace in plot.data:
        fig.add_trace(trace, row=1+i, col=1)

fig.add_vline(x=0.05, line_dash="dash", line_color="red")
fig.update_layout(height=600)
fig.show()

In [191]:
q_ys = [f'L{i}H{j}' for i in range(n_layers) for j in range(n_heads)]
k_ys = [f'L{i}K{j}' for i in range(n_layers) for j in range(n_key_value_heads)]
v_ys = [f'L{i}V{j}' for i in range(n_layers) for j in range(n_key_value_heads)]

In [192]:
threshold = 0.1

q_mask = atp_q_v2.abs().max(1).values > threshold
k_mask = atp_k_v2.abs().max(1).values > threshold
v_mask = atp_v_v2.abs().max(1).values > threshold

In [193]:
import numpy as np

q_ys_ = np.array(q_ys)[q_mask]
k_ys_ = np.array(k_ys)[k_mask]
v_ys_ = np.array(v_ys)[v_mask]

len(q_ys_), len(k_ys_), len(v_ys_)

(25, 12, 18)

In [194]:
fig = make_subplots(rows=3, cols=1, subplot_titles=("Queries", "Keys", "Values"), shared_xaxes=True, vertical_spacing=0.05)
n_last_tokens = 512
val = 1
xs = [f"{tok} | {i}" for i, tok in enumerate(clean_str_tokens_v2[-n_last_tokens:])]

for i, (atp, ys, mask) in enumerate(zip([atp_q_v2, atp_k_v2, atp_v_v2], [q_ys_, k_ys_, v_ys_], [q_mask, k_mask, v_mask])):
    plot = px.imshow(
        atp[mask, -n_last_tokens:].cpu().numpy(), 
        x=xs,
        y=ys,
        color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
    )
    for trace in plot.data:
        fig.add_trace(trace, row=1+i, col=1)
    
fig.update_layout(
    coloraxis1=dict(colorscale='RdBu', cmin=-val, cmax=val),
    showlegend=False,
    height = 1200
)

### Patterns

In [198]:
clean_patterns_v2 = torch.load(f'patches/{idx}_patterns_clean_v2.bin')
corr_patterns_v2 = torch.load(f'patches/{idx}_patterns_corr_v2.bin')

In [199]:
plot_str_tokens = [tok[-10:] for tok in corr_str_tokens_v2]
plot_patterns(plot_str_tokens, corr_patterns_v2, heads_v2, 3, 400, 400)