In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time
%matplotlib inline
import plotly.express as px
import plotly.io as pio
# pio.renderers.default = "colab"
import plotly.graph_objects as go

# from google.colab import drive
from pathlib import Path
import pickle
import os

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from functools import *
import pandas as pd
import gc

# import comet_ml
import itertools


In [2]:

# A helper class to get access to intermediate activations (inspired by Garcon)
# It's a dummy module that is the identity function by default
# I can wrap any intermediate activation in a HookPoint and get a convenient
# way to add PyTorch hooks
class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []

    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name

    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output,
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")

    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")

    def forward(self, x):
        return x

# Define network architecture
# I defined my own transformer from scratch so I'd fully understand each component
# - I expect this wasn't necessary or particularly important, and a bunch of this
# replicates existing PyTorch functionality

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))

    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))

    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))

    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon

    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x

# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out

# MLP Layers
class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()
        assert act_type in ['ReLU', 'GeLU']

    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in)
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        x = self.hook_post(x)
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(self, x):
        x = self.hook_resid_mid(x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))
        return x

# Full transformer
class Transformer(nn.Module):
    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_cache=False, use_ln=True):
        super().__init__()
        self.cache = {}
        self.use_cache = use_cache

        self.embed = Embed(d_vocab, d_model)
        self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)

    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        # x = self.ln(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()
        def save_hook_back(tensor, name):
            cache[name+'_grad'] = tensor[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

# Helper functions
def cuda_memory():
    print(torch.cuda.memory_allocated()/1e9)

def cross_entropy_high_precision(logits, labels):
    # Shapes: batch x seq x vocab, batch x seq
    # Cast logits to float64 because log_softmax has a float32 underflow on overly
    # confident data and can only return multiples of 1.2e-7 (the smallest float x
    # such that 1+x is different from 1 in float32). This leads to loss spikes
    # and dodgy gradients

    logprobs = F.log_softmax(logits.to(torch.float64), dim=-1)
    # print(logprobs.shape, labels.shape)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, :, None], dim=-1)
    # print(prediction_logprobs.shape)
    loss = -torch.mean(prediction_logprobs)
    return loss

def full_loss(model, data, arr_len, device='cuda'):
    """
    Calculate the full loss and accuracy of the model.

    Parameters:
    model (nn.Module): The PyTorch model.
    data (Tensor): The input data.
    arr_len (int): The length of the array.

    Returns:
    tuple: A tuple containing the loss and accuracy.
    """
    # Take the final position only
    logits = model(data)[:, arr_len:-1]

    # labels = torch.tensor([fn(i, j) for i, j, _ in data]).to('cuda')
    # labels = torch.tensor([np.sort(x) for x in data])
    labels = torch.tensor(data[:, arr_len + 1:]).to(device)

    # Calculate loss
    loss = cross_entropy_high_precision(logits, labels)

    # Calculate accuracy
    predictions = torch.argmax(logits, dim=2)
    # print(predictions.shape, labels.shape, len(labels))
    accuracy = torch.sum(predictions == labels) / (arr_len * len(labels))
    # Calculate exact match accuracy
    exact_match_accuracy = torch.sum(torch.all(predictions == labels, dim=-1)) / labels.shape[0]
    return loss, accuracy,exact_match_accuracy


In [3]:
def unflatten_first(tensor):
    if tensor.shape[0]==p*p:
        return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)
    else:
        return tensor
def cos(x, y):
    return (x.dot(y))/x.norm()/y.norm()
def mod_div(a, b):
    return (a*pow(b, p-2, p))%p
def normalize(tensor, axis=0):
    return tensor/(tensor).pow(2).sum(keepdim=True, axis=axis).sqrt()
def extract_freq_2d(tensor, freq):
    # Takes in a pxpx... or batch x ... tensor, returns a 3x3x... tensor of the
    # Linear and quadratic terms of frequency freq
    tensor = unflatten_first(tensor)
    # Extracts the linear and quadratic terms corresponding to frequency freq
    index_1d = [0, 2*freq-1, 2*freq]
    # Some dumb manipulation to use fancy array indexing rules
    # Gets the rows and columns in index_1d
    return tensor[[[i]*3 for i in index_1d], [index_1d]*3]
def get_cov(tensor, norm=True):
    # Calculate covariance matrix
    if norm:
        tensor = normalize(tensor, axis=1)
    return tensor @ tensor.T
def is_close(a, b):
    return ((a-b).pow(2).sum()/(a.pow(2).sum().sqrt())/(b.pow(2).sum().sqrt())).item()

