Notebook to load in a llama-based model from a provided llama_variant. It also helps to inspect and define a new model.

For running:
    1. conda:
        - For SmolLM2-135M: open-instruct-env
        - For models defined in Absty fms-fsdp: /proj/data-eng/fsdp/train_env
        - For models defined in NEW fms-fsdp-docling: /proj/data-eng/granite_docling/conda-envs/fms-fsdp-env

    2.For running, copy this notebook to:
        - `/proj/data-eng/fsdp/fms-fsdp` for Absty PPLN (conda env: /proj/data-eng/fsdp/train_env)
        - `/proj/data-eng/granite_docling/fms-fsdp-docling/` for New Repo (conda env: /proj/data-eng/granite_docling/conda-envs/fms-fsdp-env)

        Otherwise, ModuleNotFoundError: No module named 'fms_fsdp' in running "from fms_fsdp.utils.config_utils import get_model_config"

        cp ./dqa_refine_llama_based_models.ipynb /proj/data-eng/fsdp/fms-fsdp/
        cd /proj/data-eng/fsdp/fms-fsdp/ && code ./dqa_refine_llama_based_models.ipynb
        
    3. Quick test: 
        running python -c "from fms_fsdp.utils.config_utils import get_model_config"

NOTE: 
    1. Repo where Llama class is defined: 
            /proj/data-eng/fsdp/foundation-model-stack/fms/models/llama.py
    2. Repo where to define NEW Llama-based model: 
            /proj/data-eng/fsdp/fms-fsdp/fms_fsdp/utils/config_utils.py (starcoder branch, appears to be STALE! -> consider to switch to main)
    3. Conda env for showing (in BlueVela):
        SmolLM2 model: open-instruct-env
        Llama models: /proj/data-eng/fsdp/train_env
    

CONDA_ENV_PATH="/proj/data-eng/fsdp/train_env"
CODE_PATH="/proj/data-eng/fsdp/fms-fsdp"
ENV_FILE="/proj/data-eng/fsdp/env/train_v01.env"
DATA_PATH="/proj/data-eng/fsdp/data/R00"
OUTPUT_PATH="/proj/data-eng/fsdp/experiments"


IBM repos for:
    + llama.py (where LLaMAConfig() params are offered): https://github.com/foundation-model-stack/foundation-model-stack.git
    + config_utils.py (where model_variant is defined): origin  https://github.com/foundation-model-stack/fms-fsdp.git 

In term of number of params compared between SmolLM2-135M vs llamaVer3_granite4tiktoken: https://ibm-research.slack.com/archives/C08STR7BYKW/p1747684015003109?thread_ts=1747402534.300819&cid=C08STR7BYKW

    SmolLM - SmoLM'sEmbdLayer + Granite4-tiktoken'sEmbdlayer = Granite4-tiktoken
    134,515,008 - 28,311,552 + 57,802,752  = 164,006,208

- both SmolLM2's LlamaDecoderLayer and llamaVer3_granite4tiktoken's LLaMABlock have the exact same parameter count of 3,540,096 per block, their architectures are almost functionally identical.
-   SmolLM2: Uses separate linear layers for Q, K, and V projections (q_proj, k_proj, v_proj).
    llamaVer3_granite4tiktoken: Uses a single fused linear layer (qkv_fused) for Q, K, and V.
    But they are EQUIVALENT since fusing QKV proj into one large matrix multiplication can be faster due to reduced mem bandwidth. 



In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
%reload_ext autoreload
%autoreload 2
'''
conda env:
        - For models defined in Absty fms-fsdp: /proj/data-eng/fsdp/train_env
        - For models defined in NEW fms-fsdp-docling: /proj/data-eng/granite_docling/conda-envs/fms-fsdp-env
'''
from fms_fsdp.utils.config_utils import get_model_config
from fms.models.llama import LLaMA
def print_model(model_variant:str="llama2mod_starcoder"):
  c = get_model_config(model_variant)
  m = LLaMA(c)
  num_trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
  print(f"== Number of trainable params: {num_trainable_params:,}") 
  print(m)
  return m

#### 1. NEW models defined in fms-fsdp-docling:
- conda: /proj/data-eng/granite_docling/conda-envs/fms-fsdp-env
- script: /proj/data-eng/granite_docling/fms-fsdp-docling/fms_fsdp/utils/config_utils.py

In [2]:
_ = print_model("llama165m_granite4tiktoken")

