# load Best Model

In [1]:
### Please choose a model (L48 Normal Baby10M / L22 FORGETTER Baby10M / L18 Normal Baby100M / L16 FORGETTER Baby100M )
# num_hidden_layers = 48; num_attention_heads = 3; num_key_value_heads = 1; hidden_size = num_attention_heads*208; intermediate_size = hidden_size*4; head_dim = 224; rms_norm_eps = 1e-6; rope_theta = 1000.0
# num_hidden_layers = 22; num_attention_heads = 8; num_key_value_heads = 4; hidden_size = num_attention_heads*84; intermediate_size = hidden_size*4; head_dim = 192; rms_norm_eps = 1e-6; rope_theta = 1000.0
num_hidden_layers = 18; num_attention_heads = 8; num_key_value_heads = 4; hidden_size = num_attention_heads*72; intermediate_size = hidden_size*4; head_dim = 256; rms_norm_eps = 1e-4; rope_theta = 4000.0
# num_hidden_layers = 16; num_attention_heads = 9; num_key_value_heads = 3; hidden_size = num_attention_heads*72; intermediate_size = hidden_size*4; head_dim = 192; rms_norm_eps = 1e-4; rope_theta = 4000.0


vocab_size = 50257; T=512; from transformers import AutoTokenizer; import matplotlib.pyplot as plt; import numpy as np; import torch; device = 'cuda' if torch.cuda.is_available() else 'cpu'

def apply_rotary_emb(x: torch.Tensor, dim: int) -> torch.Tensor: # seq_len = x.size(1) # N
    freqs = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) # Dynamically compute frequency cis
    t = torch.arange(x.size(1), device=device); freqs = torch.outer(t, freqs).float(); freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    return x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

class RMSNorm(torch.nn.Module): # RMS:4.326552, RMS_no_weight:4.410741 # RMS':4.554899
    def __init__(self, dim: int = hidden_size):
        super().__init__(); self.weight = torch.nn.Parameter(torch.zeros(dim)) # one weight per feature to be learned
    def _norm(self, x): # mean square for each feature (across the last dimension)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + rms_norm_eps)
    def forward(self, x): # ensure the data type matches the input.
        return self._norm(x.float()).type_as(x) * (1 + self.weight)

class GemmaAttention(torch.nn.Module): # MQA = K,V shared by 4Qs
    def __init__(self):
        super().__init__(); self.qkv_proj = torch.nn.Linear(hidden_size, (num_attention_heads + 2 * num_key_value_heads) * head_dim, bias=False); self.o_proj = torch.nn.Linear(num_attention_heads * head_dim, hidden_size, bias=False) # concatenated attention outputs back to the hidden size.
    def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:  # in=(B, T, hidden_size)
        batch_size, input_len, _ = hidden_states.shape
        qkv = self.qkv_proj(hidden_states)
        xq, xk, xv = qkv.split([num_attention_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim],dim=-1)
        xq = xq.view(batch_size, -1, num_attention_heads, head_dim); xk = xk.view(batch_size, -1, num_key_value_heads, head_dim); xv = xv.view(batch_size, -1, num_key_value_heads, head_dim)
        xq = apply_rotary_emb(xq, head_dim); xk = apply_rotary_emb(xk, head_dim)
        if num_key_value_heads != num_attention_heads:  # Q/KV multiples of K and V to match Q
            xk = torch.repeat_interleave(xk, num_attention_heads // num_key_value_heads, dim=2) # [B, T, n_local_heads, head_dim]
            xv = torch.repeat_interleave(xv, num_attention_heads // num_key_value_heads, dim=2)
        q = xq.transpose(1, 2); k = xk.transpose(1, 2); v = xv.transpose(1, 2) # [batch_size, n_local_heads, input_len, head_dim]
        output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True) # B nh T hs        
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)  # [B, T, "hidden_dim"]
        return self.o_proj(output)

