In [50]:
import matplotlib.pyplot as plt; import numpy as np; import time, torch; device = 'cuda' if torch.cuda.is_available() else 'cpu'
from transformers import AutoTokenizer, TrainingArguments, DefaultDataCollator, Trainer, EarlyStoppingCallback
vocab_size = 50257 # =tokenizer.vocab_size  # FIX!!! # G256128    ### T=256 for minGemma # G8192 for real Gemma
num_hidden_layers =   8 # 8 # G28 G18 #blocks
num_attention_heads = 6 # 4 # G16 G8
num_key_value_heads = 3 # 4 # G16 G1
hidden_size = num_attention_heads*56 # 128 # G3072 G2048 # embedding dimension
intermediate_size = hidden_size*1 # x4 or x8 # time limiting factor #512 # G24576 G16384  # MLP inner dim
head_dim = 96 # 32 # G256 # dim in attention # Doesn't affect time
rms_norm_eps = 1e-4 # 1e-6
rope_theta = 2400 # scale freq is small for S-model. 1000 might work too # G10000.0

def apply_rotary_emb(x: torch.Tensor, dim: int) -> torch.Tensor: # seq_len = x.size(1) # N
    freqs = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) # Dynamically compute frequency cis
    t = torch.arange(x.size(1), device=device); freqs = torch.outer(t, freqs).float(); freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    return x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

class RMSNorm(torch.nn.Module): # RMS:4.326552, RMS_no_weight:4.410741 # RMS':4.554899
    def __init__(self, dim: int = hidden_size):
        super().__init__(); self.weight = torch.nn.Parameter(torch.zeros(dim)) # one weight per feature to be learned
    def _norm(self, x): # mean square for each feature (across the last dimension)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + rms_norm_eps)
    def forward(self, x): # ensure the data type matches the input.
        return self._norm(x.float()).type_as(x) * (1 + self.weight)
        
class GemmaAttention(torch.nn.Module): # MQA = K,V shared by 4Qs
    def __init__(self):
        super().__init__(); self.qkv_proj = torch.nn.Linear(hidden_size, (num_attention_heads + 2 * num_key_value_heads) * head_dim, bias=False); self.o_proj = torch.nn.Linear(num_attention_heads * head_dim, hidden_size, bias=False) # concatenated attention outputs back to the hidden size.
    def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:  # in=(B, T, hidden_size)
        batch_size, input_len, _ = hidden_states.shape
        qkv = self.qkv_proj(hidden_states)
        xq, xk, xv = qkv.split([num_attention_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim],dim=-1)
        xq = xq.view(batch_size, -1, num_attention_heads, head_dim); xk = xk.view(batch_size, -1, num_key_value_heads, head_dim); xv = xv.view(batch_size, -1, num_key_value_heads, head_dim)
        xq = apply_rotary_emb(xq, head_dim); xk = apply_rotary_emb(xk, head_dim)
        if num_key_value_heads != num_attention_heads:  # Q/KV multiples of K and V to match Q
            xk = torch.repeat_interleave(xk, num_attention_heads // num_key_value_heads, dim=2) # [B, T, n_local_heads, head_dim]
            xv = torch.repeat_interleave(xv, num_attention_heads // num_key_value_heads, dim=2)
        q = xq.transpose(1, 2); k = xk.transpose(1, 2); v = xv.transpose(1, 2) # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True) # B nh T hs        
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)  # [B, T, "hidden_dim"]
        return self.o_proj(output)

class GemmaDecoderLayer(torch.nn.Module): # normalize before and after the attention mechanism
    def __init__(self):
        super().__init__(); self.self_attn = GemmaAttention(); self.input_layernorm = RMSNorm(); self.post_attention_layernorm = RMSNorm(); self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size); self.up_proj = torch.nn.Linear(hidden_size, intermediate_size); self.down_proj = torch.nn.Linear(intermediate_size, hidden_size) # mlp
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:  # input_size = (B, T, hidden_size)
        residual = hidden_states # Self Attention Block
        hidden_states = self.input_layernorm(hidden_states); hidden_states = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states
        residual = hidden_states # MLP Block
        hidden_states = self.post_attention_layernorm(hidden_states); gate = torch.nn.functional.gelu(self.gate_proj(hidden_states)); up = self.up_proj(hidden_states); fuse = gate * up; hidden_states = self.down_proj(fuse) # mlp
        return residual + hidden_states

