In [19]:
from transformers import AutoModelForCausalLM, AutoConfig


# create an instance of AutoModelForCausalLM from a configuration file without loading the weights
config = AutoConfig.from_pretrained('/models/mistral-100M/')
config

MistralConfig {
  "_name_or_path": "/models/mistral-100M/",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 16,
  "hidden_act": "silu",
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 768,
  "max_position_embeddings": 1024000,
  "model_type": "mistral",
  "num_attention_heads": 8,
  "num_hidden_layers": 4,
  "num_key_value_heads": 2,
  "pad_token_id": 10,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "unsloth_version": "2024.8",
  "use_cache": true,
  "vocab_size": 131072
}

In [20]:
model = AutoModelForCausalLM.from_config(config)
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(131072, 256, padding_idx=10)
    (layers): ModuleList(
      (0-3): 4 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=256, out_features=128, bias=False)
          (k_proj): Linear(in_features=256, out_features=32, bias=False)
          (v_proj): Linear(in_features=256, out_features=32, bias=False)
          (o_proj): Linear(in_features=128, out_features=256, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=256, out_features=768, bias=False)
          (up_proj): Linear(in_features=256, out_features=768, bias=False)
          (down_proj): Linear(in_features=768, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((256,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((256,), eps=1e-05)
      )
    

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

count_parameters(model)

69798144

## initialize parameters and save

In [23]:
import torch


# initialize parameters for the model
for p in model.parameters():
    if p.dim() > 1:
        torch.nn.init.xavier_uniform_(p, gain=(config.num_hidden_layers * 2) ** -0.5)


In [25]:
model.bfloat16()

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(131072, 256, padding_idx=10)
    (layers): ModuleList(
      (0-3): 4 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=256, out_features=128, bias=False)
          (k_proj): Linear(in_features=256, out_features=32, bias=False)
          (v_proj): Linear(in_features=256, out_features=32, bias=False)
          (o_proj): Linear(in_features=128, out_features=256, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=256, out_features=768, bias=False)
          (up_proj): Linear(in_features=256, out_features=768, bias=False)
          (down_proj): Linear(in_features=768, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((256,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((256,), eps=1e-05)
      )
    

In [26]:
model.save_pretrained('/models/mistral-100M')