In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import math

import numpy as np
from huggingface_hub import hf_hub_download
from datasets import load_dataset

In [59]:
class RoPE(nn.Module):
    def __init__(self, head_dim, max_tokens=4096, base=10000):
        super(RoPE, self).__init__()
        
        assert head_dim % 2 == 0, "head_dim must be even"

        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
        t = torch.arange(max_tokens, dtype=torch.float32)
        freqs = torch.outer(t, inv_freq)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

        cos_cached = freqs_cis.real.repeat_interleave(2, dim=1)
        sin_cached = freqs_cis.imag.repeat_interleave(2, dim=1)

        self.register_buffer("cos_cached", cos_cached, persistent=False)
        self.register_buffer("sin_cached", sin_cached, persistent=False)

    def _rotate_half(self, x):
        x1 = x[..., 0::2]
        x2 = x[..., 1::2]
        return torch.stack((-x2, x1), dim=-1).flatten(-2)
    
    def forward(self, x, seq_len, offset=0):
        cos = self.cos_cached[offset:offset + seq_len]
        sin = self.sin_cached[offset:offset + seq_len]

        cos = cos.unsqueeze(0).unsqueeze(1)
        sin = sin.unsqueeze(0).unsqueeze(1)
        return (x * cos) + (self._rotate_half(x) * sin)

class MLAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, ukv_dim, max_tokens=4096):
        super(MLAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ukv_dim = ukv_dim
        self.max_tokens = max_tokens

        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = embed_dim // num_heads

        self.Wdkv = nn.Linear(embed_dim, ukv_dim, bias=False)
        self.Wuk = nn.Linear(ukv_dim, embed_dim, bias=False)
        self.Wuv = nn.Linear(ukv_dim, embed_dim, bias=False)
        self.Wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=False)

        self.rope = RoPE(self.head_dim, max_tokens)

        self.kv_latent_cache = None
        self.cache_pos = None

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _init_cache(self, batch_size, device, dtype):
        self.kv_latent_cache = torch.zeros(
            (batch_size, self.max_tokens, self.ukv_dim),
            device=device,
            dtype=dtype
        )
        self.cache_pos = torch.zeros(batch_size, dtype=torch.long, device=device)

    def forward(self, x, use_cache=True):
        batch_size, seq_len, _ = x.shape

        q = self.Wq(x)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        latent_kv = self.Wdkv(x)

        attn_mask = None
        current_pos_offset = 0
        if use_cache:
            if self.kv_latent_cache is None or self.kv_latent_cache.shape[0] != batch_size:
                self._init_cache(batch_size, x.device, x.dtype)
            
            current_pos_offset = self.cache_pos[0].item()

            start_pos = self.cache_pos.view(-1, 1)
            indices = start_pos + torch.arange(seq_len, device=x.device).view(1, -1)
            if torch.any(indices >= self.max_tokens):
                raise ValueError("Sequence length exceeds maximum tokens in cache.")
            
            self.kv_latent_cache.scatter_(1, indices.unsqueeze(-1).expand(-1, -1, self.ukv_dim), latent_kv)
            self.cache_pos += seq_len

            max_pos = self.cache_pos.max().item()
            full_latent_kv = self.kv_latent_cache[:, : max_pos, :]

            kv_indices = torch.arange(max_pos, device=x.device).view(1, -1)
            mask = kv_indices < self.cache_pos.view(-1, 1)
            attn_mask = mask.unsqueeze(1).unsqueeze(2)
        
        else:
            full_latent_kv = latent_kv
            attn_mask = None

        attn_seq_len = full_latent_kv.shape[1]

        k = self.Wuk(full_latent_kv)
        v = self.Wuv(full_latent_kv)

        k = k.view(batch_size, attn_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, attn_seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        q = self.rope(q, seq_len, offset=current_pos_offset)
        k = self.rope(k, attn_seq_len, offset=current_pos_offset)

        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)

        output = self.Wo(attn_output)

        return output

    def reset_cache(self):
        if self.kv_latent_cache is not None:
            self.kv_latent_cache.zero_()
            
        if self.cache_pos is not None:
            self.cache_pos.zero_()
    
    def null_cache(self):
        self.kv_latent_cache = None
        self.cache_pos = None

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    def forward(self, x):
        return self._norm(x.float()).type_as(x) * self.weight

