In [2]:
import pandas as pd
import numpy as np
import torch.utils.checkpoint
from torch import nn
from dotenv import load_dotenv

load_dotenv()

from transformers import AutoTokenizer
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import *

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
model = LlamaForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

print("Model and tokenizer loaded successfully!")

Model and tokenizer loaded successfully!


In [3]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [4]:
class CustomLlamaConfig(LlamaConfig):
    def __init__(
        self,
        rate=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.rate = rate

    # self.position_embeddings = position_embeddings
    # Override the `to_dict` method to include the new parameters
    def to_dict(self):
        base_dict = super().to_dict()
        config_dict = {
            "rate": self.rate,
        }
        base_dict.update(config_dict)
        return base_dict

In [5]:
def save_array(X, file_name):
    """
    Saves array X to csv

    Args:
        X: numpy array of shape (d, seq_len) where d is the embedding dimension.
    Returns:
        void: saves X to csv.
    """
    # df = pd.DataFrame(X)
    # Convert the tensor to a NumPy array
    numpy_array = X.numpy()
    # Save the NumPy array to a file
    file = './custom-llama/%s.npy' %(file_name)
    np.save(file, numpy_array)
    

In [6]:
"""
Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L85
"""
class CustomLlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,
        max_position_embeddings=2048,
        base=10000,
        device=None,
        scaling_factor=1.0,
        rope_type="default",
        config: Optional[CustomLlamaConfig] = None,
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        self.rope_kwargs = {}
        if config is None:
            logger.warning_once(
                "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.45"
            )
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings
            self.original_max_seq_len = max_position_embeddings
        else:
            # BC: "rope_type" was originally "type"
            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
            else:
                self.rope_type = "default"
            self.max_seq_len_cached = config.max_position_embeddings
            self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        
        # inverse frequency llama and attention factor
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


In [7]:
"""
Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L524
"""
class CustomLlamaAttention(LlamaAttention):
    def __init__(self, config: CustomLlamaConfig, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx)
        # Use custom rotary embedding
        self.rotary_emb = CustomLlamaRotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor, #pass
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45 pass
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        """
        [batch_size, seq_len, heads, head_dim]
        query_states,  torch.Size([1, 32, 13, 64]) #Figure out multihead
        key_states,  torch.Size([1, 8, 13, 64])

        [batch_size, seq_len, head_dim]
        cos torch.Size([1, 13, 64])
        sin torch.Size([1, 13, 64])

        For each head, we apply rotation.
        """

        save_array(query_states, 'query_states')
        save_array(key_states, 'key_states')
        query_states, key_states = apply_custom_rotary_pos_emb(query_states, key_states, cos, sin) # print q, k, cos, sin
        save_array(query_states, 'query_states_rope_applied')
        save_array(key_states, 'key_states_rope_applied')
        #get q, k and test

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value
    
def custom_rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_custom_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (custom_rotate_half(q) * sin)
    k_embed = (k * cos) + (custom_rotate_half(k) * sin)
    return q_embed, k_embed


class CustomLlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(self, config: CustomLlamaConfig, layer_idx: int):  # Add layer_idx as an argument
        super().__init__(config, layer_idx)
        self.self_attn = CustomLlamaAttention(config, layer_idx)

In [8]:
class CustomLlamaModel(LlamaModel):

    def __init__(self, config: CustomLlamaConfig):
        super().__init__(config)

        self.layers = nn.ModuleList(
            [CustomLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # Use custom rotary embedding
        self.rotary_emb = CustomLlamaRotaryEmbedding(config=config)

        # Initialize weights and apply final processing
        self.post_init()

"""
Top level model we are using
"""
class CustomLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config: CustomLlamaConfig):
        super().__init__(config)

        self.model = CustomLlamaModel(config)
        # Initialize weights and apply final processing
        self.post_init()

    def save_checkpoint(self, dir):
        # to bypass the code line 2291 in transformers.trainer
        pass

In [9]:
custom_lamma_config = CustomLlamaConfig(rate=1, rope_type='default')

In [10]:
custom_model = CustomLlamaForCausalLM.from_pretrained(MODEL_NAME)
custom_model.config = custom_lamma_config

In [None]:
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = custom_model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# Transfering Weights to Custom Model

In [15]:
"""
In scenarios where the custom model introduces changes to the layers, as in this example, 
the dimensions of the weights will differ. Then it is necessary to create new model weights tailored
to the custom architecture.
"""
module_patterns_to_transfer = ["q_proj", "k_proj", "v_proj", "o_proj"]
def transfer_weights(original_model, custom_model, module_patterns_to_transfer):
    original_dict = original_model.state_dict()
    custom_dict = custom_model.state_dict()

    # Filter and transfer weights for specified layers
    for key in custom_dict.keys():
        for pattern in module_patterns_to_transfer:
            if pattern in key:
                if key in original_dict:
                    # Transfer weights
                    with torch.no_grad():
                        custom_dict[key].copy_(original_dict[key])

    # Load the updated state dictionary to the model
    custom_model.load_state_dict(custom_dict)

In [None]:
# Transfer weights from the original model to the model
model_2 = CustomLlamaModel(custom_lamma_config)
transfer_weights(model, model_2, module_patterns_to_transfer)

# transferred weights in the custom model
for key, parameter in model_2.state_dict().items():
    print(key)
    print(parameter.size())
    print(parameter)


In [66]:
# save the new weights into a folder.
model_2.save_pretrained('./custom-llama-weights/')