== Number of trainable params: 164,006,208
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 576)
    (head): Linear(in_features=576, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-29): 30 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=576, out_features=960, bias=False)
        )
        (dense): Linear(in_features=576, out_features=576, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (wg1_fused): Linear(in_features=576, out_features=3072, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1536, out_features=576, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [11]:
'''
1.4B llama used for Ablation study
NOTE: also FusedQKV but normal MLP
'''
_ = print_model("llama2mod_starcoder")


== model_variant: llama2mod_starcoder
== Number of trainable params: 1,410,959,360
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(49152, 2048)
    (head): Linear(in_features=2048, out_features=49152, bias=False)
  )
  (layers): ModuleList(
    (0-23): 24 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=2048, out_features=6144, bias=False)
        )
        (dense): Linear(in_features=2048, out_features=2048, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=2048, out_features=5472, bias=False)
        (wg): Linear(in_features=2048, out_features=5472, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=5472, out_features=2048, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [7]:
''' Oriignal HuggingFaceTB/SmolLM2-135M
Need to switch to conda env: open-instruct-env
Ref: 
    github: https://github.com/huggingface/smollm
    blog: https://huggingface.co/blog/smollm
            https://huggingface.co/HuggingFaceTB/SmolLM2-135M
    model + training params:
        https://github.com/huggingface/smollm/blob/main/text/pretraining/smollm2/config_smollm2_135M.yaml

PP (Sec.4.1): 
    1. SmolLM1.7B uses LLama2 arch., 
    2. Tokenizer: from SMolLM1, 49,152 tokens and was trained on a mixture of 
        70% of FineWeb-edu, 
        15% Cosmopedia-v2, 
        8% OpenWebMath, 
        5% StarCoder Data and 
        2% StackOverflow.
PP (Sec.6): 
    - Same architecture as SmolLM2-1.7B but use GQA
    - trained using the WSD scheduler with 20% decay, lr of 3.0 × 10−3.
    - SFT: a filtered version of SmolTalk3, removing complex IF (e.g., function calling)


1. Embeddings (model.embed_tokens) (49152, 576): 28,311,552 parameters
2. Decoder Layers (model.layers): There are 30 LlamaDecoderLayers.
    Each LlamaDecoderLayer consists of:
        2.1 self_attn (Q, K, V, O projections): (576 * 576) + (576 * 192) + (576 * 192) + (576 * 576) = 331,776 + 110,592 + 110,592 + 331,776 = 884,736 parameters.
        2.2 mlp (gate, up, down projections): (576 * 1536) + (576 * 1536) + (1536 * 576) = 884,736 + 884,736 + 884,736 = 2,654,208 parameters.
        2.3 input_layernorm: 576 parameters.
        2.4 post_attention_layernorm: 576 parameters.
    Total per LlamaDecoderLayer: 884,736 + 2,654,208 + 576 + 576 = 3,540,096 parameters.
    Total for 30 layers: 30 * 3,540,096 = 106,202,880 parameters.
3. Final Normalization (model.norm): 576 parameters.
4. Language Model Head (lm_head): 0 parameters, because tie_word_embeddings: true means its weights are shared with the embedding layer and are not counted separately.

Sum: 28,311,552 (embeddings) + 106,202,880 (decoder layers) + 576 (final norm) + 0 (lm_head) = 134,515,008 parameters.


'''

import torch
from transformers import AutoModelForCausalLM
from transformers import AutoConfig

def inspect_checkpoint(checkpoint:str = "HuggingFaceTB/SmolLM2-135M"):

    model = AutoModelForCausalLM.from_pretrained(
        checkpoint,
        device_map="cpu",  # changed from "CPU" to "cpu"
        torch_dtype=torch.bfloat16
        )
    # rope_theta:
    config = AutoConfig.from_pretrained(checkpoint)
    for attr in dir(config):
        if "rope" in attr.lower():
            print(f"== {attr}: {getattr(config, attr)}")

    # tie word embd:
    print(f"== Tied word embeddings: {config.tie_word_embeddings}")
    input_embeddings = model.get_input_embeddings()
    output_embeddings = model.lm_head
    # Check if weights are the same object
    tied = input_embeddings.weight.data_ptr() == output_embeddings.weight.data_ptr()
    print(f"== Tied word embeddings via embd weight: {tied}")
    

    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"== Number of trainable params of {checkpoint}: {num_trainable_params:,}")
    print(f"== Layers:\n{model}")

inspect_checkpoint()

== rope_interleaved: False
== rope_scaling: None
== rope_theta: 100000
== Tied word embeddings: True
== Tied word embeddings via embd weight: True
== Number of trainable params of HuggingFaceTB/SmolLM2-135M: 134,515,008
== Layers:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): Si