class minGemma(torch.nn.Module):
    def __init__(self):
        super().__init__(); self.embedder = torch.nn.Embedding(vocab_size, hidden_size); self.norm = RMSNorm(); self.layers = torch.nn.ModuleList(GemmaDecoderLayer() for _ in range(num_hidden_layers))
    def forward(self, input_token_ids: torch.Tensor) -> torch.Tensor: # (B, T)
        hidden_states = self.embedder(input_token_ids[:,:-1]) # (B, T) & (vocab_size, hidden_size) -> (B, T, hidden_size)
        hidden_states = hidden_states * (hidden_size**0.5)
        for i in range(len(self.layers)):
            hidden_states = self.layers[i](hidden_states)
        hidden_states = self.norm(hidden_states) # -> (B, T, hidden_size)        
        embedder_weight = self.embedder.weight # for output layer
        logits = torch.matmul(hidden_states, embedder_weight.t()) # (B, T, hidden_size) @ (hidden_size, vocab_size) -> (B, T, vocab_size)
        loss = torch.nn.functional.cross_entropy(logits.view(B*T, vocab_size), input_token_ids[:,1:].reshape(B*T)) #, weight=None, ignore_index=-100, reduction='mean') # slower than .view???        
        return loss, logits # logits, loss

def collator(idx):
    if idx[0][0] < len(val_data):
        return {'input_token_ids': torch.stack([torch.from_numpy((val_data[i[0]:i[0]+T+1]).astype(np.int64)) for i in idx])}
    else:
        return {'input_token_ids': torch.stack([torch.from_numpy((train_data[i[0]-len(val_data):i[0]-len(val_data)+T+1]).astype(np.int64)) for i in idx])}

train_data = np.memmap('train_WK2_mk.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('val_WK2_mk.bin', dtype=np.uint16, mode='r')
T=512; B=8; N_step=8000; n_steps=8000; print(str(T * B * N_step / 1000000)+" million tokens")

model = minGemma().to(device); print(f'L{num_hidden_layers} att{num_attention_heads} kv_heads{num_key_value_heads} hidden{hidden_size} intermediate{intermediate_size} head_dim{head_dim} T{T} Nparam{sum(p.numel() for p in model.parameters()) / 10**6:.1f}'); torch.cuda.empty_cache()
training_args = TrainingArguments(weight_decay=1.0, learning_rate=7e-3, output_dir='./', num_train_epochs=1, load_best_model_at_end=True, per_device_train_batch_size=B, per_device_eval_batch_size=B, logging_strategy='steps', logging_steps=n_steps, eval_strategy='steps', eval_steps=n_steps, save_strategy='steps', save_steps=n_steps, report_to='none')
trainer = Trainer(model=model, args=training_args, data_collator=collator, train_dataset=torch.utils.data.TensorDataset(torch.randint(len(train_data)-T-1, (B*N_step,))+len(val_data)), eval_dataset=torch.utils.data.TensorDataset(torch.randint(len(val_data)-T-1, (B*64,))), callbacks = [EarlyStoppingCallback(3,0.0)]); trainer.can_return_loss = True;
trainer.train()

32.768 million tokens
L8 att6 kv_heads3 hidden336 intermediate336 head_dim96 T512 Nparam24.3


Step,Training Loss,Validation Loss
8000,4.7881,4.156473


TrainOutput(global_step=8000, training_loss=4.78814453125, metrics={'train_runtime': 1643.3021, 'train_samples_per_second': 38.946, 'train_steps_per_second': 4.868, 'total_flos': 0.0, 'train_loss': 4.78814453125, 'epoch': 1.0})