In [2]:
import torch

In [3]:
embedding = torch.nn.Embedding(10, 3)
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print(embedding(input))

tensor([[[ 0.3749, -0.5688, -0.4704],
         [ 0.4444, -0.7900, -0.8251],
         [ 1.4089, -0.7410, -0.8014],
         [ 0.7591, -0.3077, -1.0916]],

        [[ 1.4089, -0.7410, -0.8014],
         [ 0.7296, -0.4963, -0.9164],
         [ 0.4444, -0.7900, -0.8251],
         [ 0.4487, -1.5731,  0.6695]]], grad_fn=<EmbeddingBackward0>)


In [5]:
from transformers import Gemma2Model, Gemma2Config

config = Gemma2Config()

model = Gemma2Model.from_pretrained("google/gemma-2-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
model

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,),

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        # Placeholder for rotary embedding initialization
        pass

    def forward(self, x):
        # Apply rotary embedding
        return x


class PytorchGELUTanh(nn.Module):
    def forward(self, x):
        return F.gelu(x) * torch.tanh(x)


class Gemma2RMSNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(features))

    def forward(self, x):
        norm = x.norm(2, dim=-1, keepdim=True)
        return x * self.weight / (norm + self.eps)


class Gemma2Attention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.q_proj = nn.Linear(embed_dim, 2048, bias=False)
        self.k_proj = nn.Linear(embed_dim, 1024, bias=False)
        self.v_proj = nn.Linear(embed_dim, 1024, bias=False)
        self.o_proj = nn.Linear(2048, embed_dim, bias=False)
        self.rotary_emb = Gemma2RotaryEmbedding()

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k = self.rotary_emb(q), self.rotary_emb(k)
        # Implement attention mechanism
        scores = torch.matmul(q, k.transpose(-2, -1)) / q.size(-1)**0.5
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, v)
        return self.o_proj(context)


class Gemma2MLP(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.gate_proj = nn.Linear(embed_dim, 9216, bias=False)
        self.up_proj = nn.Linear(embed_dim, 9216, bias=False)
        self.down_proj = nn.Linear(9216, embed_dim, bias=False)
        self.act_fn = PytorchGELUTanh()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x) * self.up_proj(x)))


class Gemma2DecoderLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.self_attn = Gemma2Attention(embed_dim)
        self.mlp = Gemma2MLP(embed_dim)
        self.input_layernorm = Gemma2RMSNorm(embed_dim)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(embed_dim)
        self.post_feedforward_layernorm = Gemma2RMSNorm(embed_dim)
        self.post_attention_layernorm = Gemma2RMSNorm(embed_dim)

    def forward(self, x):
        x = x + self.self_attn(self.input_layernorm(x))
        x = x + self.mlp(self.pre_feedforward_layernorm(x))
        return self.post_feedforward_layernorm(x)


class Gemma2Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.layers = nn.ModuleList([Gemma2DecoderLayer(embed_dim) for _ in range(num_layers)])
        self.norm = Gemma2RMSNorm(embed_dim)

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


# Example instantiation
model = Gemma2Model(vocab_size=256, embed_dim=8, num_layers=4)
print(model)

Gemma2Model(
  (embed_tokens): Embedding(256, 8, padding_idx=0)
  (layers): ModuleList(
    (0-3): 4 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=8, out_features=2048, bias=False)
        (k_proj): Linear(in_features=8, out_features=1024, bias=False)
        (v_proj): Linear(in_features=8, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=8, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=8, out_features=9216, bias=False)
        (up_proj): Linear(in_features=8, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=8, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm()
      (pre_feedforward_layernorm): Gemma2RMSNorm()
      (post_feedforward_layernorm): Gemma2RMSNorm()
      (post_attention_layernorm): Gemma2RMSNorm()
    )
  )
  (norm): G

In [10]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class Gemma2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
        
    def forward(self, x):
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x

class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_position_embeddings = max_position_embeddings
        
    def forward(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos, sin

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class Gemma2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, (self.num_heads//2) * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, (self.num_heads//2) * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        
        self.rotary_emb = Gemma2RotaryEmbedding(self.head_dim)
        
    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length = hidden_states.shape[:2]
        
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)
        
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_length, self.num_heads//2, self.head_dim)
        v = v.view(batch_size, seq_length, self.num_heads//2, self.head_dim)
        
        cos, sin = self.rotary_emb(v, seq_length)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        k = k.repeat(1, 1, 2, 1)  # Repeat to match number of heads
        v = v.repeat(1, 1, 2, 1)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores + attention_mask
            
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        attn_output = attn_output.reshape(batch_size, seq_length, -1)
        attn_output = self.o_proj(attn_output)
        
        return attn_output

class Gemma2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.GELU()
        
    def forward(self, x):
        gate = self.act_fn(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)

class Gemma2DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Gemma2Attention(config)
        self.mlp = Gemma2MLP(config)
        
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size)
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size)
        
    def forward(self, hidden_states, attention_mask=None):
        # Self attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(hidden_states, attention_mask)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states
        
        # MLP
        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states

