A wrong use of some ops like views might distort the proper backpropagation gradient flow. The following test ensure the newly architetture does not mess up gradient information across the batch dimension.

In [2]:
from attention_approximation.modeling_llama_approximated import LlamaForCausalLM
import torch
from attention_approximation.data import DistributedDataLoader

device = torch.device('mps')
model_config_path = "../data/MobileLLM/config.json"
data_path = "../data/minipile"

config = LlamaForCausalLM.config_class.from_json_file(model_config_path)
config.factorization_rank = 32
config.layer_sharing = False
config.seq_length = 512

model = LlamaForCausalLM(config)
loader = DistributedDataLoader(data_path, 4, 512, "train")
model.to(device)


Found 16 shards for split train


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaApproximatedAttention(
          (cp_circuit): CPCircuitLayer(
            (seq_mode_factor): Linear(in_features=576, out_features=32, bias=False)
            (cp): CP()
          )
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
)

In [12]:
# Test a forward
model.eval()
x, y = loader.next_batch()
x = x.to(device)
with torch.inference_mode():
    outs = model(x).logits
outs.shape

torch.Size([4, 512, 32000])

In [None]:
model.eval()  # Set to eval mode to avoid batch norm issues
# Load batch
x, y = loader.next_batch()
x = x.to(device)

# Get embeddings and detach to make it a leaf variable
embeddings = model.model.embed_tokens(x).detach().clone()
embeddings.requires_grad = True

# Forward pass through the rest of the model
hidden_states = embeddings

_, seq_len, hidden_size = hidden_states.size()
device = hidden_states.device
grid_y, grid_x = torch.meshgrid(
    torch.arange(seq_len, dtype=torch.long, device=device),
    torch.arange(hidden_size, dtype=torch.long, device=device),
    indexing="ij"
)
all_indices = torch.stack([grid_y, grid_x], dim=-1).view(-1, 2)

for layer in model.model.layers:
    hidden_states = layer(hidden_states, all_indices)[0]
hidden_states = model.model.norm(hidden_states)
outs = torch.nn.functional.linear(hidden_states, model.model.embed_tokens.weight)

# Test: loss depends only on batch element 2
test_batch_idx = 2
loss = outs[test_batch_idx].sum()

# Backward pass
loss.backward()

# Verify that only embeddings[test_batch_idx] has non-zero gradients
print(f"Testing gradient dependencies for batch index {test_batch_idx}:")
for i in range(embeddings.shape[0]):
    if i == test_batch_idx:
        has_nonzero = (embeddings.grad[i] != 0).any().item()
        print(f"  embeddings.grad[{i}] has non-zero values: {has_nonzero} (expected: True)")
        assert has_nonzero, f"Expected non-zero gradients for batch {i}"
    else:
        all_zero = (embeddings.grad[i] == 0).all().item()
        print(f"  embeddings.grad[{i}] is all zeros: {all_zero} (expected: True)")
        assert all_zero, f"Expected zero gradients for batch {i}, but found non-zero values!"

print("\n✓ Test passed! No information leakage across batch dimension.")

LlamaConfig {
  "architectures": [
    "MobileLLMForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "auto_map": {
    "AutoConfig": "configuration_mobilellm.MobileLLMConfig",
    "AutoModelForCausalLM": "modeling_mobilellm.MobileLLMForCausalLM"
  },
  "bos_token_id": 1,
  "dtype": "float16",
  "eos_token_id": 2,
  "factorization_rank": 32,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 576,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "layer_sharing": false,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 9,
  "num_hidden_layers": 30,
  "num_key_value_heads": 3,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "seq_length": 512,
  "share_embedding": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.56.0",
  "use_cache": true,
  "vocab_size": 32000
}



Found 16 shards for split train


Testing gradient dependencies for batch index 2:
  embeddings.grad[0] is all zeros: True (expected: True)
  embeddings.grad[1] is all zeros: True (expected: True)
  embeddings.grad[2] has non-zero values: True (expected: True)
  embeddings.grad[3] is all zeros: True (expected: True)

✓ Test passed! No information leakage across batch dimension.


This test make sure that the model can overfit a batch


In [None]:
from attention_approximation.pytorch import init_seeds
init_seeds(1337)

config.factorization_rank = 16
model = LlamaForCausalLM(config)
model = model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)

x, y = loader.next_batch()
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)

for i in range(200):
    optimizer.zero_grad()
    outputs = model(x).logits
    loss = torch.nn.functional.cross_entropy(outputs.view(-1, outputs.size(-1)), y.view(-1))
    loss.backward()
    optimizer.step()
    if i <= 20 or i % 50:
        print(f"step {i}, loss = {loss.item()}")



In [10]:
from attention_approximation.data import TokenDataset, distribute_dataloader
import os
from attention_approximation.pytorch import RANK, DistributedEvalSampler, seed_worker
from pathlib import Path
from transformers import AutoTokenizer
import json

tokenizer = AutoTokenizer.from_pretrained("facebook/MobileLLM-350M-layer-share", use_fast=False, legacy=False)

ds = TokenDataset('../data/minipileA', seq_len=30, split="train")
dl = distribute_dataloader(ds, batch=5, shuffle=True, mode="train")
for x, y in dl:
    bs = x.size(0)
    for i in range(bs):
        print(f"Text: {tokenizer.decode(x[i])}, Next token: {tokenizer.decode(y[i][-1])}")
    break




Text: , the LRM or RLM to put a "strong" character to fully enclose an ambiguous punctuation character and thus make it, Next token: inherit
Text: on Monday the 2nd.

Some future dates to be aware of:

SDOT presentation and discussion with the Seattle Bicycle, Next token: Ad
Text: had declines during this time. Fig. 10
Fig. 8. Population indices and trends of breeding mourning do, Next token: ves
Text: ousings.

The first three Nimitz class carriers retained their original twin nuclear reactors which were of a fission design., Next token: 

Text: S.E.2d 554 (1984). As the record reveals that defense counsel questioned the witness exhaust, Next token: ively


In [3]:
from attention_approximation.modeling_llama import LlamaForCausalLM
import safetensors
model_config_path = "../data/MobileLLM/config.json"
model_weights_path = "../data/MobileLLM/model.safetensors"
config = LlamaForCausalLM.config_class.from_json_file(model_config_path)
msd = safetensors.torch.load_file(model_weights_path)
print(msd.keys())
model = LlamaForCausalLM(config)
model.load_state_dict(msd, strict=True)
model.eval()


LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


dict_keys(['model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', '

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
)

In [4]:
sum(p.numel() for p in model.parameters())

124635456

In [5]:
from attention_approximation.modeling_llama_approximated import LlamaForCausalLM
import torch
model_config_path = "../data/MobileLLM/config.json"

config = LlamaForCausalLM.config_class.from_json_file(model_config_path)
config.factorization_rank = 128
config.layer_sharing = False
config.seq_length = 512

msd = torch.load("../checkpoints/CF128/last.pt", map_location='cpu')
model = LlamaForCausalLM(config)
model.load_state_dict(msd, strict=False)
model.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (cp_circuit): CPCircuitLayer(
            (seq_mode_factor_1): Linear(in_features=576, out_features=128, bias=False)
            (seq_mode_factor_2): Linear(in_features=576, out_features=128, bias=False)
          )
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )


In [6]:
sum(p.numel() for p in model.parameters())

115826496