# induction head
> This nb contains some exploration for the induction head

- skip_showdoc: true
- skip_exec: true

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class InductionHead(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(InductionHead, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)  # Query, Key, Value projections
        self.out_proj = nn.Linear(embed_dim, embed_dim)  # Output projection
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        # x shape: (seq_len, batch_size, embed_dim)
        batch_size, seq_len, embed_dim = x.size()

        # Project input to Q, K, V
        qkv = self.qkv_proj(x)  # Shape: (seq_len, batch_size, 3 * embed_dim)
        qkv = qkv.reshape(seq_len, batch_size, self.num_heads, 3 * embed_dim // self.num_heads)
        qkv = qkv.permute(2, 1, 0, 3)  # (num_heads, batch_size, seq_len, 3 * embed_dim // num_heads)
        Q, K, V = qkv.chunk(3, dim=-1)  # Split into Q, K, V

        # Compute attention
        attn_output, attn_weights = self.attention(Q, K, V)

        # Concatenate heads
        attn_output = attn_output.permute(1, 0, 2).contiguous().reshape(batch_size, seq_len, embed_dim)

        # Output projection
        output = self.out_proj(attn_output)
        return output, attn_weights

In [None]:
class InductionHead(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(InductionHead, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)  # Query, Key, Value projections
        self.out_proj = nn.Linear(embed_dim, embed_dim)  # Output projection
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        # x shape: (seq_len, batch_size, embed_dim)
        seq_len, batch_size, embed_dim = x.size()

        # Project input to Q, K, V
        qkv = self.qkv_proj(x)  # Shape: (seq_len, batch_size, 3 * embed_dim)
        qkv = qkv.reshape(seq_len, batch_size, 3, self.num_heads, embed_dim // self.num_heads)
        qkv = qkv.permute(2, 3, 1, 0, 4)  # (3, num_heads, batch_size, seq_len, embed_dim_per_head)
        Q, K, V = qkv[0], qkv[1], qkv[2]  # Split into Q, K, V

        # Reshape Q, K, V for attention mechanism
        # Q = Q.reshape(batch_size * self.num_heads, seq_len, embed_dim // self.num_heads)
        # K = K.reshape(batch_size * self.num_heads, seq_len, embed_dim // self.num_heads)
        # V = V.reshape(batch_size * self.num_heads, seq_len, embed_dim // self.num_heads)

        Q = Q.reshape(batch_size * self.num_heads, seq_len, embed_dim )
        K = K.reshape(batch_size * self.num_heads, seq_len, embed_dim )
        V = V.reshape(batch_size * self.num_heads, seq_len, embed_dim )

        # Apply attention (query, key, value)
        attn_output, attn_weights = self.attention(Q, K, V)

        # Reshape attention output back
        attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len, embed_dim // self.num_heads)
        attn_output = attn_output.permute(2, 0, 1, 3).reshape(seq_len, batch_size, embed_dim)

        # Output projection
        output = self.out_proj(attn_output)
        return output, attn_weights

In [None]:
# Example usage:
embed_dim = 512
num_heads = 8
induction_head = InductionHead(embed_dim, num_heads)

In [None]:
# Dummy input (batch_size, seq_len, embed_dim)
x = torch.randn(10, 20, embed_dim)
output, attn_weights = induction_head(x)
print(output.shape)  # Expected: (batch_size, seq_len, embed_dim)

RuntimeError: shape '[160, 10, 512]' is invalid for input of size 102400

In [None]:
x.device

device(type='cpu')