class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super(FeedForward, self).__init__()
        self.w_gate = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.w_up = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.w_down = nn.Linear(hidden_dim, embed_dim, bias=False)
    
    def forward(self, x):
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
    
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super(TransformerBlock, self).__init__()
        self.attention = MLAttention(
            embed_dim=config['embed_dim'],
            num_heads=config['num_heads'],
            ukv_dim=config['ukv_dim'],
            max_tokens=config['max_tokens']
        )
        self.feed_forward = FeedForward(
            embed_dim=config['embed_dim'],
            hidden_dim=config['ffn_hidden_dim']
        )
        self.attention_norm = RMSNorm(config['embed_dim'])
        self.ffn_norm = RMSNorm(config['embed_dim'])
    
    def forward(self, x, use_cache=True):
        h = x + self.attention(self.attention_norm(x), use_cache=use_cache)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
    
class DeepSeekModel(nn.Module):
    def __init__(self, config):
        super(DeepSeekModel, self).__init__()
        self.config = config

        self.tok_embeddings = nn.Embedding(config['vocab_size'], config['embed_dim'])
        self.layers = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config['num_layers'])]
        )

        self.norm = RMSNorm(config['embed_dim'])

        self.output_head = nn.Linear(config['embed_dim'], config['vocab_size'], bias=False)
    
    def forward(self, input_ids, use_cache=True):
        x = self.tok_embeddings(input_ids)

        for layer in self.layers:
            x = layer(x, use_cache=use_cache)
        
        x = self.norm(x)
        logits = self.output_head(x)

        return logits
    
    def reset_cache(self):
        for layer in self.layers:
            layer.attention.reset_cache()
    
    @classmethod
    def from_config(cls, config):
        return cls(config)

In [61]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_config = {
    'vocab_size': 10000,
    'max_tokens': 4096,
    'num_layers': 4,
    'embed_dim': 256,
    'num_heads': 4,
    'ukv_dim': 64,
    'ffn_hidden_dim': 256 * 4
}

BATCH_SIZE = 16

print("=== Full Model Test ===")
print(f"Testing with Batch Size: {BATCH_SIZE}")
print(f"Configuration: {model_config}")

model = DeepSeekModel.from_config(model_config).to(device)
print(f"\nModel Architecture:\n{model}")

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal Trainable Parameters: {num_params / 1e6:.2f}M")

print("\n=== Testing Training Pass ===")
model.train()
model.reset_cache()
dummy_ids = torch.randint(0, model_config['vocab_size'], (4, 128)).to(device)
logits = model(dummy_ids, use_cache=False)
print(f"Input IDs shape: {dummy_ids.shape}")
print(f"Output logits shape: {logits.shape}")
assert logits.shape == (4, 128, model_config['vocab_size']), "Output shape mismatch"
print("Training pass successful!")

print("\n=== Testing Inference Pass ===")
model.eval()
model.reset_cache()

input_ids = torch.randint(0, model_config['vocab_size'], (BATCH_SIZE, 1)).to(device)

for i in range(5):
    logits = model(input_ids, use_cache=True)
    next_token_logits = logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    
    print(f"Step {i + 1}: Next token ID shape: {next_token_id.shape}")

    input_ids = torch.cat([input_ids, next_token_id], dim=1)

print("\nFinal generated sequence length:", input_ids.shape[1])
assert input_ids.shape[1] == 6, "Generated sequence length mismatch"
print("Inference pass successful!")




=== Full Model Test ===
Testing with Batch Size: 16
Configuration: {'vocab_size': 10000, 'max_tokens': 4096, 'num_layers': 4, 'embed_dim': 256, 'num_heads': 4, 'ukv_dim': 64, 'ffn_hidden_dim': 1024}

Model Architecture:
DeepSeekModel(
  (tok_embeddings): Embedding(10000, 256)
  (layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attention): MLAttention(
        (Wdkv): Linear(in_features=256, out_features=64, bias=False)
        (Wuk): Linear(in_features=64, out_features=256, bias=False)
        (Wuv): Linear(in_features=64, out_features=256, bias=False)
        (Wq): Linear(in_features=256, out_features=256, bias=False)
        (Wo): Linear(in_features=256, out_features=256, bias=False)
        (rope): RoPE()
      )
      (feed_forward): FeedForward(
        (w_gate): Linear(in_features=256, out_features=1024, bias=False)
        (w_up): Linear(in_features=256, out_features=1024, bias=False)
        (w_down): Linear(in_features=1024, out_features=256, bias=False)
      )
  