<a href="https://colab.research.google.com/github/melwina/arena_exercises/blob/main/1_2_bonus_toy_trained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###i trained a 2L attention only transformer with positional encodings. This was based on the possible extension suggestions from ARENA exercise 1.2

### we can see some potential q composition. see:attention visualisation.

In [None]:
# !pip install transformer_lens
# !pip install circuitsvis
import torch as t
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)



In [None]:
model_path = "/content/induction_head_model (3).pt"

In [None]:
import circuitsvis as cv
from IPython.display import display

device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(device)
def get_model_config():
    cfg = HookedTransformerConfig(
        d_model=384,
        d_head=32,
        n_heads=12,
        n_layers=2,
        n_ctx=512,
        d_vocab=50278,
        attention_dir="causal",
        attn_only=True,
        tokenizer_name="gpt2",
        seed=42,
        device=t.device("cuda" if t.cuda.is_available() else "cpu")
    )
    return cfg

model = HookedTransformer(get_model_config())
model_path = "/content/induction_head_model (3).pt"
state_dict = t.load(model_path, map_location='cpu')
state_dict.pop("pos_encoder.pe", None)  # remove the unwanted key
model.load_state_dict(state_dict)

model.eval()

def generate_repeated_tokens(model, seq_len=50, batch_size=1):
    half_seq_len = seq_len // 2 + seq_len % 2
    first_half = t.randint(0, model.cfg.d_vocab, (batch_size, half_seq_len), device=model.cfg.device)
    first_half[:, 0] = model.tokenizer.bos_token_id
    second_half = first_half[:, 1:1+seq_len//2]
    tokens = t.cat([first_half, second_half], dim=1)[:, :seq_len]
    return tokens

tokens = generate_repeated_tokens(model, seq_len=100)
logits, cache = model.run_with_cache(tokens)

print(type(cache))
attention_pattern = cache["pattern", 0]
print(attention_pattern.shape)

seq_len = tokens.shape[1]
half = seq_len // 2
str_tokens = [str(i + 1) if i < half else str(i - half + 1) for i in range(seq_len)]

print("Layer 0 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_tokens,
        attention=attention_pattern[0]
    )
)

attention_pattern_layer1 = cache["pattern", 1]
print("\nLayer 1 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_tokens,
        attention=attention_pattern_layer1[0]
    )
)

def max_attention_to_token(cache, comparison_token , comparison_token_refs_diagonal = True, percentage = 0.4):
    current_attn_heads = []

    for i in range(model.cfg.n_layers):
        attention = cache["pattern", i][0]  # [n_heads, seq_len, seq_len]
        seqn_len = attention.shape[-1]

        for j in range(model.cfg.n_heads):
            attention_head = attention[j]  # [seq_len, seq_len]

            if comparison_token_refs_diagonal:
                num_self_attn = sum(
                    t.argmax(attention_head[k]).item() == k + comparison_token
                    for k in range(seqn_len)
                    if 0 <= k + comparison_token < seqn_len  # prevent index error
                )
            else:
                num_self_attn = sum(
                    t.argmax(attention_head[k]).item() == comparison_token
                    for k in range(seqn_len)
                )

            if abs(num_self_attn - seqn_len) < seqn_len * (1 - percentage):
                current_attn_heads.append(f"{i}.{j}")

    return current_attn_heads



def current_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    """
    return max_attention_to_token(cache, 0)


def prev_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """
    return max_attention_to_token(cache, -1)


def first_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """
    return max_attention_to_token(cache, 0, comparison_token_refs_diagonal=False)

def induction_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Remember - the tokens used to generate rep_cache are (bos_token, *rand_tokens, *rand_tokens)
    """
    current_attn_heads = []
    percentage = 0.4

    for i in range(model.cfg.n_layers):
      attention = cache["pattern", i][0]  # [n_heads, seq_len, seq_len] #[12, 62, 62]
      seqn_len = attention.shape[-1]

      for j in range(model.cfg.n_heads):
        attention_head = attention[j] #[62, 62]

        counter = 0

        num_self_attn = sum(t.argmax(attention_head[k]).item() == k -seqn_len/2 + 1 for k in range(int(seqn_len/2), seqn_len))
        if abs(num_self_attn - seqn_len/2) < seqn_len/2 * (1-percentage):
          current_attn_heads.append(f"{i}.{j}")

    return current_attn_heads

prev_heads = prev_attn_detector(cache)
induction_heads = induction_attn_detector(cache)

print(f"\nHeads attending to previous token: {', '.join(prev_heads)}")
print(f"Induction heads: {', '.join(induction_heads)}")

cpu
<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([1, 12, 99, 99])
Layer 0 Head Attention Patterns:



Layer 1 Head Attention Patterns:



Heads attending to previous token: 
Induction heads: 