In [4]:
def write_to_file(x, file_name=None):
    # Define the default file name and its extension
    default_file_name = "plot"
    file_extension = ".html"

    # If file_name is not provided, use the default file name with an incrementing number
    if file_name is None:
        i = 1
        while True:
            file_name = f"{default_file_name}_{i}{file_extension}"
            if not os.path.exists(file_name):
                break
            i += 1

    # Write x to the file
    with open("plots_html/" + file_name, "w") as f:
        f.write(x.to_html())

    return x
#Plotting functions
# This is mostly a bunch of over-engineered mess to hack Plotly into producing
# the pretty pictures I want, I recommend not reading too closely unless you
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()
def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', file_name=None, **kwargs):
    if tensor.shape[0]==p*p:
        tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    write_to_file(px.imshow(to_numpy(tensor, flat=False),
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name},
              **kwargs), file_name=file_name).show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps
# of activations with x axis being input 1 and y axis being input 2.
inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()
def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, file_name=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    write_to_file(fig, file_name=file_name)
    fig.show()
def line_marker(x, **kwargs):
    lines([x], mode='lines+markers', **kwargs)
def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', file_name=None, **kwargs):
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[1]):
            rows.append([lines_list[i][j], snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
    write_to_file(px.line(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs), file_name=file_name).show()

def imshow_fourier(tensor, title='', animation_name='snapshot', facet_labels=[], **kwargs):
    # Set nice defaults for plotting functions in the 2D fourier basis
    # tensor is assumed to already be in the Fourier Basis
    if tensor.shape[0]==p*p:
        tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    fig=px.imshow(to_numpy(tensor),
            x=fourier_basis_names,
            y=fourier_basis_names,
            labels={'x':'x Component',
                    'y':'y Component',
                    'animation_frame':animation_name},
            title=title,
            color_continuous_midpoint=0.,
            color_continuous_scale='RdBu',
            **kwargs)
    fig.update(data=[{'hovertemplate':"%{x}x * %{y}y<br>Value:%{z:.4f}"}])
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    fig.show()

def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, file_name=None, **kwargs):
    # Can plot an animation of lines with multiple lines on the plot.
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if swap_y_animate:
        lines_list = lines_list.transpose(1, 0, 2)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if y_index is None:
        y_index = [str(i) for i in range(lines_list.shape[1])]
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])
    write_to_file(px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs), file_name=file_name).show()

def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', file_name=None, **kwargs):
    # Can plot an animated scatter plot
    # lines_list has shape snapshot x 2 x line
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    if color is None:
        color = np.ones(lines_list.shape[-1])
    if type(color)==torch.Tensor:
        color = to_numpy(color)
    if len(color.shape)==1:
        color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])
    print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])
    print([lines_list[:, 0].min(), lines_list[:, 0].max()])
    print([lines_list[:, 1].min(), lines_list[:, 1].max()])
    df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])
    write_to_file(px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs), file_name=file_name).show()

In [5]:
lr=1e-4 #@param
weight_decay = 1.0 #@param
p=113 #@param
d_model = 128 #@param
fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']
frac_train = 0.3 #@param
dataset_size = 10000 #@param
num_epochs = 20000 #@param
save_models = True #@param
save_every = 1000 #@param
# Stop training when test loss is <stopping_thresh
stopping_thresh = -1 #@param
seed = 0 #@param

arr_len = 5 #@param

start = 1 #@param
end = 100 #@param

num_layers = 1

batch_style = 'full'
d_vocab = p+1
n_ctx = 3
d_mlp = 4*d_model
num_heads = 4
assert d_model % num_heads == 0
d_head = d_model//num_heads
act_type = 'ReLU' #@param ['ReLU', 'GeLU']
# batch_size = 512
use_ln = False
random_answers = np.random.randint(low=0, high=p, size=(p, p))
fns_dict = {'add': lambda x,y:(x+y)%p, 'subtract': lambda x,y:(x-y)%p, 'x2xyy2':lambda x,y:(x**2+x*y+y**2)%p, 'rand':lambda x,y:random_answers[x][y]}
fn = fns_dict[fn_name]


In [6]:
model = Transformer(num_layers=num_layers, d_vocab=(end - start + 1 + 1), d_model=d_model, d_mlp=d_mlp, d_head=d_head, num_heads=num_heads, n_ctx=2 * arr_len + 1, act_type=act_type, use_cache=True, use_ln=use_ln)

In [7]:
state_dict = torch.load('./max5_unbalanced/3000.pth')
init_dict = torch.load('./max5_unbalanced/init.pth')
train_data = init_dict['train_data']
test_data = init_dict['test_data']

  state_dict = torch.load('./max5_unbalanced/3000.pth')
  init_dict = torch.load('./max5_unbalanced/init.pth')


In [8]:
cache = {}
model.remove_all_hooks()
model.load_state_dict(state_dict=state_dict['model'])
model.cache_all(cache)
model.to('cuda')

Transformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn): HookPoint()
        (hook_attn_pre): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)

In [10]:
model(torch.tensor([[1, 2, 3, 4, 5, 0]]))[0][-1].argmax()

tensor(20, device='cuda:0')

In [11]:
# Helper variables
W_O = einops.rearrange(model.blocks[0].attn.W_O, 'm (i h)->i m h', i=num_heads)
W_K = model.blocks[0].attn.W_K
W_Q = model.blocks[0].attn.W_Q
W_V = model.blocks[0].attn.W_V
W_in = model.blocks[0].mlp.W_in
W_out = model.blocks[0].mlp.W_out
W_pos = model.pos_embed.W_pos.T
# We remove the equals sign dimension from the Embed and Unembed, so we can
# apply a Fourier Transform over R^p
W_E = model.embed.W_E[:, :-1]
W_U = model.unembed.W_U[:, :-1].T

# The initial value of the residual stream at position 2 - constant for all inputs
final_pos_resid_initial = model.embed.W_E[:, -1] + W_pos[:, 2]
print('W_O', W_O.shape)
print('W_K', W_K.shape)
print('W_Q', W_Q.shape)
print('W_V', W_V.shape)
print('W_in', W_in.shape)
print('W_out', W_out.shape)
print('W_pos', W_pos.shape)
print('W_E', W_E.shape)
print('W_U', W_U.shape)
print('Initial residual stream value at final pos:', final_pos_resid_initial.shape)

W_O torch.Size([4, 128, 32])
W_K torch.Size([4, 32, 128])
W_Q torch.Size([4, 32, 128])
W_V torch.Size([4, 32, 128])
W_in torch.Size([512, 128])
W_out torch.Size([128, 512])
W_pos torch.Size([128, 11])
W_E torch.Size([128, 100])
W_U torch.Size([100, 128])
Initial residual stream value at final pos: torch.Size([128])


In [12]:
test_loss, test_acc, test_exact_acc = full_loss(model, test_data[:120], arr_len)
(test_loss, test_acc, test_exact_acc)

(tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>),
 tensor(0.2000, device='cuda:0'),
 tensor(1., device='cuda:0'))

In [13]:
for k in cache.keys():
    print(k, cache[k].shape)

blocks.0.hook_resid_pre torch.Size([120, 7, 128])
blocks.0.attn.hook_k torch.Size([120, 4, 7, 32])
blocks.0.attn.hook_q torch.Size([120, 4, 7, 32])
blocks.0.attn.hook_v torch.Size([120, 4, 7, 32])
blocks.0.attn.hook_attn_pre torch.Size([120, 4, 7, 7])
blocks.0.attn.hook_attn torch.Size([120, 4, 7, 7])
blocks.0.attn.hook_z torch.Size([120, 4, 7, 32])
blocks.0.hook_attn_out torch.Size([120, 7, 128])
blocks.0.hook_resid_mid torch.Size([120, 7, 128])
blocks.0.mlp.hook_pre torch.Size([120, 7, 512])
blocks.0.mlp.hook_post torch.Size([120, 7, 512])
blocks.0.hook_mlp_out torch.Size([120, 7, 128])
blocks.0.hook_resid_post torch.Size([120, 7, 128])


In [14]:
# Extracts out key activations
# Attention values
attn_mat = cache['blocks.0.attn.hook_attn'][:, :, 5, :5] # Attention of 0 (CLS) with all inputs
print('Attention Matrix:', attn_mat.shape)


neuron_acts_pre = cache['blocks.0.mlp.hook_pre'][:, 5] # Before non-linearity in MLP 
print('Neuron Activations Pre:', neuron_acts_pre.shape)

neuron_acts = cache['blocks.0.mlp.hook_post'][:, 5] # After non-linearity in MLP 
print('Neuron Activations:', neuron_acts.shape)


Attention Matrix: torch.Size([120, 4, 5])
Neuron Activations Pre: torch.Size([120, 512])
Neuron Activations: torch.Size([120, 512])


In [15]:
def get_activations(model, data, device='cuda'):
    cache = {}
    model.remove_all_hooks()
    model.cache_all(cache)
    model.to(device)
    with torch.no_grad():
        loss, acc, exact_acc = full_loss(model, data, arr_len, device=device)
    # Extracts out key activations
    # Attention values
    attn_mat = cache['blocks.0.attn.hook_attn'][:, :, 5, :5] # Attention of 0 (CLS) with all inputs
    neuron_acts_pre = cache['blocks.0.mlp.hook_pre'][:, 5] # Before non-linearity in MLP 
    neuron_acts = cache['blocks.0.mlp.hook_post'][:, 5] # After non-linearity in MLP 
    return attn_mat, neuron_acts_pre, neuron_acts, cache


