In [1]:
import transformers
import torch
import numpy as np
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from transformers import AutoTokenizer, AutoProcessor
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import *

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

model = LlamaForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

print("Model and tokenizer loaded successfully!")

Model and tokenizer loaded successfully!


In [2]:
print(model.config)

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-3.2-1B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": [
    128001,
    128008,
    128009
  ],
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.1",
  "use_cache": true,
  "vocab_size": 128256
}



In [3]:
print(model)
# We want to use our own LlamaRotaryEmbedding
# We want to use our own LlamaSdpaAttention
# We want to use our own LlamaDecoderLayer

# We want to use our own LlamaRotaryEmbedding
# We want to use our own LlamaRMSNorm

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [4]:
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


"Hey, are you conscious? Can you talk to me? I'm trying to understand your behavior. I have some questions for you.\n\n**Update"

In [5]:
class CondensedLlamaConfig(LlamaConfig):
    def __init__(
        self,
        #position_embeddings=None,
        **kwargs
    ):
        super().__init__(**kwargs)

    # self.position_embeddings = position_embeddings
    # Override the `to_dict` method to include the new parameters
    def to_dict(self):
        base_dict = super().to_dict()
        config_dict = {
            "position_embeddings": self.position_embeddings,
        }
        base_dict.update(config_dict)
        return base_dict

"""
Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L524
"""
class CondensedLlamaAttention(LlamaAttention):
    def __init__(self, config: CondensedLlamaConfig, layer_idx: Optional[int] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        super().__init__(config, layer_idx)
        self.position_embeddings = position_embeddings

        # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            #cos, sin = position_embeddings
            cos, sin = self.position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value



class CondensedLlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(self, config: CondensedLlamaConfig, layer_idx: Optional[int] = None):  # Add layer_idx as an argument
        super().__init__(config, layer_idx)

        #self.position_embeddings = position_embeddings
        DIMENSION = 1
        SEQ_LEN = config.max_position_embeddings
        position_embeddings = np.ones((DIMENSION, SEQ_LEN))

        # Replace self_attn with your new attention module
        self.self_attn = CondensedLlamaAttention(config, layer_idx, position_embeddings)
    
    #TODO
    def rotary_position_embedding(max_seq_len, dim):
        # Calculate the angle rates based on dimension indices.
        angle_rates = 1 / torch.pow(10000, torch.arange(0, dim, 2).float() / dim)
        # Calculate the angles for each position for half of the dimensions (sine and cosine)
        angles = (torch.arange(max_seq_len).unsqueeze(1) * angle_rates.unsqueeze(0))
        # Cosines and sines of the angles to get the RoPE for each position
        position_encodings = torch.stack((angles.cos(), angles.sin()), dim=2).flatten(1)
        return position_encodings


class CondensedLlamaModel(LlamaModel):
    def __init__(self, config: CondensedLlamaConfig):
        super().__init__(config)

        self.layers = nn.ModuleList([CondensedLlamaDecoderLayer(config,layer_idx=None) for _ in range(config.num_hidden_layers)])
        # Initialize weights and apply final processing
        self.post_init()


In [6]:

class CondensedLlamaModel(LlamaModel):
    def __init__(self, config: CondensedLlamaConfig):
        super().__init__(config)

        self.layers = nn.ModuleList([CondensedLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        # Initialize weights and apply final processing
        self.post_init()

class CondensedLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config: CondensedLlamaConfig):
        super().__init__(config)

        self.model = CondensedLlamaModel(config)
        # Initialize weights and apply final processing
        self.post_init()

    def save_checkpoint(self, dir):
        # to bypass the code line 2291 in transformers.trainer
        pass


In [7]:
model_1 = CondensedLlamaModel.from_pretrained(MODEL_NAME)
model_1

Instantiating LlamaSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.
Instantiating CondensedLlamaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.


CondensedLlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x CondensedLlamaDecoderLayer(
      (self_attn): CondensedLlamaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=512, bias=False)
        (v_proj): Linear(in_features=2048, out_features=512, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    )
  )
  (norm): LlamaRMSNorm((2048,), eps=1e-05)
  (rotary_emb)

In [8]:
print(model_1.config)

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-3.2-1B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": [
    128001,
    128008,
    128009
  ],
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.1",
  "use_cache": true,
  "vocab_size": 128256
}



In [13]:
model_1.config.to_json_file("config.json")