In [None]:
from einops import rearrange
from torch import nn
from torch.nn import functional as F
import torch


class AttentionProbe(nn.Module):
    def __init__(
        self, input_dim: int, output_dim: int = 1, num_heads: int = 1,
    ):
        super().__init__()

        self.query = nn.Parameter(
            torch.randn(1, num_heads, 1, input_dim),
        )
        self.output = nn.Linear(input_dim * num_heads, output_dim)

    def forward(self, x):
        # x is of shape (batch_size, num_heads, seq_len, dim)
        x = x[:, None].expand(-1, self.query.shape[1], -1, -1)
        q = self.query.expand(x.shape[0], -1, -1, -1)

        out = F.scaled_dot_product_attention(q, x, x)
        out = rearrange(out, "b h n d -> b (h n d)")
        return self.output(out)


probe = AttentionProbe(2304, 1, 8)
x = torch.randn(2, 10, 512)
out = probe(x)

In [16]:
import numpy as np

samples = []
for i in range(743):
    file = np.load(
        f"/mnt/ssd-1/nora/MOSAIC/output/activations/google_gemma-2-2b/Anthropic_election_questions/test/layer_5/16k/sample_{i}.npz"
    )
    samples.append(file['hidden_state'])


In [15]:
[s.shape for s in samples]

[(16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (16, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (24, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (21, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (18, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (23, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (22, 2304),
 (19, 2304),
 (19, 2304),
 (19, 2304),
 (19, 2304),
 (19, 2304),
 (19, 2304),
 (19, 2304),