Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing the LlamaDecoderLayer Class hugging Face With New LongNet #29962

Open
4 of 6 tasks
younesselbrag opened this issue Mar 30, 2024 · 0 comments
Open
4 of 6 tasks
Labels
Feature request Request for a new feature

Comments

@younesselbrag
Copy link

System Info

i working on the CodeLLama Model which Uses a Decoder-Only Model Transformer following Arch Blow

Main Task is replaced Decoder-Only which used Masked-Self-Attention and KV_cache with my own Encoder-Only which used Diltaed-Attention used in LongNet

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaModel, LlamaForCausalLM


model_id = "codellama/CodeLlama-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16
).to("cpu")

class CondensedLlamaConfig(LlamaConfig):
    def __init__(
        self,
        dilation_rates=None,
        segment_lengths=None,
        is_causal=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.dilation_rates = dilation_rates
        self.segment_lengths = segment_lengths
        self.is_causal = is_causal

    # Override the `to_dict` method to include the new parameters
    def to_dict(self):
        base_dict = super().to_dict()
        config_dict = {
            "dilation_rates": self.dilation_rates,
            "segment_lengths": self.segment_lengths,
            "is_causal": self.is_causal
        }
        base_dict.update(config_dict)
        return base_dict

config.num_hidden_layers = 2
model_1 = CondensedLlamaModel(config)

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer
from transformers.modeling_utils import ModuleUtilsMixin

class CondensedLlamaAttention(LlamaAttention):
    def __init__(self, config: CondensedLlamaConfig,layer_idx=None):
        super().__init__(config)

        self.LongNetAttention = MultiheadDilatedAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.dilation_rates,
            config.segment_lengths
        )
        self.is_causal = config.is_causal


    def forward(self, input, is_causal=None):
        if is_causal is None:
            is_causal = self.is_causal
        x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
        return x


class CondensedLlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(self, config: CondensedLlamaConfig, layer_idx=None):  # Add layer_idx as an argument
        super().__init__(config, layer_idx=None)  # Pass layer_idx to the parent class constructor
        # Replace self_attn with your new attention module
        self.self_attn = MultiheadDilatedAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.dilation_rates,
            config.segment_lengths
        )
        self.is_causal = config.is_causal


    def forward(self, input, is_causal=None):
        if is_causal is None:
            is_causal = self.is_causal
        x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
        return x


class CondensedLlamaModel(LlamaModel):
    def __init__(self, config: CondensedLlamaConfig):
        super().__init__(config)

        self.layers = nn.ModuleList([CondensedLlamaDecoderLayer(config,layer_idx=None) for _ in range(config.num_hidden_layers)])
        # Initialize weights and apply final processing
        self.post_init()

model_2 = model.model

import torch
module_patterns_to_transfer = ["q_proj", "k_proj", "v_proj", "o_proj"]
def transfer_weights(original_model, custom_model, module_patterns_to_transfer):
    original_dict = original_model.state_dict()
    custom_dict = custom_model.state_dict()

    # Filter and transfer weights for specified layers
    for key in custom_dict.keys():
        for pattern in module_patterns_to_transfer:
            if pattern in key:
                if key in original_dict:
                    # Transfer weights
                    with torch.no_grad():
                        custom_dict[key].copy_(original_dict[key])

    # Load the updated state dictionary to the model
    custom_model.load_state_dict(custom_dict)

config= CondensedLlamaConfig(dilation_rates=[2048, 4096, 8192, 16384, 32768],segment_lengths=[1, 2, 4, 6, 12],is_causal=False)
config.num_hidden_layers = 2
model_1 = CondensedLlamaModel(config)



# Transfer weights from the original model to the model
transfer_weights(model_2, model_1, module_patterns_to_transfer)

# transferred weights in the custom model
for key, parameter in model_1.state_dict().items():
    print(key)
    print(parameter.size())
    print(parameter)

Expected behavior

yeah i am aware of that

Checklist

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Mar 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

2 participants