In [1]:
import plotly.express as px
import torch

from pathlib import Path

from torch import Tensor

from models import ToyTransformer
from train import NUM_SYMBOLS, ARRAY_LEN, EMBEDDING_DIM, PROJECTION_DIM, NUM_STEPS

In [2]:
checkpoint_path = Path("artifacts/dense_checkpointing/")

In [3]:
# initialize model
model = ToyTransformer(
    num_symbols=NUM_SYMBOLS,
    seq_len=(ARRAY_LEN + 1),
    embedding_dim=EMBEDDING_DIM,
    projection_dim=PROJECTION_DIM,
)

In [4]:
def get_queries(model: ToyTransformer) -> Tensor:
    with torch.no_grad():
        symbols = torch.arange(NUM_SYMBOLS)
        symbol_embeddings = model.symbol_embed(symbols)
        queries = model.attention_head.query_projection(symbol_embeddings)
        return queries

In [5]:
def get_keys(model: ToyTransformer) -> Tensor:
    with torch.no_grad():
        positions = torch.arange(ARRAY_LEN)
        position_embeddings = model.position_embed(positions)
        keys = model.attention_head.key_projection(position_embeddings)
        return keys

In [6]:
def get_qk_circuit(model: ToyTransformer) -> Tensor:
    with torch.no_grad():
        symbols = torch.arange(NUM_SYMBOLS)
        symbol_embeddings = model.symbol_embed(symbols)
        queries = model.attention_head.query_projection(symbol_embeddings)
        positions = torch.arange(ARRAY_LEN)
        position_embeddings = model.position_embed(positions)
        keys = model.attention_head.key_projection(position_embeddings)
        return torch.einsum("s d, p d -> s p", queries, keys)

In [7]:
checkpoint_fns: list[Path] = list(filter(lambda fn: fn.stem.isnumeric(), checkpoint_path.glob("*.pt")))
checkpoint_fns.sort(key=lambda fn: int(fn.stem))

In [22]:
qk_circuits = torch.zeros(size=(NUM_STEPS, NUM_SYMBOLS, ARRAY_LEN))
out_circuits = torch.zeros(size=(NUM_STEPS, NUM_SYMBOLS, NUM_SYMBOLS))

for checkpoint_fn in checkpoint_fns:
    step = int(checkpoint_fn.stem)
    if step >= NUM_STEPS:
        break

    checkpoint = torch.load(checkpoint_fn)
    model.load_state_dict(checkpoint)

    with torch.no_grad():
        symbols = torch.arange(NUM_SYMBOLS)
        symbol_embeddings = model.symbol_embed(symbols)
        positions = torch.arange(ARRAY_LEN)
        position_embeddings = model.position_embed(positions)

        out_circuit = model.unembed(model.attention_head.output_projection(symbol_embeddings))
        out_circuits[step] = out_circuit

        queries = model.attention_head.query_projection(symbol_embeddings)
        keys = model.attention_head.key_projection(position_embeddings)
        qk_circuit = torch.einsum("s d, p d -> s p", queries, keys)
        qk_circuits[step] = qk_circuit
        


In [23]:
torch.save(out_circuits, "artifacts/dense_checkpointing/out_circuits.pt")
torch.save(qk_circuits, "artifacts/dense_checkpointing/qk_circuits.pt")

In [33]:
px.imshow(out_circuits[:100:10],
          animation_frame=0,
          title="OV circuit",
          color_continuous_scale="RdBu",
          range_color=(-2, 2)
          )

In [36]:
px.imshow(qk_circuits[100:200:10],
          animation_frame=0,
          title="QK circuit",
          xaxis_title="Position",
          color_continuous_scale="RdBu",
          range_color=(-60, 60)
          )

TypeError: imshow() got an unexpected keyword argument 'xaxis_title'