In [3]:
from safetensors import safe_open


state = dict()

with safe_open("/models/Llama-3.2-1B-Instruct/model.safetensors", 'pt', device='cpu') as f:
	for key in f.keys():
		state[key] = f.get_tensor(key)

for k, t in state.items():
	print(k, t.shape)


model.embed_tokens.weight torch.Size([128256, 2048])
model.layers.0.input_layernorm.weight torch.Size([2048])
model.layers.0.mlp.down_proj.weight torch.Size([2048, 8192])
model.layers.0.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.0.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.0.post_attention_layernorm.weight torch.Size([2048])
model.layers.0.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048])
model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048])
model.layers.0.self_attn.v_proj.weight torch.Size([512, 2048])
model.layers.1.input_layernorm.weight torch.Size([2048])
model.layers.1.mlp.down_proj.weight torch.Size([2048, 8192])
model.layers.1.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.1.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.1.post_attention_layernorm.weight torch.Size([2048])
model.layers.1.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.1.self_at

In [18]:
# prune weights

N_HIDDEN_SIZE = 768
N_INTERMEDIATE_SIZE = 3072
N_K_SIZE = 192

new_state = {}

for k, t in state.items():
	if k.startswith('model.layers.'):
		n = int(k.split('.')[2])
		if n > 3:
			continue

	if '.input_layernorm.' in k or 'model.norm.' in k or '.post_attention_layernorm.' in k:
		nt = t[:N_HIDDEN_SIZE].contiguous()
	elif '.self_attn.o_proj.' in k or '.self_attn.q_proj.' in k:
		nt = t[:N_HIDDEN_SIZE, :N_HIDDEN_SIZE].contiguous()
	elif '.mlp.down_proj.' in k:
		nt = t[:N_HIDDEN_SIZE, :N_INTERMEDIATE_SIZE].contiguous()
	elif '.mlp.gate_proj.' in k or '.mlp.up_proj.' in k:
		nt = t[:N_INTERMEDIATE_SIZE, :N_HIDDEN_SIZE].contiguous()
	elif '.self_attn.k_proj.' in k or '.self_attn.v_proj.' in k:
		nt = t[:N_K_SIZE, :N_HIDDEN_SIZE].contiguous()
	else:
		nt = t[:, :N_HIDDEN_SIZE].contiguous()
	new_state[k] = nt

for k, t in new_state.items():
	print(k, t.shape)

model.embed_tokens.weight torch.Size([128256, 768])
model.layers.0.input_layernorm.weight torch.Size([768])
model.layers.0.mlp.down_proj.weight torch.Size([768, 3072])
model.layers.0.mlp.gate_proj.weight torch.Size([3072, 768])
model.layers.0.mlp.up_proj.weight torch.Size([3072, 768])
model.layers.0.post_attention_layernorm.weight torch.Size([768])
model.layers.0.self_attn.k_proj.weight torch.Size([192, 768])
model.layers.0.self_attn.o_proj.weight torch.Size([768, 768])
model.layers.0.self_attn.q_proj.weight torch.Size([768, 768])
model.layers.0.self_attn.v_proj.weight torch.Size([192, 768])
model.layers.1.input_layernorm.weight torch.Size([768])
model.layers.1.mlp.down_proj.weight torch.Size([768, 3072])
model.layers.1.mlp.gate_proj.weight torch.Size([3072, 768])
model.layers.1.mlp.up_proj.weight torch.Size([3072, 768])
model.layers.1.post_attention_layernorm.weight torch.Size([768])
model.layers.1.self_attn.k_proj.weight torch.Size([192, 768])
model.layers.1.self_attn.o_proj.weight t

In [19]:
from safetensors.torch import save_file


save_file(new_state, "/models/Llama-3.2-100m-raw/model.safetensors", metadata={'format': 'pt'})

In [21]:
## test
from transformers import AutoTokenizer, AutoModelForCausalLM


model = AutoModelForCausalLM.from_pretrained('/models/Llama-3.2-100m-raw')
model

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /models/Llama-3.2-100m-raw and are newly initialized: ['model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 768)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=192, bias=False)
          (v_proj): Linear(in_features=768, out_features=192, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
          (up_proj): Linear(in_features=768, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNor