In [1]:
import functools
import json
import os
from typing import Any, List, Tuple, Union
import matplotlib.pyplot as plt
import torch
import torch as t
import torch.nn.functional as F
from fancy_einsum import einsum
from sklearn.linear_model import LinearRegression
from torch import nn
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from einops import rearrange, repeat
import pandas as pd
import numpy as np
import sklearn

import w5d5_tests
from w5d5_transformer import ParenTransformer, SimpleTokenizer

MAIN = __name__ == "__main__"
DEVICE = t.device("cpu")

loaded 100000 examples


In [2]:
if MAIN:
    model = ParenTransformer(ntoken=5, nclasses=2, d_model=56, nhead=2, d_hid=56, nlayers=3).to(DEVICE)
    state_dict = t.load("w5d5_balanced_brackets_state_dict.pt")
    model.to(DEVICE)
    model.load_simple_transformer_state_dict(state_dict)
    model.eval()
    tokenizer = SimpleTokenizer("()")
    with open("w5d5_brackets_data.json") as f:
        data_tuples: List[Tuple[str, bool]] = json.load(f)
        print(f"loaded {len(data_tuples)} examples")
    assert isinstance(data_tuples, list)


class DataSet:
    '''A dataset containing sequences, is_balanced labels, and tokenized sequences'''

    def __init__(self, data_tuples: list):
        '''
        data_tuples is List[Tuple[str, bool]] signifying sequence and label
        '''
        self.strs = [x[0] for x in data_tuples]
        self.isbal = t.tensor([x[1] for x in data_tuples]).to(device=DEVICE, dtype=t.bool)
        self.toks = tokenizer.tokenize(self.strs).to(DEVICE)
        self.open_proportion = t.tensor([s.count("(") / len(s) for s in self.strs])
        self.starts_open = t.tensor([s[0] == "(" for s in self.strs]).bool()

    def __len__(self) -> int:
        return len(self.strs)

    def __getitem__(self, idx) -> Union["DataSet", tuple[str, t.Tensor, t.Tensor]]:
        if type(idx) == slice:
            return self.__class__(list(zip(self.strs[idx], self.isbal[idx])))
        return (self.strs[idx], self.isbal[idx], self.toks[idx])

    @property
    def seq_length(self) -> int:
        return self.toks.size(-1)

    @classmethod
    def with_length(cls, data_tuples: list[tuple[str, bool]], selected_len: int) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if len(s) == selected_len])

    @classmethod
    def with_start_char(cls, data_tuples: list[tuple[str, bool]], start_char: str) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if s[0] == start_char])


if MAIN:
    N_SAMPLES = 5000
    data_tuples = data_tuples[:N_SAMPLES]
    data = DataSet(data_tuples)
    "TODO: YOUR CODE HERE"

    bracket_lengths = [len(s[0]) for s in data]
    fig = px.histogram(
        title="Length of bracket strings in dataset",
        x=bracket_lengths, 
        nbins=max(bracket_lengths),
        template="simple_white"
    )
    fig.show()

loaded 100000 examples


In [3]:
def is_balanced_forloop(parens: str) -> bool:
    '''Return True if the parens are balanced.

    Parens is just the ( and ) characters, no begin or end tokens.
    '''
    stack = []
    for c in parens:
        if c == "(":
            stack.append(c)
        elif c == ")":
            if len(stack) == 0:
                return False
            stack.pop()
        else:
            raise ValueError(f"Unexpected character {c}")
    return len(stack) == 0

if MAIN:
    examples = ["()", "))()()()()())()(())(()))(()(()(()(", "((()()()()))", "(()()()(()(())()", "()(()(((())())()))"]
    labels = [True, False, True, False, True]
    for (parens, expected) in zip(examples, labels):
        actual = is_balanced_forloop(parens)
        assert expected == actual, f"{parens}: expected {expected} got {actual}"
    print("is_balanced_forloop ok!")

is_balanced_forloop ok!


In [4]:
def is_balanced_vectorized(tokens: t.Tensor) -> bool:
    '''
    tokens: sequence of tokens including begin, end and pad tokens - recall that 3 is '(' and 4 is ')'
    '''
    token_map = {0: 0, 2: 0, 1: 0, 3: 1, 4: -1}
    token_values = t.tensor([token_map[int(t)] for t in tokens])
    token_cumsum = t.cumsum(token_values, dim=0)
    return token_cumsum[-1] == 0 and token_cumsum.min() >= 0


if MAIN:
    for (tokens, expected) in zip(tokenizer.tokenize(examples), labels):
        actual = is_balanced_vectorized(tokens)
        assert expected == actual, f"{tokens}: expected {expected} got {actual}"
    print("is_balanced_vectorized ok!")

is_balanced_vectorized ok!


In [5]:
if MAIN:
    toks = tokenizer.tokenize(examples).to(DEVICE)
    out = model(toks)
    prob_balanced = out.exp()[:, 1]
    print("Model confidence:\n" + "\n".join([f"{ex:34} : {prob:.4%}" for ex, prob in zip(examples, prob_balanced)]))

def run_model_on_data(model: ParenTransformer, data: DataSet, batch_size: int = 200) -> t.Tensor:
    '''Return probability that each example is balanced'''
    ln_probs = []
    for i in range(0, len(data.strs), batch_size):
        toks = data.toks[i : i + batch_size]
        with t.no_grad():
            out = model(toks)
        ln_probs.append(out)
    out = t.cat(ln_probs).exp()
    assert out.shape == (len(data), 2)
    return out

if MAIN:
    test_set = data
    n_correct = t.sum((run_model_on_data(model, test_set).argmax(-1) == test_set.isbal).int())
    print(f"\nModel got {n_correct} out of {len(data)} training examples correct!")

Model confidence:
()                                 : 99.9987%
))()()()()())()(())(()))(()(()(()( : 0.0003%
((()()()()))                       : 99.9987%
(()()()(()(())()                   : 0.0006%
()(()(((())())()))                 : 99.9982%

Model got 5000 out of 5000 training examples correct!


In [6]:
def get_post_final_ln_dir(model: ParenTransformer) -> t.Tensor:
    '''Return the direction of the final layer norm'''
    return model.decoder.weight[0] - model.decoder.weight[1]

In [7]:
def get_inputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the inputs to a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    acts = []
    def fn(module, input, output):
        acts.append(input[0].detach().clone())

    handle = module.register_forward_hook(fn)
    run_model_on_data(model, data)
    handle.remove()

    res = t.cat(acts, dim=0)
    return res

def get_outputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the outputs from a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    acts = []
    def fn(module, input, output):
        acts.append(output.detach().clone())

    handle = module.register_forward_hook(fn)
    run_model_on_data(model, data)
    handle.remove()

    res = t.cat(acts, dim=0)
    return res

#if MAIN:
#    w5d5_tests.test_get_inputs(get_inputs, model, data)
#    w5d5_tests.test_get_outputs(get_outputs, model, data)

In [8]:
def get_ln_fit(
    model: ParenTransformer, data: DataSet, ln_module: nn.LayerNorm, seq_pos: Union[None, int]
) -> Tuple[LinearRegression, t.Tensor]:
    '''
    if seq_pos is None, find best fit for all sequence positions. Otherwise, fit only 
    for given seq_pos.

    Returns: A tuple of a (fitted) sklearn LinearRegression object and a dimensionless 
    tensor containing the r^2 of the fit (hint: wrap a value in torch.tensor() to make a dimensionless tensor)
    '''

    inputs = get_inputs(model, data, ln_module)
    outputs = get_outputs(model, data, ln_module)
    lr_mod = sklearn.linear_model.LinearRegression()

    if seq_pos is None:
        inputs = rearrange(inputs, 'b s e -> (b s) e')
        outputs = rearrange(outputs, 'b s e -> (b s) e')
        lr_mod.fit(inputs, outputs)
        r2 = lr_mod.score(inputs, outputs)
    else:
        lr_mod.fit(inputs[:, seq_pos], outputs[:, seq_pos])
        r2 = lr_mod.score(inputs[:, seq_pos], outputs[:, seq_pos])

    return lr_mod, torch.tensor(r2)


if MAIN:
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    print("r^2: ", r2)
    w5d5_tests.test_final_ln_fit(model, data, get_ln_fit)

r^2:  tensor(0.9820, dtype=torch.float64)
All tests in `test_final_ln_fit` passed.


In [16]:
def get_pre_final_ln_dir(model: ParenTransformer, data: DataSet) -> t.Tensor:
    '''
    Return the direction of the layer norm before the final layer norm.
    Hint: use get_ln_fit to get the fit and then return the direction of the fit.
    '''
    post_ln_dir = get_post_final_ln_dir(model)
    (pre_final_ln_fit, _) = get_ln_fit(model, data, model.norm, seq_pos=0)
    L = t.from_numpy(pre_final_ln_fit.coef_)

    return t.einsum("i,ij->j", post_ln_dir, L)


if MAIN:
    w5d5_tests.test_pre_final_ln_dir(model, data, get_pre_final_ln_dir)

All tests in `test_pre_final_ln_dir` passed.


In [23]:
def get_out_by_head(model: ParenTransformer, data: DataSet, layer: int) -> t.Tensor:
    '''
    Get the output of the heads in a particular layer when the model is run on the data.
    Returns a tensor of shape (batch, num_heads, seq, embed_width)
    '''
    module = model.layers[layer].self_attn.W_O
    r = get_inputs(model, data, module)
    r_heads = rearrange(r, 'b s (h w) -> b h s w', h=model.nhead)
    W = module.weight
    W_heads = rearrange(W, 'e (h w) -> h e w', h=model.nhead)
    r_mult = t.einsum('b h s w, h e w -> b h s e', r_heads, W_heads)
    
    return r_mult

if MAIN:
    w5d5_tests.test_get_out_by_head(get_out_by_head, model, data)

All tests in `test_get_out_by_head` passed.


In [27]:
def get_out_by_components(model: ParenTransformer, data: DataSet) -> t.Tensor:
    '''
    Computes a tensor of shape [10, dataset_size, seq_pos, emb] representing the output of the model's components when run on the data.
    The first dimension is  [embeddings, head 0.0, head 0.1, mlp 0, head 1.0, head 1.1, mlp 1, head 2.0, head 2.1, mlp 2]
    '''
    res = []
    res.append(get_outputs(model, data, model.pos_encoder))
    for layer in range(model.nlayers):
        head_out = get_out_by_head(model, data, layer)
        res.append(head_out[:, 0])
        res.append(head_out[:, 1])
        res.append(get_outputs(model, data, model.layers[layer].linear2))
    return t.stack(res, dim=0)

if MAIN:
    w5d5_tests.test_get_out_by_component(get_out_by_components, model, data)

All tests in `test_get_out_by_component` passed.


In [30]:
if MAIN:
    biases = sum([model.layers[l].self_attn.W_O.bias for l in (0, 1, 2)]).clone()
    out_by_components = get_out_by_components(model, data)
    summed_terms = out_by_components.sum(dim=0) + biases
    pre_final_ln = get_inputs(model, data, model.norm)
    t.testing.assert_close(summed_terms, pre_final_ln)

In [32]:
def hists_per_comp(magnitudes, data, n_layers=3, xaxis_range=(-1, 1)):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=n_layers+1, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        if row == n_layers+2: break
        fig.add_trace(go.Histogram(x=mag[data.isbal].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~data.isbal].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, row=row, col=col, range=xaxis_range)
    fig.update_layout(width=1200, height=250*(n_layers+1), barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()

if MAIN:
    # Get output by components at the 0th sequence position
    out_by_components = get_out_by_components(model, data)[:, :, 0, :].detach()
    # Get unbalanced directions for balanced and unbalanced respectively
    unbalanced_dir = get_pre_final_ln_dir(model, data).detach()
    # Get magnitudes, and plot them
    magnitudes = einsum("component sample emb, emb -> component sample", out_by_components, unbalanced_dir)
    # Subtract the mean of the balanced magnitudes from each component
    magnitudes = magnitudes - magnitudes[:, data.isbal].mean(-1, keepdim=True)

    assert "magnitudes" in locals(), "You need to define `magnitudes`"
    hists_per_comp(magnitudes, data, xaxis_range=[-10, 20])

In [44]:
def is_balanced_vectorized_return_both(tokens: t.Tensor) -> bool:
    tokens = tokens.flip(0)
    tokens[tokens <= 2] = 0
    tokens[tokens == 3] = 1
    tokens[tokens == 4] = -1
    brackets_altitude = tokens.cumsum(0)
    total_elevation_failure = brackets_altitude[-1] != 0
    negative_failure = brackets_altitude.max(0).values > 0

    assert total_elevation_failure.shape == negative_failure.shape == t.Size([5000])

    return total_elevation_failure, negative_failure

In [45]:
if MAIN:
    total_elevation_failure, negative_failure = is_balanced_vectorized_return_both(data.toks.T.clone())
    h20_in_d = magnitudes[7] - magnitudes[7, data.isbal].mean(0)
    h21_in_d = magnitudes[8] - magnitudes[8, data.isbal].mean(0)

    failure_types = np.full(len(h20_in_d), "", dtype=np.dtype("U32"))
    failure_types_dict = {
        "both failures": negative_failure & total_elevation_failure,
        "just neg failure": negative_failure & ~total_elevation_failure,
        "just total elevation failure": ~negative_failure & total_elevation_failure,
        "balanced": ~negative_failure & ~total_elevation_failure
    }
    for name, mask in failure_types_dict.items():
        failure_types = np.where(mask, name, failure_types)
    failures_df = pd.DataFrame({
        "Head 2.0 contribution": h20_in_d,
        "Head 2.1 contribution": h21_in_d,
        "Failure type": failure_types
    })[data.starts_open.tolist()]
    fig = px.scatter(
        failures_df, 
        x="Head 2.0 contribution", y="Head 2.1 contribution", color="Failure type", 
        title="h20 vs h21 for different failure types", template="simple_white", height=600, width=800,
        category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4)
    fig.show()

In [46]:
if MAIN:
    fig = px.scatter(
        x=data.open_proportion, y=h20_in_d, color=failure_types, 
        title="Head 2.0 contribution vs proportion of open brackets '('", template="simple_white", height=500, width=800,
        labels={"x": "Open-proportion", "y": "Head 2.0 contribution"}, category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4, opacity=0.5).update_layout(legend_title_text='Failure type')
    fig.show()

In [47]:
if MAIN:
    fig = px.scatter(
        x=data.open_proportion, y=h21_in_d, color=failure_types, 
        title="Head 2.1 contribution vs proportion of open brackets '('", template="simple_white", height=500, width=800,
        labels={"x": "Open-proportion", "y": "Head 2.1 contribution"}, category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4, opacity=0.5).update_layout(legend_title_text='Failure type')
    fig.show()

In [70]:
def get_attn_probs(
    model: ParenTransformer,
    tokenizer: SimpleTokenizer,
    data: DataSet,
    layer: int,
    head: int,
) -> t.Tensor:
    """
    Write a function that extracts the attention patterns for a given head when run on a batch of inputs. Our code will show you the average attention pattern paid by the query for residual stream 0 when that position is an open paren.

    Specifically:

    Use get_inputs from earlier, on the self-attention module in the layer in question.
    You can use the attention_pattern_pre_softmax function to get the pattern, then mask
    the padding (elements of the batch might be different lengths, and thus be suffixed
    with padding).

    Returns: (N_SAMPLES, max_seq_len, max_seq_len) tensor that sums to 1 over the last dimension.
    """
    module = model.layers[layer].self_attn
    input = get_inputs(model, data, module)

    attn_logits = module.attention_pattern_pre_softmax(input)
    attn_logits = attn_logits[:,head,:,:]
    mask = rearrange(data.toks, "b s -> b 1 s") == tokenizer.PAD_TOKEN

    attn_logits = attn_logits.masked_fill(mask, -1e9)
    attn_probs = attn_logits.softmax(dim=-1)

    return attn_probs.detach()


if MAIN:
    attn_probs = get_attn_probs(model, tokenizer, data, 2, 0)
    attn_probs_open = attn_probs[data.starts_open].mean(0)[[0]]
    px.bar(
        y=attn_probs_open.squeeze().numpy(),
        labels={"y": "Probability", "x": "Key Position"},
        template="simple_white",
        height=500,
        width=600,
        title="Avg Attention Probabilities for '(' query from query 0",
    ).update_layout(showlegend=False, hovermode="x unified").show()


In [None]:
def get_WV(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    '''
    Returns the W_V matrix of a head. Should be a CPU tensor of size (d_model / num_heads, d_model)
    '''
    pass

def get_WO(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    '''
    Returns the W_O matrix of a head. Should be a CPU tensor of size (d_model, d_model / num_heads)
    '''
    pass

def get_WOV(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    return get_WO(model, layer, head) @ get_WV(model, layer, head)

def get_pre_20_dir(model, data):
    '''
    Returns the direction propagated back through the OV matrix of 2.0 and then through the layernorm before the layer 2 attention heads.
    '''
    pass

if MAIN:
    w5d5_tests.test_get_WV(model, get_WV)
    w5d5_tests.test_get_WO(model, get_WO)
    w5d5_tests.test_get_pre_20_dir(model, data, get_pre_20_dir)

if MAIN:
    assert "magnitudes" in locals()
    hists_per_comp(magnitudes, data, n_layers=2, xaxis_range=(-7, 7))