In [8]:
inspect_checkpoint("/proj/data-eng/granite_docling/experiments/llama165m_granite4tiktoken-2T-8node/hf_model/step_1650000_ckp")

== rope_scaling: None
== rope_theta: 10000.0
== Tied word embeddings: False
== Tied word embeddings via embd weight: False
== Number of trainable params of /proj/data-eng/granite_docling/experiments/llama165m_granite4tiktoken-2T-8node/hf_model/step_1650000_ckp: 221,808,960
== Layers:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(100352, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=153

In [9]:
inspect_checkpoint("/proj/data-eng/granite_docling/experiments/00-bkup/WRONG-CONFIG-llama165m_granite4tiktoken-2T-pdbs2-8node/hf_model/step_1900000_ckp")

== rope_scaling: None
== rope_theta: 10000.0
== Tied word embeddings: False
== Tied word embeddings via embd weight: False
== Number of trainable params of /proj/data-eng/granite_docling/experiments/00-bkup/WRONG-CONFIG-llama165m_granite4tiktoken-2T-pdbs2-8node/hf_model/step_1900000_ckp: 221,808,960
== Layers:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(100352, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_pr

In [3]:
print(f"params in Embd layer of SmolLM2: {49152* 576:,}")
print(f"params in Embd layer of llamaVer3_granite4tiktoken: {100352* 576:,}")   

params in Embd layer of SmolLM2: 28,311,552
params in Embd layer of llamaVer3_granite4tiktoken: 57,802,752


In [12]:
%reload_ext autoreload
%autoreload 2
'''
tie_heads=False: 221,808,960
tie_heads=True: 164,006,208

Layer Description	Tensor Shape	Total Number of Parameters
    1. WordEmbedding:	(100352, 576)	57802752
    2. LLaMABlock (one instance):	3540096
    3. dec_norm (LayerNormParameterized)	(576,)	576
    4. shared.head (Linear) - weights tied to shared.emb	(576, 100352)	0

TOTAL = 57,802,752(shared.emb)+106,202,880(LLaMABlocks)+576(dec_norm)+0(shared.head) = 164,006,208 parameters.

Detail on LLaMABlock (one instance):	3,540,096

Each LLaMABlock consists of:
    1. ln (LayerNormParameterized): hidden_size = 576 parameters.
    2. ff_ln (LayerNormParameterized): hidden_size = 576 parameters.
    3. attn (MultiHeadAttention):
        3.1 in_proj.qkv_fused: in_features × out_features = 576×960 = 552,960 parameters.
        3.2 dense: in_features × out_features = 576×576 = 331,776 parameters.
        -> Total attn parameters: 552,960+331,776 = 884,736 parameters.
    4. ff_sub_layer (GatedLinearUnit):
        4.1 w1: in_features × out_features = 576×1536 = 884,736 parameters.
        4.2 wg: in_features × out_features = 576×1536 = 884,736 parameters.
        4.3 w2: in_features × out_features = 1536×576 = 884,736 parameters.
        -> Total ff_sub_layer parameters: 884,736+884,736+884,736 = 2,654,208 parameters.

    -> Total parameters per LLaMABlock: 576(ln)+576(ff_ln)+884,736(attn)+2,654,208(ff_sub_layer)
        = 3,540,096 
    -> 30 layers = 30×3,540,096 = 106,202,880

Final Summation:



Compare to SmolLM above:

1. Embeddings (model.embed_tokens) (49152, 576): 28,311,552 parameters
2. Decoder Layers (model.layers): There are 30 LlamaDecoderLayers.
    Each LlamaDecoderLayer consists of:
        2.1 self_attn (Q, K, V, O projections): (576 * 576) + (576 * 192) + (576 * 192) + (576 * 576) = 331,776 + 110,592 + 110,592 + 331,776 = 884,736 parameters.
        2.2 mlp (gate, up, down projections): (576 * 1536) + (576 * 1536) + (1536 * 576) = 884,736 + 884,736 + 884,736 = 2,654,208 parameters.
        2.3 input_layernorm: 576 parameters.
        2.4 post_attention_layernorm: 576 parameters.
    Total per LlamaDecoderLayer: 884,736 + 2,654,208 + 576 + 576 = 3,540,096 parameters.
    Total for 30 layers: 30 * 3,540,096 = 106,202,880 parameters.
3. Final Normalization (model.norm): 576 parameters.
4. Language Model Head (lm_head): 0 parameters, because tie_word_embeddings: true means its weights are shared with the embedding layer and are not counted separately.

Sum: 28,311,552 (embeddings) + 106,202,880 (decoder layers) + 576 (final norm) + 0 (lm_head) = 134,515,008 parameters.


'''
llama165m_granite4tiktoken = print_model("llama165m_granite4tiktoken")




== model_variant: llama165m_granite4tiktoken
== Number of trainable params: 164,006,208
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 576)
    (head): Linear(in_features=576, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-29): 30 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=576, out_features=960, bias=False)
        )
        (dense): Linear(in_features=576, out_features=576, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=576, out_features=1536, bias=False)
        (wg): Linear(in_features=576, out_features=1536, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1536, out_features=576, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [13]:

llama2mod_starcoder135M_context8K_doclingV01 = print_model("llama2mod_starcoder135M_context8K_doclingV01")



== model_variant: llama2mod_starcoder135M_context8K_doclingV01
== Number of trainable params: 136,816,000
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(49152, 640)
    (head): Linear(in_features=640, out_features=49152, bias=False)
  )
  (layers): ModuleList(
    (0-14): 15 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=640, out_features=1920, bias=False)
        )
        (dense): Linear(in_features=640, out_features=640, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=640, out_features=1712, bias=False)
        (wg): Linear(in_features=640, out_features=1712, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1712, out_features=640, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [10]:

llama2mod_starcoder135M_context2K_doclingV02 = print_model("llama2mod_starcoder135M_context2K_doclingV02")


NameError: name 'print_model' is not defined

In [None]:

llama2mod_starcoder135M_context2K_doclingV02 = print_model("llama2mod_starcoder135M_context2K_doclingV02")



== model_variant: llama2mod_starcoder135M_context2K_doclingV02
== Number of trainable params: 136,816,000
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(49152, 640)
    (head): Linear(in_features=640, out_features=49152, bias=False)
  )
  (layers): ModuleList(
    (0-14): 15 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=640, out_features=1920, bias=False)
        )
        (dense): Linear(in_features=640, out_features=640, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=640, out_features=1712, bias=False)
        (wg): Linear(in_features=640, out_features=1712, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1712, out_features=640, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [16]:

llama2mod_granite4tiktoken202M_context8K_doclingV04 = print_model("llama2mod_granite4tiktoken202M_context8K_doclingV04")



== model_variant: llama2mod_granite4tiktoken202M_context8K_doclingV04
== Number of trainable params: 202,352,000
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 640)
    (head): Linear(in_features=640, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-14): 15 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=640, out_features=1920, bias=False)
        )
        (dense): Linear(in_features=640, out_features=640, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=640, out_features=1712, bias=False)
        (wg): Linear(in_features=640, out_features=1712, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1712, out_features=640, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [17]:

llama2mod_granite4tiktoken138M_context8K_tieheads_doclingV05 = print_model("llama2mod_granite4tiktoken138M_context8K_tieheads_doclingV05")



== model_variant: llama2mod_granite4tiktoken138M_context8K_tieheads_doclingV05
== Number of trainable params: 138,126,720
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 640)
    (head): Linear(in_features=640, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-14): 15 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=640, out_features=1920, bias=False)
        )
        (dense): Linear(in_features=640, out_features=640, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=640, out_features=1712, bias=False)
        (wg): Linear(in_features=640, out_features=1712, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1712, out_features=640, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [19]:

llama2mod_granite4tiktoken165M_context8K_tieheads_embdim576nh9nl30_doclingV06 = print_model("llama2mod_granite4tiktoken165M_context8K_tieheads_embdim576nh9nl30_doclingV06")



== model_variant: llama2mod_granite4tiktoken165M_context8K_tieheads_embdim576nh9nl30_doclingV06
== Number of trainable params: 177,277,248
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 576)
    (head): Linear(in_features=576, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-29): 30 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=576, out_features=1728, bias=False)
        )
        (dense): Linear(in_features=576, out_features=576, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=576, out_features=1536, bias=False)
        (wg): Linear(in_features=576, out_features=1536, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1536, out_features=576, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)


In [20]:

llama165m_granite4tiktoken = print_model("llama165m_granite4tiktoken")



== model_variant: llama165m_granite4tiktoken
== Number of trainable params: 164,006,208
LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(100352, 576)
    (head): Linear(in_features=576, out_features=100352, bias=False)
  )
  (layers): ModuleList(
    (0-29): 30 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (in_proj): FusedQKV(
          (qkv_fused): Linear(in_features=576, out_features=960, bias=False)
        )
        (dense): Linear(in_features=576, out_features=576, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=576, out_features=1536, bias=False)
        (wg): Linear(in_features=576, out_features=1536, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=1536, out_features=576, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)