class Gemma2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.layers = nn.ModuleList([Gemma2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = Gemma2RMSNorm(config.hidden_size)
        
    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.embed_tokens(input_ids)
        
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
            
        hidden_states = self.norm(hidden_states)
        return hidden_states

config = Gemma2Config(
    vocab_size=256,
    hidden_size=8,
    num_hidden_layers=4,
    num_attention_heads=2,
    head_dim=4,
    intermediate_size=16
)

model = Gemma2Model(config)
print(model)

Gemma2Model(
  (embed_tokens): Embedding(256, 8, padding_idx=0)
  (layers): ModuleList(
    (0-3): 4 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=8, out_features=8, bias=False)
        (k_proj): Linear(in_features=8, out_features=4, bias=False)
        (v_proj): Linear(in_features=8, out_features=4, bias=False)
        (o_proj): Linear(in_features=8, out_features=8, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=8, out_features=16, bias=False)
        (up_proj): Linear(in_features=8, out_features=16, bias=False)
        (down_proj): Linear(in_features=16, out_features=8, bias=False)
        (act_fn): GELU(approximate='none')
      )
      (input_layernorm): Gemma2RMSNorm()
      (pre_feedforward_layernorm): Gemma2RMSNorm()
      (post_feedforward_layernorm): Gemma2RMSNorm()
      (post_attention_layernorm): Gemma2RMSNorm()
    )
  )
  (norm): Gemma2RMSNor

In [12]:
embed_tokens = nn.Embedding(256, 8, padding_idx=0)
embed_tokens

Embedding(256, 8, padding_idx=0)

In [13]:
example_input = [1, 5, 3]
embed_tokens(torch.tensor(example_input))

tensor([[-1.7893, -1.2624,  1.7711,  0.6604,  1.3104, -1.6788, -0.0572, -0.6773],
        [-0.3582, -0.4415, -0.0198, -0.9026,  0.8654, -2.0028, -0.7236, -0.2078],
        [-0.1072,  1.5299,  0.3908, -0.2816, -0.1998, -0.7492, -1.3824,  1.2076]],
       grad_fn=<EmbeddingBackward0>)

In [75]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F


def precompute_rotary_embedding(dim: int, seq_len: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(seq_len)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)


def apply_rotary_embedding(x, freqs):
    """
    Apply rotary embedding to the input tensor.
    x: Tensor of shape [batch_size, seq_len, num_heads, head_dim].
    freqs: Tensor of shape [seq_len, head_dim // 2].
    """
    batch_size, seq_len, num_heads, head_dim = x.shape

    # Ensure head_dim is divisible by 2
    assert head_dim % 2 == 0, "head_dim must be divisible by 2 for rotary embeddings"

    # Adjust freqs to match the current sequence length
    freqs = freqs[:seq_len].to(x.device)  # Slice freqs to match seq_len

    # Split head_dim into real and imaginary parts
    x = x.view(batch_size, seq_len, num_heads, head_dim // 2, 2)
    x_complex = torch.view_as_complex(x)  # Convert to complex numbers

    # Expand freqs to match [seq_len, num_heads, head_dim // 2]
    freqs = freqs.unsqueeze(1).repeat(1, num_heads, 1)

    # Apply rotary embedding
    x_rotated = torch.view_as_real(x_complex * freqs)

    # Reshape back to original dimensions
    return x_rotated.view(batch_size, seq_len, num_heads, head_dim)


class Gemma2RMSNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(features))

    def forward(self, x):
        norm = x.norm(2, dim=-1, keepdim=True)
        return x * self.weight / (norm + self.eps)


class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, seq_len: int, theta: float = 10000.0):
        super().__init__()
        self.freqs = precompute_rotary_embedding(dim, seq_len, theta)

    def forward(self, x):
        return apply_rotary_embedding(x, self.freqs[:x.size(1)].to(x.device))


class PytorchGELUTanh(nn.Module):
    def forward(self, x):
        return F.gelu(x) * torch.tanh(x)


class Gemma2Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim, seq_len):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=False)
        self.rotary_emb = Gemma2RotaryEmbedding(head_dim, seq_len)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        head_dim = self.q_proj.out_features // self.num_heads

        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)

        # Apply rotary embedding
        freqs = self.rotary_emb.freqs[:seq_len].to(x.device)
        q = apply_rotary_embedding(q, freqs)
        k = apply_rotary_embedding(k, freqs)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)

        # Apply the causal mask
        if mask is not None:
            mask = mask.expand(batch_size, self.num_heads, seq_len, seq_len)  # Broadcast to match scores
            scores = scores.masked_fill(mask, float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        return self.o_proj(attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1))


