In [None]:
from transformers import AutoModelForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
from transformers.activations import ACT2FN
import torch.nn as nn

In [None]:
config = Qwen2Config(
    vocab_size=151936,
    hidden_size=4096,
    intermediate_size=22016,
    num_hidden_layers=32,
    num_attention_heads=32,
    max_position_embeddings=32768
)

In [None]:
class Qwen2MLP2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

class Qwen2MLP3(Qwen2MLP):

    def forward(self, x):
        print("edit by guofeng")
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    

    
class Qwen2MLP4(Qwen2MLP):

    def forward(self, x):
        print("edit by guofeng")
        tmp = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        down_proj = self.down_proj(tmp)
        return down_proj

In [None]:
def apply_mlp(type : str) -> None:

    from transformers.models.qwen2 import modeling_qwen2

    if type == 'simple':
        modeling_qwen2.Qwen2MLP = Qwen2MLP2

    elif type == 'v2':
        modeling_qwen2.Qwen2MLP = Qwen2MLP3

    elif type == 'liger_kernel':
        modeling_qwen2.Qwen2MLP = Qwen2MLP4

In [None]:
model = Qwen2ForCausalLM(config=config)