In [16]:
rng1 = np.random.RandomState(400)
tensor = np.sort(rng1.randint(1, 100, (5,)))
# tensor = [1,2,3,4,5]
permutations = torch.tensor(list(itertools.permutations(tensor)))
permutations = torch.cat((permutations, torch.zeros((120, 1), dtype=torch.int64)), dim=-1)
permutations.shape, permutations[:5]

(torch.Size([120, 6]),
 tensor([[13, 58, 63, 80, 93,  0],
         [13, 58, 63, 93, 80,  0],
         [13, 58, 80, 63, 93,  0],
         [13, 58, 80, 93, 63,  0],
         [13, 58, 93, 63, 80,  0]]))

In [17]:
attn_mat, neuron_acts_pre, neuron_acts, cache = get_activations(model, permutations)
print('Attention Matrix:', attn_mat.shape)
print('Neuron Activations Pre:', neuron_acts_pre.shape)
print('Neuron Activations:', neuron_acts.shape)


Attention Matrix: torch.Size([120, 4, 5])
Neuron Activations Pre: torch.Size([120, 512])
Neuron Activations: torch.Size([120, 512])


  labels = torch.tensor(data[:, arr_len + 1:]).to(device)


In [18]:
# Compare positional embeddings
def compare_tensors(v, w):
    return ((v-w).pow(2).sum()/v.pow(2).sum().sqrt()/w.pow(2).sum().sqrt()).item()
print('Positions 0 and 1 are symmetric')
print('Difference in position embeddings', compare_tensors(W_pos[:, 0], W_pos[:, 1]))
print('Cosine sim of position embeddings', cos(W_pos[:, 0], W_pos[:, 1]).item())
print('Cosine sim of position embeddings', cos(W_pos[:, 0], W_pos[:, 2]).item())
print('Cosine sim of position embeddings', cos(W_pos[:, 0], W_pos[:, 3]).item())
print('Cosine sim of position embeddings', cos(W_pos[:, 0], W_pos[:, 4]).item())
print('Cosine sim of position embeddings', cos(W_pos[:, 0], W_pos[:, 5]).item())


Positions 0 and 1 are symmetric
Difference in position embeddings 1.281852126121521
Cosine sim of position embeddings 0.3598152995109558
Cosine sim of position embeddings 0.4259050190448761
Cosine sim of position embeddings 0.5002790093421936
Cosine sim of position embeddings 0.45464444160461426
Cosine sim of position embeddings 0.0025431266985833645


In [22]:
cosine_similarities = F.cosine_similarity(W_pos.T[:6].unsqueeze(1), W_pos.T[:6].unsqueeze(0), dim=2)
imshow(cosine_similarities.detach().cpu(), 
       xaxis='position', yaxis='position',
       title='Cosine similarity of every pair of position embeddings', file_name='max_pos_em_cos_sim')

In [23]:
cosine_similarities = F.cosine_similarity(neuron_acts_pre.unsqueeze(1), neuron_acts_pre.unsqueeze(0), dim=2)
imshow(cosine_similarities.detach().cpu(), 
       xaxis='permutation', yaxis='permutation',
       title='Cosine similarity of neuron activations pre-ReLU corresponding to two permuatations',
       file_name='max_neuron_cos_sim_perm_inv.html')

In [19]:
# TODO:
# mlp_out = neuron_acts @ W_out.T
# # Index by -1 to look at just the final position
# # Since it's a 1L transformer, the residual stream at positions 0 and 1 don't
# # matter post the attention layer
# x_1 = cache['blocks.0.hook_resid_mid'][:, -1]
# # Average x_1 across the batch of all data to get a bias term, constant for all inputs
# average_x_1 = einops.reduce(x_1, 'batch model -> 1 model', 'mean')

# print('Loss with skip connection:', test_logits((mlp_out + x_1)@W_U.T).item())
# print('Loss with skip connection as bias term:', test_logits((mlp_out + average_x_1)@W_U.T).item())
# print('Loss with no skip connection:', test_logits((mlp_out)@W_U.T).item())

In [20]:
# perm = [4, 2, 1, 3, 0, 5]

# permuted_test_data = test_data[:, perm]
# _, neuron_acts_pre_1, _, _ = get_activations(model, test_data)
# _, neuron_acts_pre_2, _, _ = get_activations(model, permuted_test_data)

# F.cosine_similarity(neuron_acts_pre_1, neuron_acts_pre_2).min()

In [21]:
perm_similarities = F.cosine_similarity(permutations.float().unsqueeze(1), permutations.float().unsqueeze(0), dim=2)
imshow(perm_similarities)