class GemmaDecoderLayer(torch.nn.Module): # normalize before and after the attention mechanism
    def __init__(self):
        super().__init__(); self.self_attn = GemmaAttention(); self.input_layernorm = RMSNorm(); self.post_attention_layernorm = RMSNorm(); self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size); self.up_proj = torch.nn.Linear(hidden_size, intermediate_size); self.down_proj = torch.nn.Linear(intermediate_size, hidden_size) # mlp
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:  # input_size = (B, T, hidden_size)
        residual = hidden_states # Self Attention Block
        hidden_states = self.input_layernorm(hidden_states); hidden_states = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states
        residual = hidden_states # MLP Block
        hidden_states = self.post_attention_layernorm(hidden_states); gate = torch.nn.functional.gelu(self.gate_proj(hidden_states)); up = self.up_proj(hidden_states); fuse = gate * up; hidden_states = self.down_proj(fuse) # mlp
        return residual + hidden_states

class minGemma(torch.nn.Module):
    def __init__(self):
        super().__init__(); self.embedder = torch.nn.Embedding(vocab_size, hidden_size); self.layers = torch.nn.ModuleList(GemmaDecoderLayer() for _ in range(num_hidden_layers)); self.norm = RMSNorm();
    def forward(self, input_token_ids: torch.Tensor) -> torch.Tensor: # (B, T)
        hidden_states = self.embedder(input_token_ids[:,:-1]) # (B, T) & (vocab_size, hidden_size) -> (B, T, hidden_size)
        hidden_states = hidden_states * (hidden_size**0.5)
        for i in range(len(self.layers)):
            hidden_states = self.layers[i](hidden_states) # shortened too much???
        hidden_states = self.norm(hidden_states) # -> (B, T, hidden_size)
        embedder_weight = self.embedder.weight
        logits = torch.matmul(hidden_states, embedder_weight.t()); b,t,v=logits.shape; # (B, T, hidden_size) @ (hidden_size, vocab_size) -> (B, T, vocab_size)
        loss = torch.nn.functional.cross_entropy(logits.view(b*t,v), input_token_ids[:,1:].reshape(b*t)) #, weight=None, ignore_index=-100, reduction='mean')
        return loss, logits # logits, loss

model = minGemma().to(device); print(f'L{num_hidden_layers}' f' att{num_attention_heads}' f' kv_heads{num_key_value_heads}' f' hidden{hidden_size}' f' intermediate{intermediate_size}' f' head_dim{head_dim}' f' T{T}')


### Please choose a model to load
# model.load_state_dict(torch.load('Normal Models_Baby10M/minGemma-hidden_layers48-att_heads3-kv_heads1-hidden624-intermediate2496-head_dim224-T512--2025-07-20-01-06.pth'))
# model.load_state_dict(torch.load('FORGETTER Models_Baby10M/minGemma-hidden_layers22-att_heads8-kv_heads4-hidden672-intermediate2688-head_dim192-T512--2025-06-21-17-20.pth'))
model.load_state_dict(torch.load('Normal Models_Baby100M/minGemma-hidden_layers18-att_heads8-kv_heads4-hidden576-intermediate2304-head_dim256-T512--2025-07-21-10-59.pth'))
# model.load_state_dict(torch.load('FORGETTER Models_Baby100M/minGemma-hidden_layers16-att_heads9-kv_heads3-hidden648-intermediate2592-head_dim192-T512--2025-06-25-10-20.pth'))

  from .autonotebook import tqdm as notebook_tqdm


L18 att8 kv_heads4 hidden576 intermediate2304 head_dim256 T512


<All keys matched successfully>

### BLiMP for best models

