In [99]:
import functools
import json
import os
from typing import Any, List, Tuple, Union
import matplotlib.pyplot as plt
import torch
import torch as t
import torch.nn.functional as F
from fancy_einsum import einsum
from sklearn.linear_model import LinearRegression
from torch import nn
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from einops import rearrange, repeat
import pandas as pd
import numpy as np

import w5d5_tests
from w5d5_transformer import ParenTransformer, SimpleTokenizer

MAIN = __name__ == "__main__"
DEVICE = t.device("cpu")

In [100]:
if MAIN:
    model = ParenTransformer(ntoken=5, nclasses=2, d_model=56, nhead=2, d_hid=56, nlayers=3).to(DEVICE)
    state_dict = t.load("w5d5_balanced_brackets_state_dict.pt")
    model.to(DEVICE)
    model.load_simple_transformer_state_dict(state_dict)
    model.eval()
    tokenizer = SimpleTokenizer("()")
    with open("w5d5_brackets_data.json") as f:
        data_tuples: List[Tuple[str, bool]] = json.load(f)
        print(f"loaded {len(data_tuples)} examples")
    assert isinstance(data_tuples, list)

class DataSet:
    '''A dataset containing sequences, is_balanced labels, and tokenized sequences'''

    def __init__(self, data_tuples: list):
        '''
        data_tuples is List[Tuple[str, bool]] signifying sequence and label
        '''
        self.strs = [x[0] for x in data_tuples]
        self.isbal = t.tensor([x[1] for x in data_tuples]).to(device=DEVICE, dtype=t.bool)
        self.toks = tokenizer.tokenize(self.strs).to(DEVICE)
        self.open_proportion = t.tensor([s.count("(") / len(s) for s in self.strs])
        self.starts_open = t.tensor([s[0] == "(" for s in self.strs]).bool()

    def __len__(self) -> int:
        return len(self.strs)

    def __getitem__(self, idx) -> Union["DataSet", tuple[str, t.Tensor, t.Tensor]]:
        if type(idx) == slice:
            return self.__class__(list(zip(self.strs[idx], self.isbal[idx])))
        return (self.strs[idx], self.isbal[idx], self.toks[idx])

    @property
    def seq_length(self) -> int:
        return self.toks.size(-1)

    @classmethod
    def with_length(cls, data_tuples: list[tuple[str, bool]], selected_len: int) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if len(s) == selected_len])

    @classmethod
    def with_start_char(cls, data_tuples: list[tuple[str, bool]], start_char: str) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if s[0] == start_char])

if MAIN:
    N_SAMPLES = 5000
    data_tuples = data_tuples[:N_SAMPLES]
    data = DataSet(data_tuples)
    # plot histogram of sequence lengths with plotly express
    fig = px.histogram(x=[len(s) for s in data.strs])
    fig.show()

loaded 100000 examples


In [101]:
# mysolution
def my_is_balanced_forloop(s: str) -> bool:
        stack = []
        for char in s:
            if char == '(':
                stack.append(char)
            elif char == ')':
                if len(stack) == 0:
                    return False
                elif stack.pop() != '(':
                    return False
        return len(stack) == 0

def is_balanced_forloop(parens: str) -> bool:
    """Return True if the parens are balanced.
    Parens is just the ( and ) characters, no begin or end tokens.
    """
    i = 0
    for c in parens:
        if c == "(":
            i += 1
        elif c == ")":
            i -= 1
            if i < 0:
                return False
        else:
            raise ValueError(parens)
    return i == 0


if MAIN:
    examples = ["()", "))()()()()())()(())(()))(()(()(()(", "((()()()()))", "(()()()(()(())()", "()(()(((())())()))"]
    labels = [True, False, True, False, True]
    for (parens, expected) in zip(examples, labels):
        actual = is_balanced_forloop(parens)
        assert expected == actual, f"{parens}: expected {expected} got {actual}"
    print("is_balanced_forloop ok!")

is_balanced_forloop ok!


In [102]:
def is_balanced_vectorized(tokens: t.Tensor) -> bool:
    '''
    tokens: sequence of tokens including begin, end and pad tokens - recall that 3 is '(' and 4 is ')'

    One solution is to map begin, pad, and end tokens to zero, map open paren to 1 and close paren to -1. Then calculate the cumulative sum of the sequence. Your sequence is unbalanced if and only if:

    The last element of the cumulative sum is nonzero
    Any element of the cumulative sum is negative
    '''
    bracketdict = {3:1, 4:-1, 0:0, 1:0, 2:0}
    bracketlist = [bracketdict[token.item()] for token in tokens]
    cumsum = np.cumsum(bracketlist)
    return cumsum[-1] == 0 and min(cumsum) >= 0