In [22]:
for i in range(120):
    print(i, permutations[i])

0 tensor([13, 58, 63, 80, 93,  0])
1 tensor([13, 58, 63, 93, 80,  0])
2 tensor([13, 58, 80, 63, 93,  0])
3 tensor([13, 58, 80, 93, 63,  0])
4 tensor([13, 58, 93, 63, 80,  0])
5 tensor([13, 58, 93, 80, 63,  0])
6 tensor([13, 63, 58, 80, 93,  0])
7 tensor([13, 63, 58, 93, 80,  0])
8 tensor([13, 63, 80, 58, 93,  0])
9 tensor([13, 63, 80, 93, 58,  0])
10 tensor([13, 63, 93, 58, 80,  0])
11 tensor([13, 63, 93, 80, 58,  0])
12 tensor([13, 80, 58, 63, 93,  0])
13 tensor([13, 80, 58, 93, 63,  0])
14 tensor([13, 80, 63, 58, 93,  0])
15 tensor([13, 80, 63, 93, 58,  0])
16 tensor([13, 80, 93, 58, 63,  0])
17 tensor([13, 80, 93, 63, 58,  0])
18 tensor([13, 93, 58, 63, 80,  0])
19 tensor([13, 93, 58, 80, 63,  0])
20 tensor([13, 93, 63, 58, 80,  0])
21 tensor([13, 93, 63, 80, 58,  0])
22 tensor([13, 93, 80, 58, 63,  0])
23 tensor([13, 93, 80, 63, 58,  0])
24 tensor([58, 13, 63, 80, 93,  0])
25 tensor([58, 13, 63, 93, 80,  0])
26 tensor([58, 13, 80, 63, 93,  0])
27 tensor([58, 13, 80, 93, 63,  0])
28

In [23]:
logits = model(permutations)
logits.shape

torch.Size([120, 6, 101])

In [24]:
torch.argmax(logits[:, 5], dim=-1)

tensor([93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93,
        93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93], device='cuda:0')

In [25]:
print(F.cosine_similarity(neuron_acts_pre, neuron_acts_pre))
# print('Difference in neuron activations for (x,y) and (y,x)', compare_tensors(neuron_acts_square, neuron_acts_square.permute(1, 0, 2)))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 

In [26]:
attn_mat.shape

torch.Size([120, 4, 5])

In [24]:
imshow((attn_mat),
            xaxis='Position', yaxis='Heads', color_continuous_scale='RdBu', color_continuous_midpoint=0.0,
           title=f'Attention score for heads at all positions for every permutation',
           animation_frame=0,
           animation_name='head',
           file_name='max_attn_all_positions_all_permutations.html')


In [28]:
W_E.shape,W_K.shape, W_pos.shape

(torch.Size([128, 100]), torch.Size([4, 32, 128]), torch.Size([128, 11]))

In [29]:
ls = []
for i in range(1, 100):
    k = W_K[0] @ (W_E[:, i] + W_pos[:, 0])
    q = W_Q[0] @ (W_E[:, 50] + W_pos[:, 4])
    ls.append(torch.dot(k, q).cpu().item())
ls2 = []
for i in range(1, 100):
    k = W_K[1] @ (W_E[:, i] + W_pos[:, 0])
    q = W_Q[1] @ (W_E[:, 50] + W_pos[:, 4])
    ls2.append(torch.dot(k, q).cpu().item())
ls3 = []
for i in range(1, 100):
    k = W_K[2] @ (W_E[:, i] + W_pos[:, 0])
    q = W_Q[2] @ (W_E[:, 50] + W_pos[:, 4])
    ls3.append(torch.dot(k, q).cpu().item())
ls4 = []
for i in range(1, 100):
    k = W_K[3] @ (W_E[:, i] + W_pos[:, 0])
    q = W_Q[3] @ (W_E[:, 50] + W_pos[:, 4])
    ls4.append(torch.dot(k, q).cpu().item())
# ls
lines([ls, ls2, ls3, ls4])

In [30]:
x = np.array([ls, ls2, ls3, ls4]).mean(axis=0)

x[1:]-x[:-1]