In [2]:
# fast version # BLiMP for MinGemma "model"
import os, json, re; import numpy as np; tokenizer = AutoTokenizer.from_pretrained('gpt2'); tokenizer.pad_token = tokenizer.eos_token; model.eval()
accuracy=[]; files = os.listdir("./blimp-master/data/")
for filename in files:
    correct = 0; total = 0
    f = open("./blimp-master/data/"+filename); print(filename)
    for _ in range(20):
        lines = [f.readline() for _ in range(50)] # 22.7G for L48 #  lines = f.readlines() # len(list)==1000
        batch_bad = [re.sub(r"\n+", "\n", json.loads(x)["sentence_bad"]).replace("\n"," ") for x in lines]
        batch_good = [re.sub(r"\n+", "\n", json.loads(x)["sentence_good"]).replace("\n"," ") for x in lines]
        
        bads = tokenizer.batch_encode_plus(batch_bad, padding="longest", max_length=512, truncation=True, return_tensors='pt')["input_ids"]
        bads = torch.tensor(bads, device='cuda'); # torch.cuda.empty_cache(); print(bads.shape)
        pred = model(bads)[1].to("cpu")

        goods = tokenizer.batch_encode_plus(batch_good, padding="longest", max_length=512, truncation=True, return_tensors='pt')["input_ids"]
        goods = torch.tensor(goods, device='cuda'); # torch.cuda.empty_cache()
        pred2 = model(goods)[1].to("cpu")

        for l in range(len(lines)):
            ans = bads[l,1:]; ans = ans[ans!=50256]
            likeli_list = [pred[l, i, ans[i]] for i in range(len(ans))]
            likelihood_bad = sum(torch.tensor(likeli_list) - torch.log(torch.sum(torch.exp(pred[l,0:len(ans),:]),1)))

            ans2 = goods[l,1:]; ans2 = ans2[ans2!=50256]
            likeli_list2 = [pred2[l, i, ans2[i]] for i in range(len(ans2))]
            likelihood_good = sum(torch.tensor(likeli_list2) - torch.log(torch.sum(torch.exp(pred2[l,0:len(ans2),:]),1)))
        
            total += 1
            if likelihood_bad < likelihood_good:
                correct += 1
    accuracy.append(correct/total); print(correct/total)
np.mean(np.array(accuracy)) # 0.7033731343283581 for L48N_Baby10M  # 0.7256567164179104 for L22F_Baby10M  # 0.7668955223880597 for L18N_Baby100M  # 0.7748059701492538 for L16F_Baby100M     # cf:  0.7175970149253731 for our best model for WikiText-103

adjunct_island.jsonl


  bads = torch.tensor(bads, device='cuda'); # torch.cuda.empty_cache(); print(bads.shape)
  output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True) # B nh T hs
  goods = torch.tensor(goods, device='cuda'); # torch.cuda.empty_cache()


0.801
anaphor_gender_agreement.jsonl
0.946
anaphor_number_agreement.jsonl
0.976
animate_subject_passive.jsonl
0.793
animate_subject_trans.jsonl
0.87
causative.jsonl
0.724
complex_NP_island.jsonl
0.518
coordinate_structure_constraint_complex_left_branch.jsonl
0.552
coordinate_structure_constraint_object_extraction.jsonl
0.805
determiner_noun_agreement_1.jsonl
0.98
determiner_noun_agreement_2.jsonl
0.97
determiner_noun_agreement_irregular_1.jsonl
0.858
determiner_noun_agreement_irregular_2.jsonl
0.956
determiner_noun_agreement_with_adjective_1.jsonl
0.946
determiner_noun_agreement_with_adj_2.jsonl
0.908
determiner_noun_agreement_with_adj_irregular_1.jsonl
0.837
determiner_noun_agreement_with_adj_irregular_2.jsonl
0.885
distractor_agreement_relational_noun.jsonl
0.804
distractor_agreement_relative_clause.jsonl
0.7
drop_argument.jsonl
0.778
ellipsis_n_bar_1.jsonl
0.812
ellipsis_n_bar_2.jsonl
0.831
existential_there_object_raising.jsonl
0.702
existential_there_quantifiers_1.jsonl
0.984
exis

0.7668955223880597