## Download MobileLLM (125M) weights from HuggingFace

In [None]:
!mkdir -p ../data/MobileLLM

!sudo apt install openmpi-bin openmpi-doc libopenmpi-dev
!curl -L -o ../data/MobileLLM/model.safetensors https://huggingface.co/mia-llm/MobileLLM-125M-wikitext2raw-hosein/resolve/main/model.safetensors
!curl -L -o ../data/MobileLLM/config.json https://huggingface.co/mia-llm/MobileLLM-125M-wikitext2raw-hosein/resolve/main/config.json

In [None]:
from transformers import LlamaConfig
from attention_approximation.modeling_llama import LlamaForCausalLM as TeacherModel
from attention_approximation.pytorch import intersect_dicts
from attention_approximation.utils import LOGGER
import safetensors
from attention_approximation.modeling_llama_approximated import LlamaModel
from copy import copy
import torch
from collections.abc import Callable
from torch import nn
from attention_approximation.modeling_llama_approximated import LlamaApproximatedAttention


model_config_path = "../data/MobileLLM/config.json"
model_weights_path = "../data/MobileLLM/model.safetensors"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Instantiate teacher
config = LlamaConfig().from_json_file(model_config_path)
model = TeacherModel(config)
checkpoint = safetensors.torch.load_file(model_weights_path)
csd = intersect_dicts(checkpoint, model.state_dict())  # intersect
model.load_state_dict(csd, strict=False)  # load
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from pretrained weights")
model.eval().to(device)

# Student config
student_config = copy(config)
student_config.factorization_rank = 16 # config.hidden_size // 4  # Low-rank factorization
student_config.layer_sharing = False
student_config.seq_length = 512


class AttentionDistillationWrapper(nn.Module):
    def __init__(self, student_att: Callable, teacher_att: nn.Module, config: LlamaConfig):
        super().__init__()
        self.student_att = student_att(config=config, all_indices=config.all_indices)
        self.teacher_att = teacher_att

    def forward(self, *args, **kwargs):
        student_outputs = self.student_att(*args, **kwargs)
        with torch.inference_mode():
            teacher_outputs = self.teacher_att(*args, **kwargs)
        student_hidden_states = teacher_outputs[0] if isinstance(teacher_outputs, tuple) else teacher_outputs
        teacher_hidden_states = student_outputs[0] if isinstance(student_outputs, tuple) else student_outputs
        att_loss = torch.linalg.vector_norm(student_hidden_states - teacher_hidden_states, dim=-1).mean() * (student_hidden_states[0].size(-1) ** -0.5)
        return (teacher_hidden_states, att_loss, ) + teacher_outputs[2:]


grid_y, grid_x = torch.meshgrid(
    torch.arange(student_config.seq_length, dtype=torch.long),
    torch.arange(student_config.hidden_size, dtype=torch.long),
    indexing="ij"
)
all_indices = torch.stack([grid_y, grid_x], dim=-1).view(-1, 2)
student_config.all_indices = all_indices.to(device)


model.requires_grad = False
for layer in model.model.layers:
    layer.self_attn = AttentionDistillationWrapper(student_att=LlamaApproximatedAttention, teacher_att=layer.self_attn, config=student_config)



model = model.to(device)
x = torch.randint(0, config.vocab_size, (8, 512)).to(device)
model(x)  # Test forward pass