if MAIN:
    for (tokens, expected) in zip(tokenizer.tokenize(examples), labels):
        actual = is_balanced_vectorized(tokens)
        assert expected == actual, f"{tokens}: expected {expected} got {actual}"
    print("is_balanced_vectorized ok!")

is_balanced_vectorized ok!


In [103]:
if MAIN:
    toks = tokenizer.tokenize(examples).to(DEVICE)
    out = model(toks)
    prob_balanced = out.exp()[:, 1]
    print("Model confidence:\n" + "\n".join([f"{ex:34} : {prob:.4%}" for ex, prob in zip(examples, prob_balanced)]))

def run_model_on_data(model: ParenTransformer, data: DataSet, batch_size: int = 200) -> t.Tensor:
    '''Return probability that each example is balanced'''
    ln_probs = []
    for i in range(0, len(data.strs), batch_size):
        toks = data.toks[i : i + batch_size]
        with t.no_grad():
            out = model(toks)
        ln_probs.append(out)
    out = t.cat(ln_probs).exp()
    assert out.shape == (len(data), 2)
    return out

if MAIN:
    test_set = data
    n_correct = t.sum((run_model_on_data(model, test_set).argmax(-1) == test_set.isbal).int())
    print(f"\nModel got {n_correct} out of {len(data)} training examples correct!")

Model confidence:
()                                 : 99.9987%
))()()()()())()(())(()))(()(()(()( : 0.0003%
((()()()()))                       : 99.9987%
(()()()(()(())()                   : 0.0006%
()(()(((())())()))                 : 99.9982%

Model got 5000 out of 5000 training examples correct!


In [104]:
def get_post_final_ln_dir(model: ParenTransformer) -> t.Tensor:
    return model.decoder.weight[0, :] - model.decoder.weight[1, :]

In [105]:
def get_inputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the inputs to a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    module_out = None
    def hook_fn(module, input, output):
        input = input[0]
        nonlocal module_out
        module_out = input
    hook = module.register_forward_hook(hook_fn)
    model_out = model(data.toks)
    hook.remove() # ensures that the model is the same after we run this function
    return module_out

def get_outputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the outputs from a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    module_out = None
    def hook_fn(module, input, output):
        input = input[0]
        nonlocal module_out
        module_out = output
    hook = module.register_forward_hook(hook_fn)
    model_out = model(data.toks)
    hook.remove()
    return module_out
    

if MAIN:
    w5d5_tests.test_get_inputs(get_inputs, model, data)
    w5d5_tests.test_get_outputs(get_outputs, model, data)

All tests in `test_get_inputs` passed.
All tests in `test_get_outputs` passed.


In [106]:
def get_ln_fit(
    model: ParenTransformer, data: DataSet, ln_module: nn.LayerNorm, seq_pos: Union[None, int]
) -> Tuple[LinearRegression, t.Tensor]:
    '''
    if seq_pos is None, find best fit for all sequence positions. Otherwise, fit only for given seq_pos.

    Returns: A tuple of a (fitted) sklearn LinearRegression object and a dimensionless tensor containing the r^2 of the fit (hint: wrap a value in torch.tensor() to make a dimensionless tensor)
    '''
    inp = get_inputs(model, data, ln_module).detach().numpy()
    out = get_outputs(model, data, ln_module).detach().numpy()
    # print(f'inp shape: {inp.shape}')
    # print(f'out shape: {out.shape}')
    if seq_pos is None:
        inp = rearrange(inp, 'b s e -> (b s) e')
        out = rearrange(out, 'b s e -> (b s) e')
    else:
        inp = inp[:, seq_pos, :]
        out = out[:, seq_pos, :]
    print(f'inp shape: {inp.shape}')
    print(f'out shape: {out.shape}')

    linear_regression = LinearRegression().fit(inp, out)
    r2 = linear_regression.score(inp, out)
    return linear_regression, torch.tensor(r2)
    


if MAIN:
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    print("r^2: ", r2)
    print(r2.shape)
    w5d5_tests.test_final_ln_fit(model, data, get_ln_fit)

inp shape: (5000, 56)
out shape: (5000, 56)
r^2:  tensor(0.9820, dtype=torch.float64)
torch.Size([])
inp shape: (5000, 56)
out shape: (5000, 56)
All tests in `test_final_ln_fit` passed.


In [107]:
print(model)

ParenTransformer(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): TransformerBlock(
      (norm1): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (self_attn): MultiheadAttention(
        (W_Q): Linear(in_features=56, out_features=56, bias=True)
        (W_K): Linear(in_features=56, out_features=56, bias=True)
        (W_V): Linear(in_features=56, out_features=56, bias=True)
        (W_O): Linear(in_features=56, out_features=56, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (linear1): Linear(in_features=56, out_features=56, bias=True)
      (linear2): Linear(in_features=56, out_features=56, bias=True)
      (activation): ReLU()
    )
    (1): TransformerBlock(
      (norm1): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (self_attn)

In [108]:
def get_L(model, data):
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    L = t.from_numpy(final_ln_fit.coef_)
    return L

def get_pre_final_ln_dir(model: ParenTransformer, data: DataSet) -> t.Tensor:
    post_final_ln_dir = get_post_final_ln_dir(model)
    L = get_L(model, data)
    return t.einsum("i,ij->j", post_final_ln_dir, L)

if MAIN:
    w5d5_tests.test_pre_final_ln_dir(model, data, get_pre_final_ln_dir)

inp shape: (5000, 56)
out shape: (5000, 56)
All tests in `test_pre_final_ln_dir` passed.


In [109]:
def get_out_by_head(model: ParenTransformer, data: DataSet, layer: int) -> t.Tensor:
    '''
    Get the output of the heads in a particular layer when the model is run on the data.
    Returns a tensor of shape (batch, num_heads, seq, embed_width)
    '''
    module = model.layers[layer].self_attn.W_O
    r = get_inputs(model, data, module)
    # print(f'r shape: {r.shape}')
    r = rearrange(r, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=model.nhead)
    # print(f'r shape: {r.shape}')
    W_O = module.weight
    # print(f'weight shape: {W_O.shape}')
    W_O = rearrange(W_O, 'emb (nheads headsize) -> nheads emb headsize', nheads=model.nhead)
    # print(f'weight shape: {W_O.shape}')
    out = einsum('batch nheads seq headsize, nheads emb headsize -> batch nheads seq emb', r, W_O)
    # print(f'out shape: {out.shape}')
    return out

if MAIN:
    w5d5_tests.test_get_out_by_head(get_out_by_head, model, data)

All tests in `test_get_out_by_head` passed.


In [110]:
def get_out_by_components(model: ParenTransformer, data: DataSet) -> t.Tensor:
    '''
    Computes a tensor of shape [10, dataset_size, seq_pos, emb] representing the output of the model's components when run on the data.
    The first dimension is  [embeddings, head 0.0, head 0.1, mlp 0, head 1.0, head 1.1, mlp 1, head 2.0, head 2.1, mlp 2]
    '''
    out = []
    out.append(get_outputs(model, data, model.pos_encoder))
    for i in range(model.nlayers):
        out.append(get_out_by_head(model, data, i)[:, 0, :, :])
        out.append(get_out_by_head(model, data, i)[:, 1, :, :])
        out.append(get_outputs(model, data, model.layers[i].linear2))
    return t.stack(out, dim=0)

if MAIN:
    w5d5_tests.test_get_out_by_component(get_out_by_components, model, data)

All tests in `test_get_out_by_component` passed.


In [111]:
def hists_per_comp(magnitudes, data, n_layers=3, xaxis_range=(-1, 1)):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=n_layers+1, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        if row == n_layers+2: break
        fig.add_trace(go.Histogram(x=mag[data.isbal].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~data.isbal].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, row=row, col=col, range=xaxis_range)
    fig.update_layout(width=1200, height=250*(n_layers+1), barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()

if MAIN:
    out_by_components = get_out_by_components(model, data)[:, :, 0, :].detach()
    unbalanced_dir = get_pre_final_ln_dir(model, data).detach()
    magnitudes = einsum("component sample emb, emb -> component sample", out_by_components, unbalanced_dir)
    magnitudes = magnitudes - magnitudes[:, data.isbal].mean(-1, keepdim=True)

    assert "magnitudes" in locals(), "You need to define `magnitudes`"
    hists_per_comp(magnitudes, data, xaxis_range=[-10, 20])

inp shape: (5000, 56)
out shape: (5000, 56)


In [112]:
if MAIN:
    NEG = 0
    TOTAL_ELEVATION = 1
    OK = 2

    def is_balanced_forloop_byfail(parens: str) -> int:
        """Return True if the parens are balanced.
        Parens is just the ( and ) characters, no begin or end tokens.
        """
        i = 0
        for c in reversed(parens):
            if c == ")":
                i += 1
            elif c == "(":
                i -= 1
                if i < 0:
                    return NEG
            else:
                raise ValueError(parens)
        return OK if i == 0 else TOTAL_ELEVATION

    def negative(parens: str) -> int:
        i = 0
        for c in reversed(parens):
            if c == ")":
                i += 1
            elif c == "(":
                i -= 1
                if i < 0:
                    return False
        return True

    def total_elevation(parens: str) -> int:
        i = 0
        for c in reversed(parens):
            if c == ")":
                i += 1
            elif c == "(":
                i -= 1
        return i != 0
    
    negative_failure = t.zeros(N_SAMPLES, dtype=bool)
    total_elevation_failure = t.zeros(N_SAMPLES, dtype=bool)
    for i, (s,_,_) in enumerate(data):
        neg = negative(s)
        tot = total_elevation(s)
        negative_failure[i] = neg
        total_elevation_failure[i] = tot
    
    h20_in_d = magnitudes[7] - magnitudes[7, data.isbal].mean(0)
    h21_in_d = magnitudes[8] - magnitudes[8, data.isbal].mean(0)

    failure_types = np.full(len(h20_in_d), "", dtype=np.dtype("U32"))
    failure_types_dict = {
        "both failures": negative_failure & total_elevation_failure,
        "just neg failure": negative_failure & ~total_elevation_failure,
        "just total elevation failure": ~negative_failure & total_elevation_failure,
        "balanced": ~negative_failure & ~total_elevation_failure
    }
    for name, mask in failure_types_dict.items():
        failure_types = np.where(mask, name, failure_types)
    failures_df = pd.DataFrame({
        "Head 2.0 contribution": h20_in_d,
        "Head 2.1 contribution": h21_in_d,
        "Failure type": failure_types
    })[data.starts_open.tolist()]
    fig = px.scatter(
        failures_df, 
        x="Head 2.0 contribution", y="Head 2.1 contribution", color="Failure type", 
        title="h20 vs h21 for different failure types", template="simple_white", height=600, width=800,
        category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4)
    fig.show()

In [113]:
! pip install plotly==5.9.0

You should consider upgrading via the '/Users/jon/ml/arena/venv/bin/python3 -m pip install --upgrade pip' command.[0m


In [114]:
print(failure_types_dict.keys())
if MAIN:
    fig = px.scatter(
        x=data.open_proportion, y=h20_in_d, color=failure_types, 
        title="Head 2.0 contribution vs proportion of open brackets '('", template="simple_white", height=500, width=800,
        labels={"x": "Open-proportion", "y": "Head 2.0 contribution"}, category_orders={"color": list(failure_types_dict.keys())}
    ).update_traces(marker_size=4, opacity=0.5).update_layout(legend_title_text='Failure type')
    fig.show()

dict_keys(['both failures', 'just neg failure', 'just total elevation failure', 'balanced'])


In [115]:
# print(list(failure_types_dict.values())[0].shape)
if MAIN:
    fig = px.scatter(
        x=data.open_proportion, y=h21_in_d, color=failure_types, 
        title="Head 2.1 contribution vs proportion of open brackets '('", template="simple_white", height=500, width=800,
        labels={"x": "Open-proportion", "y": "Head 2.1 contribution"}, category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4, opacity=0.5).update_layout(legend_title_text='Failure type')
    fig.show()

In [116]:
def get_attn_probs(model: ParenTransformer, tokenizer: SimpleTokenizer, data: DataSet, layer: int, head: int) -> t.Tensor:
    '''
    Returns: (N_SAMPLES, max_seq_len, max_seq_len) tensor that sums to 1 over the last dimension.
    '''
    attn = model.layers[layer].self_attn
    inputs = get_inputs(model, data, attn)
    attn_scores = attn.attention_pattern_pre_softmax(inputs)
    print(attn_scores[:,head,:,:].shape)
    padding_mask = (data.toks == tokenizer.PAD_TOKEN)[:, None, None, :]
    print(padding_mask.shape)
    attn_scores = t.where(padding_mask, t.full_like(attn_scores, -100000), attn_scores)
    attn_scores = attn_scores[:,head,:,:] # TODO why the fuck doesn't this work
    # attn_scores[padding_mask] = -1000
    
    attn_probs = t.softmax(attn_scores, dim=-1)
    assert (attn_probs.sum(-1) - 1).abs().max() < 1e-5
    print(attn_probs.shape)
    return attn_probs.detach()

if MAIN:
    attn_probs = get_attn_probs(model, tokenizer, data, 2, 0)
    attn_probs_open = attn_probs[data.starts_open].mean(0)[[0]]

    px.bar(
        y=attn_probs_open.squeeze().numpy(), labels={"y": "Probability", "x": "Key Position"},
        template="simple_white", height=500, width=600, title="Avg Attention Probabilities for '(' query from query 0"
    ).update_layout(showlegend=False, hovermode='x unified').show()

torch.Size([5000, 42, 42])
torch.Size([5000, 1, 1, 42])
torch.Size([5000, 42, 42])


In [117]:
model

ParenTransformer(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): TransformerBlock(
      (norm1): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (self_attn): MultiheadAttention(
        (W_Q): Linear(in_features=56, out_features=56, bias=True)
        (W_K): Linear(in_features=56, out_features=56, bias=True)
        (W_V): Linear(in_features=56, out_features=56, bias=True)
        (W_O): Linear(in_features=56, out_features=56, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (linear1): Linear(in_features=56, out_features=56, bias=True)
      (linear2): Linear(in_features=56, out_features=56, bias=True)
      (activation): ReLU()
    )
    (1): TransformerBlock(
      (norm1): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((56,), eps=1e-05, elementwise_affine=True)
      (self_attn)

In [120]:
def get_WV(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    '''
    Returns the W_V matrix of a head. Should be a CPU tensor of size (d_model / num_heads, d_model)
    '''
    headsize = model.d_model // model.nhead
    if head == 0:
        return model.layers[layer].self_attn.W_V.weight[:headsize,:].detach().cpu()
    elif head == 1:
        return model.layers[layer].self_attn.W_V.weight[headsize:,:].detach().cpu()

def get_WO(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    '''
    Returns the W_O matrix of a head. Should be a CPU tensor of size (d_model, d_model / num_heads)
    '''
    headsize = model.d_model // model.nhead
    if head == 0:
        return model.layers[layer].self_attn.W_O.weight[:,:headsize].detach().cpu()
    elif head == 1:
        return model.layers[layer].self_attn.W_O.weight[:,headsize:].detach().cpu()

def get_WOV(model: ParenTransformer, layer: int, head: int) -> t.Tensor:
    return get_WO(model, layer, head) @ get_WV(model, layer, head)

def get_pre_20_dir(model, data):
    '''
    Returns the direction propagated back through the OV matrix of 2.0 and then through the layernorm before the layer 2 attention heads.
    '''
    W_OV = get_WOV(model, 2, 0)
    pre_final_ln_dir = get_pre_final_ln_dir(model, data)
    norm1 = model.layers[2].norm1
    (fit, r) = get_ln_fit(model, data, norm1, 1)
    L = t.from_numpy(fit.coef_)
    return t.einsum('i, ij, jk->k', pre_final_ln_dir, W_OV, L)
    
    

if MAIN:
    w5d5_tests.test_get_WV(model, get_WV)
    w5d5_tests.test_get_WO(model, get_WO)
    w5d5_tests.test_get_pre_20_dir(model, data, get_pre_20_dir)

if MAIN:
    out_by_components = get_out_by_components(model, data)[:, :, 0, :].detach()
    unbalanced_dir = get_pre_20_dir(model, data).detach()
    magnitudes = einsum("component sample emb, emb -> component sample", out_by_components, unbalanced_dir)
    magnitudes = magnitudes - magnitudes[:, data.isbal].mean(-1, keepdim=True)
    assert "magnitudes" in locals()
    hists_per_comp(magnitudes, data, n_layers=2, xaxis_range=(-7, 7))

All tests in `test_get_WV` passed.
All tests in `test_get_WO` passed.
inp shape: (5000, 56)
out shape: (5000, 56)
inp shape: (5000, 56)
out shape: (5000, 56)
pre_final_ln_dir: torch.Size([56])
WOV: torch.Size([56, 56])
L: torch.Size([56, 56])
All tests in `test_get_pre_20_dir` passed.
inp shape: (5000, 56)
out shape: (5000, 56)
inp shape: (5000, 56)
out shape: (5000, 56)
pre_final_ln_dir: torch.Size([56])
WOV: torch.Size([56, 56])
L: torch.Size([56, 56])


In [None]:
def out_by_neuron(model: ParenTransformer, data, layer):
    '''
    Return shape: [len(data), seq_len, neurons, out]
    '''
    mlp = model.layers[layer].linear1
    act = model.layers[layer].activation
    mlp = model.layers[layer].linear2

@functools.cache
def out_by_neuron_in_20_dir(model, data, layer):
    pass