array([ 0.01305618,  0.00574246,  0.01544912, -0.03370505,  0.01835608,
        0.0263018 , -0.02591334,  0.02011297, -0.00806889, -0.01804621,
        0.02668996, -0.02258878,  0.00831505, -0.00105119,  0.01209413,
       -0.0191979 ,  0.01823095, -0.01252526,  0.00319972, -0.00217357,
        0.01148374, -0.00456416,  0.0174726 , -0.00814295, -0.00384265,
       -0.02033288,  0.02494357,  0.00105955, -0.01393462,  0.0166932 ,
       -0.03284729,  0.02732111, -0.00379943,  0.00145738, -0.00535817,
        0.00031187, -0.00362004,  0.01050191, -0.00827826,  0.01850024,
       -0.0141461 , -0.01934663,  0.01409043,  0.01951981, -0.01692074,
       -0.00851756,  0.02333037, -0.03089407,  0.05086175, -0.02177195,
        0.01301089, -0.01306636,  0.00792992,  0.01616719, -0.06972338,
        0.06612445, -0.0313461 ,  0.00011305,  0.00865683, -0.0325943 ,
        0.0220516 ,  0.0023869 ,  0.00129042,  0.02147316, -0.03188105,
        0.02036005,  0.02091812, -0.00598021, -0.02099945, -0.00

In [31]:
torch.argmax(model(torch.tensor([[71, 72, 78, 60, 75, 0]]))[0, -1])

tensor(78, device='cuda:0')

In [32]:
vary_last = torch.tensor([[30, 60, 90, i, j, 0] for i in range(1, 101) for j in range(1, 101)])
vary_last.shape
# print(vary_last.grad)

torch.Size([10000, 6])

In [33]:
attn_mat, neuron_acts_pre, neuron_acts, cache = get_activations(model, vary_last, device='cpu')



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [34]:
attn_mat.shape

torch.Size([10000, 4, 5])

In [35]:
h = 0
imshow(attn_mat[:, h, 4].reshape(100, 100))

# imshow((attn_mat[:, h, -2:].reshape(20, 20, 2)),
#             xaxis='Position', yaxis='Heads', color_continuous_scale='RdBu', color_continuous_midpoint=0.0,
#            title=f'Attention score for heads at all positions for every permutation',
#         #    animation_frame=0,
#            animation_name='head')


In [36]:
W_V.shape, W_E.shape, (W_V.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu()).shape

(torch.Size([4, 32, 128]), torch.Size([128, 100]), torch.Size([32, 100]))

In [37]:
imshow(W_V.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu())

In [29]:
lines((W_Q.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu())[:, 1:], xaxis='input number', yaxis='q dimension', file_name='max_q_lines.html')

In [30]:
lines((W_K.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu())[:, :], xaxis='input number', yaxis='k dimension', file_name='max_k_lines.html')

In [40]:
imshow((W_K.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu())[:, 1:])

In [41]:
W_E.shape

torch.Size([128, 100])

In [42]:
lines((W_K.cpu()[0] @ (W_E + W_pos[:, 0, None]).cpu())[:, 1:])

In [43]:
lines()

TypeError: lines() missing 1 required positional argument: 'lines_list'

In [44]:
W_Q.shape, W_V.shape, W_E.shape

(torch.Size([4, 32, 128]), torch.Size([4, 32, 128]), torch.Size([128, 100]))

In [45]:
imshow(init_dict['model']['blocks.0.attn.W_Q'][0])

In [46]:
imshow(W_Q[0])


In [19]:
h= 0
imshow(torch.stack([
    (
    (W_Q.cpu()[h] 
        @ (W_E + W_pos[:, 0, None]).cpu())[:, :].T 
    @ (W_K.cpu()[h] 
        @ (W_E + W_pos[:, 0, None]).cpu())[:, :]
    )[1:, 1:] 
    for h in range(4)]), animation_frame=0,
    xaxis="W_K @ W_E[Input]",
    yaxis="W_Q @ W_E[Input]",
    title="Attention score for query of every number with key of every number",
    file_name='max_q_times_k_values.html')


In [48]:
(W_Q[0] @ W_E.cpu()).shape

torch.Size([32, 100])

In [22]:
h=1
import numpy as np
import matplotlib.pyplot as plt
q_m = []
q_b = []
for i in range(32):
    x = np.arange(100)
    y = (W_Q.cpu()[h] @ W_E.cpu())[i, :].detach().numpy()

    m, b = np.polyfit(x, y, 1)
    q_m.append(m)
    q_b.append(b)
# print(m, b)
# plt.plot(x, y, 'yo', x, m*x+b, '--k')
# plt.show()
q_m, q_b


([0.005245351983796098,
  0.004684352363731019,
  0.005374979446615556,
  -0.007046514641539105,
  -0.005557156313839992,
  -0.00713371201524609,
  0.006880712861849341,
  -0.007235332412920174,
  0.007600406947828068,
  0.006087459715571817,
  0.00633484561576823,
  -0.005737040206214916,
  -0.006683256924621807,
  0.0052245109147992155,
  0.00563146407615628,
  -0.005540145690986837,
  0.005251996354631659,
  -0.007265018123280128,
  0.0080316242766269,
  -0.007176804897260005,
  0.0044668804599934,
  0.006444435307426068,
  -0.005649205552796316,
  0.005822286150684307,
  -0.005564831570460064,
  -0.005856828112884498,
  0.006236830958818875,
  0.004655058308477421,
  0.008753900731348751,
  0.006798400394003429,
  -0.007286078750574725,
  0.006177228497312775],
 [-0.3242576988654628,
  -0.28438163119209203,
  -0.3209326536497932,
  0.42161690617717873,
  0.33879218185827653,
  0.43384158307620757,
  -0.42020753786155707,
  0.43325140540358964,
  -0.45180365676689666,
  -0.361269406

In [23]:
import numpy as np
import matplotlib.pyplot as plt
k_m = []
k_b = []
for i in range(32):
    x = np.arange(100)
    y = (W_K.cpu()[h] @ W_E.cpu())[i, :].detach().numpy()

    m, b = np.polyfit(x, y, 1)
    k_m.append(m)
    k_b.append(b)
# print(m, b)
# plt.plot(x, y, 'yo', x, m*x+b, '--k')
# plt.show()
k_m, k_b


([-0.012436838962974296,
  -0.012928009095139711,
  -0.012069755128534225,
  0.011374593048697089,
  0.012974906779628414,
  0.012172606797973263,
  -0.012333837014755995,
  0.012567891390796358,
  -0.013041689986669863,
  -0.012529895600999442,
  -0.013889023164909491,
  0.013371105061406579,
  0.01139362374491704,
  -0.011568945774234589,
  -0.013668058965724652,
  0.013646985540373831,
  -0.012700307621446512,
  0.0133374358554243,
  -0.0144470701701302,
  0.013908503730707327,
  -0.010405020366143828,
  -0.013200220597902036,
  0.013545885171555149,
  -0.011961171854907148,
  0.01316051931947273,
  0.013651948504852946,
  -0.012364380270576745,
  -0.012196876210147421,
  -0.012926734443463355,
  -0.012520399914868344,
  0.011737883746816237,
  -0.012120573753468688],
 [0.7397982343597944,
  0.752646327138127,
  0.7123545527353992,
  -0.6659808033887332,
  -0.7675659866896596,
  -0.7282095131385001,
  0.7291017778991438,
  -0.7351234371601465,
  0.7632156175275397,
  0.7346300109510

In [24]:
q_m = torch.tensor(q_m)
q_b = torch.tensor(q_b)
k_b = torch.tensor(k_b)
k_m = torch.tensor(k_m)

In [25]:
Q = q_m.unsqueeze(1) @ torch.arange(100, dtype=torch.double).unsqueeze(0) + q_b.unsqueeze(1)
lines(Q, xaxis='input number', yaxis='linear regressed q values for each dimension', file_name='max_q_lin_reg.html')

In [26]:
K = k_m.unsqueeze(1) @ torch.arange(100, dtype=torch.double).unsqueeze(0) + k_b.unsqueeze(1)
lines(K, xaxis='input number', yaxis='linear regressed k values for each dimension', file_name='max_k_lin_reg.html')

In [27]:
imshow(Q.T @ K, xaxis='input number', yaxis='input number', title='Attention score according to the linear regressed q, k  values', file_name='max_lin_reg_attn_scores.html')

In [134]:
((q_m.sign() == 1).sum(), k_m.sign())

(tensor(19),
 tensor([-1., -1., -1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,  1., -1.,
         -1.,  1., -1.,  1., -1.,  1., -1., -1.,  1., -1.,  1.,  1., -1., -1.,
         -1., -1.,  1., -1.], dtype=torch.float64))

In [135]:
a = 0
b1 = 0
b2 = 0
c = 0
for i in range(32):
    a += q_m[i] * k_m[i]
    b1 += q_m[i] * k_b[i]
    b2 += k_m[i] * q_b[i]
    c += k_b[i] * q_b[i]
a, b1, b2, c
    

(tensor(-0.0025, dtype=torch.float64),
 tensor(0.1492, dtype=torch.float64),
 tensor(0.1519, dtype=torch.float64),
 tensor(-8.9321, dtype=torch.float64))

In [165]:
b2/a

tensor(-59.8491, dtype=torch.float64)

In [143]:
c/b2

tensor(-58.8098, dtype=torch.float64)

In [136]:
def f(x1, x2):
    return a * x1 * x2 + b1 * x1 + b2 * x2 + c

In [142]:
[(i, f(0, i)) for i in range(1, 100)]

[(1, tensor(-8.7802, dtype=torch.float64)),
 (2, tensor(-8.6283, dtype=torch.float64)),
 (3, tensor(-8.4765, dtype=torch.float64)),
 (4, tensor(-8.3246, dtype=torch.float64)),
 (5, tensor(-8.1727, dtype=torch.float64)),
 (6, tensor(-8.0208, dtype=torch.float64)),
 (7, tensor(-7.8689, dtype=torch.float64)),
 (8, tensor(-7.7171, dtype=torch.float64)),
 (9, tensor(-7.5652, dtype=torch.float64)),
 (10, tensor(-7.4133, dtype=torch.float64)),
 (11, tensor(-7.2614, dtype=torch.float64)),
 (12, tensor(-7.1095, dtype=torch.float64)),
 (13, tensor(-6.9576, dtype=torch.float64)),
 (14, tensor(-6.8058, dtype=torch.float64)),
 (15, tensor(-6.6539, dtype=torch.float64)),
 (16, tensor(-6.5020, dtype=torch.float64)),
 (17, tensor(-6.3501, dtype=torch.float64)),
 (18, tensor(-6.1982, dtype=torch.float64)),
 (19, tensor(-6.0464, dtype=torch.float64)),
 (20, tensor(-5.8945, dtype=torch.float64)),
 (21, tensor(-5.7426, dtype=torch.float64)),
 (22, tensor(-5.5907, dtype=torch.float64)),
 (23, tensor(-5.438

In [30]:
# How much can I corrupt V[smtg] with V[others...] till it no longer gives the correct reustl

pos = 0
totss = []
totss2 = []
for num in range(1, 100, 10):
    tots = []
    tots2 = []
    for k in np.linspace(0, 1, 10):
        tot = 0
        tot2 = 0
        for idx in range(1, 100):
            o = model.blocks[0].attn.W_O.cpu() @ torch.cat([(W_V.cpu()[h] @ ((1-k) * W_E[:, num].cpu() + (W_E[:, idx].cpu() + W_E[:, idx-1].cpu()) *(k/2) + W_pos[:, pos].cpu())) for h in range(4)])
            # print(o.shape)
            tot += ((num == (W_U.cpu() @ W_out.cpu() @ (torch.relu(W_in.cpu() @ o))).argmax(dim=-1).item()))
            tot2 += (abs(idx - (W_U.cpu() @ W_out.cpu() @ (torch.relu(W_in.cpu() @ o))).argmax(dim=-1).item()) < 2)
            # tot += ((num == (W_U.cpu() @ o).argmax(dim=-1).item()))
            # tot2 += (abs(idx - (W_U.cpu() @ o).argmax(dim=-1).item()) < 2)
        tots.append(tot)
        tots2.append(tot2)
    totss.append(tots)
    totss2.append(tots2)
# lines(tots)
# lines([tots, tots2])
lines(totss, x=np.linspace(0, 1, 10), xaxis="Ratio of corruption", yaxis="accuracy", file_name='max_robustness.html')
lines(totss2, x=np.linspace(0, 1, 10), xaxis="Ratio of corruption", yaxis="inaccuracy")

In [150]:
np.log(0.7)/np.log(2.71)

-0.3577666205215387

In [162]:
[np.exp(b1 * x)/(np.exp(b1 * x) + 4*np.exp(b1 * (x-16))) for x in range(100)]

[tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.7313, dtype=torch.float64),
 tensor(0.73

In [32]:
pos = 4
ls = []
for idx in range(1, 100):
    o = model.blocks[0].attn.W_O @ torch.cat([(W_V[h] @ (W_E[:, idx] + W_pos[:, pos])) for h in range(4)])
    # print(o.shape)
    print(idx, (W_U @ W_out @ (torch.relu(W_in @ o))).argmax(dim=-1).item())
    ls.append((W_U @ W_out @ (torch.relu(W_in @ o))).argmax(dim=-1).item())

lines([ls], xaxis="number", yaxis="final unembedded output corresponding to a number", file_name='max_ov_circuit.html')

1 25
2 25
3 18
4 21
5 21
6 18
7 74
8 23
9 25
10 21
11 23
12 23
13 40
14 21
15 23
16 18
17 18
18 26
19 40
20 21
21 21
22 22
23 23
24 42
25 25
26 26
27 27
28 28
29 29
30 30
31 31
32 32
33 33
34 34
35 35
36 36
37 37
38 38
39 39
40 40
41 41
42 42
43 43
44 44
45 45
46 46
47 47
48 48
49 49
50 50
51 51
52 52
53 53
54 54
55 55
56 56
57 57
58 58
59 59
60 60
61 61
62 62
63 63
64 64
65 65
66 66
67 67
68 68
69 69
70 70
71 71
72 72
73 73
74 74
75 75
76 76
77 77
78 78
79 79
80 80
81 81
82 82
83 83
84 84
85 85
86 86
87 87
88 88
89 89
90 90
91 91
92 92
93 93
94 94
95 95
96 96
97 97
98 98
99 99