class Gemma2MLP(nn.Module):
    def __init__(self, embed_dim, intermediate_dim):
        super().__init__()
        self.gate_proj = nn.Linear(embed_dim, intermediate_dim, bias=False)
        self.up_proj = nn.Linear(embed_dim, intermediate_dim, bias=False)
        self.down_proj = nn.Linear(intermediate_dim, embed_dim, bias=False)
        self.act_fn = PytorchGELUTanh()

    def forward(self, x):
        gated_output = self.gate_proj(x)
        up_output = self.up_proj(x)
        fused = gated_output * up_output
        activated = self.act_fn(fused)
        return self.down_proj(activated)


class Gemma2DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim, intermediate_dim, seq_len):
        super().__init__()
        self.self_attn = Gemma2Attention(embed_dim, num_heads, head_dim, seq_len)
        self.mlp = Gemma2MLP(embed_dim, intermediate_dim)
        self.input_layernorm = Gemma2RMSNorm(embed_dim)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(embed_dim)
        self.post_feedforward_layernorm = Gemma2RMSNorm(embed_dim)
        self.post_attention_layernorm = Gemma2RMSNorm(embed_dim)

    def forward(self, x, mask=None):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, mask) + residual

        residual = x
        x = self.pre_feedforward_layernorm(x)
        x = self.mlp(x) + residual
        x = self.post_feedforward_layernorm(x)

        return x


class Gemma2Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, head_dim, intermediate_dim, seq_len):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.layers = nn.ModuleList([
            Gemma2DecoderLayer(embed_dim, num_heads, head_dim, intermediate_dim, seq_len) for _ in range(num_layers)
        ])
        self.norm = Gemma2RMSNorm(embed_dim)

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids)

        # Dynamically generate a causal mask based on the input sequence length
        seq_len = input_ids.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1)  # Shape: [1, 1, seq_len, seq_len]

        for layer in self.layers:
            x = layer(x, causal_mask)
        return self.norm(x)


    def save_weights(self, directory):
        """Save weights for each component to a directory."""
        os.makedirs(directory, exist_ok=True)
        torch.save(self.embed_tokens.state_dict(), os.path.join(directory, "embed_tokens.pth"))
        torch.save(self.norm.state_dict(), os.path.join(directory, "norm.pth"))
        for i, layer in enumerate(self.layers):
            torch.save(layer.state_dict(), os.path.join(directory, f"layer_{i}.pth"))

    def load_weights(self, directory):
        """Load weights for each component from a directory."""
        self.embed_tokens.load_state_dict(torch.load(os.path.join(directory, "embed_tokens.pth")))
        self.norm.load_state_dict(torch.load(os.path.join(directory, "norm.pth")))
        for i, layer in enumerate(self.layers):
            layer.load_state_dict(torch.load(os.path.join(directory, f"layer_{i}.pth")))


In [None]:
# Example instantiation
vocab_size = 256000
embed_dim = 2304
num_layers = 26
num_heads = 8
head_dim = 256
intermediate_dim = 9216
seq_len = 1024

model = Gemma2Model(vocab_size, embed_dim, num_layers, num_heads, head_dim, intermediate_dim, seq_len)
print(model)

In [78]:
# Example parameters for a smaller model
vocab_size = 25600
embed_dim = 256
num_layers = 4
num_heads = 2
head_dim = 128
intermediate_dim = 512
seq_len = 16

# Instantiate smaller model
model = Gemma2Model(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    head_dim=head_dim,
    intermediate_dim=intermediate_dim,
    seq_len=seq_len
)

# Example input
example_input = [1, 5, 12]
input_tensor = torch.tensor([example_input])  # Shape: [1, seq_len]

# Forward pass
logits = model(input_tensor)
print("Output shape:", logits.shape)  # Should match [1, seq_len, embed_dim]

# Project to vocabulary logits
vocab_logits = logits @ model.embed_tokens.weight.T  # Shape: [1, seq_len, vocab_size]

# Extract logits for the last token
last_token_logits = vocab_logits[:, -1, :]  # Shape: [1, vocab_size]

# Apply softmax to get probabilities
probs = torch.softmax(last_token_logits, dim=-1)

# Get the token with the highest probability (greedy decoding)
predicted_token = torch.argmax(probs, dim=-1).item()

print(f"Predicted next token: {predicted_token}")

Output shape: torch.Size([1, 3, 256])
Predicted next token: 12
