In [14]:
import pytorch_lightning as pl
from transformers import PreTrainedTokenizerFast, PretrainedConfig, AutoTokenizer, AutoConfig, AutoModelForCausalLM
from typing import Type

class LanguageModelModule(pl.LightningModule):
    def __init__(
        self,
        learning_rate: float,
        model_config: PretrainedConfig,
        tokenizer: PreTrainedTokenizerFast,
    ):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = AutoModelForCausalLM.from_config(model_config)
        self.tokenizer = tokenizer

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

model_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-0.5B")
model_cls = AutoModelForCausalLM
learning_rate = 5e-5
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

lm_module = LanguageModelModule(
    learning_rate=learning_rate,
    model_config=model_config,
    tokenizer=tokenizer,
)

In [17]:
tokenizer

Qwen2TokenizerFast(name_or_path='Qwen/Qwen2.5-0.5B', vocab_size=151643, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151646: AddedToken("<|object_ref_start|>", rstrip=False, lstrip=False, single_word=False, nor

In [12]:
from transformers import Qwen2Config

hidden_size = 64

config_2 = Qwen2Config(
    num_hidden_layers=3,
    hidden_size=hidden_size,
    intermediate_size=hidden_size * 4,  # MLP hidden dim, following GPT-2 approach x4
    num_attention_heads=8,
    num_key_value_heads=2, # if equal to the num_attention heads, the MHA if 1 then MQA, else GQA
    vocab_size=1024,
    max_position_embeddings=512,  # Maximum sequence length
    attention_probs_dropout_prob=0.1,
)

In [6]:
AutoModelForCausalLM.from_config(model_config)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [13]:
AutoModelForCausalLM.from_config(config_2)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(1024, 64)
    (layers): ModuleList(
      (0-2): 3 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=64, out_features=64, bias=True)
          (k_proj): Linear(in_features=64, out_features=16, bias=True)
          (v_proj): Linear(in_features=64, out_features=16, bias=True)
          (o_proj): Linear(in_features=64, out_features=64, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=64, out_features=256, bias=False)
          (up_proj): Linear(in_features=64, out_features=256, bias=False)
          (down_proj): Linear(in_features=256, out_features=64, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((64,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((64,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((64,), eps=1e-06)
    

In [10]:
model_config

Qwen2Config {
  "_name_or_path": "Qwen/Qwen2.5-0.5B",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 896,
  "initializer_range": 0.02,
  "intermediate_size": 4864,
  "max_position_embeddings": 32768,
  "max_window_layers": 24,
  "model_type": "qwen2",
  "num_attention_heads": 14,
  "num_hidden_layers": 24,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 151936
}

In [11]:
config_2

Qwen2Config {
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 512,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 8,
  "num_hidden_layers": 3,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.3",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 1024
}