In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
from transformers.activations import ACT2FN
import torch.nn as nn
import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous

In [None]:

@triton.jit
def silu(x):
    return x * tl.sigmoid(x)


@triton.jit
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    program_id = tl.program_id(0).to(tl.int64)

    # locate start index
    a_ptr += program_id * stride
    b_ptr += program_id * stride
    c_ptr += program_id * stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # sigmoid requires type float32
    a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
    b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
    c_row = silu(a_row) * b_row
    tl.store(c_ptr + col_offsets, c_row, mask=mask)


@triton.jit
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    program_id = tl.program_id(0).to(tl.int64)

    # locate start index
    dc_ptr += program_id * stride
    a_ptr += program_id * stride
    b_ptr += program_id * stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
    # sigmoid requires type float32
    a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
    b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)

    # recomputation to save memory
    sig_a = tl.sigmoid(a_row)
    silu_a = a_row * sig_a
    db_row = dc_row * silu_a
    da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row

    tl.store(a_ptr + col_offsets, da_row, mask=mask)
    tl.store(b_ptr + col_offsets, db_row, mask=mask)


def swiglu_forward(a, b):
    ori_shape = a.shape

    n_cols = ori_shape[-1]
    a = a.view(-1, n_cols)
    b = b.view(-1, n_cols)
    c = torch.empty_like(a)
    n_rows = a.shape[0]

    BLOCK_SIZE, num_warps = calculate_settings(n_cols)

    _swiglu_forward_kernel[(n_rows,)](
        a,
        b,
        c,
        c.stride(-2),
        n_cols=n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return a, b, c.view(*ori_shape)


def swiglu_backward(a, b, dc):
    ori_shape = dc.shape
    n_cols = ori_shape[-1]
    dc = dc.view(-1, n_cols)
    n_rows = dc.shape[0]

    BLOCK_SIZE, num_warps = calculate_settings(n_cols)

    _swiglu_backward_kernel[(n_rows,)](
        dc,
        a,
        b,
        dc.stride(-2),
        n_cols=n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return a.view(*ori_shape), b.view(*ori_shape)


class LigerSiLUMulFunction(torch.autograd.Function):
    @staticmethod
    @ensure_contiguous
    def forward(ctx, a, b):
        a, b, c = swiglu_forward(a, b)
        ctx.save_for_backward(a, b)
        return c

    @staticmethod
    @ensure_contiguous
    def backward(ctx, dc):
        a, b = ctx.saved_tensors
        a, b = swiglu_backward(a, b, dc)
        return a, b


In [None]:
class Qwen2MLP2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

class Qwen2MLP3(Qwen2MLP):

    def forward(self, x):
        print("edit by guofeng")
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    
    
class Qwen2MLP4(Qwen2MLP):

    def forward(self, x):
        print("edit by guofeng")
        tmp = LigerSiLUMulFunction.apply(self.gate_proj(x)),self.up_proj(x)
        down_proj = self.down_proj(tmp)
        return down_proj

In [None]:
def apply_mlp(type : str) -> None:

    from transformers.models.qwen2 import modeling_qwen2

    if type == 'simple':
        modeling_qwen2.Qwen2MLP = Qwen2MLP2

    elif type == 'v2':
        modeling_qwen2.Qwen2MLP = Qwen2MLP3

    elif type == 'liger_kernel':
        modeling_qwen2.Qwen2MLP = Qwen2MLP4

apply_mlp('liger_kernel')

In [None]:
model_path = 'D:/pretrained_model/models--Qwen--Qwen2.5-0.5B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Qwen2ForCausalLM.from_pretrained(
    model_path,
    device_map='cuda',
    torch_dtype='auto'
    )
model.eval()

In [None]:
prompt = "给我介绍下深圳未来10年的发展"
messages = [
    {"role": "system", "content": "你是一个国家领导人，站在全球的视野，给出你专业的看法"},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]



## AutoModel的lazy加载方式

In [29]:
# 方法一用transformers导入方式进行安装
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM, Qwen2ForQuestionAnswering
from transformers import AutoModelForCausalLM

Qwen2ForCausalLM

transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM

In [30]:
# 方法二用字符串定义模型名字，然后注册得到具体的包
import importlib

model_name = 'qwen2.modeling_qwen2'
model_package = importlib.import_module(name=f".{model_name}", package='transformers.models')

In [31]:
model_package

<module 'transformers.models.qwen2.modeling_qwen2' from 'c:\\Users\\49207\\.conda\\envs\\py311_langchainchat\\Lib\\site-packages\\transformers\\models\\qwen2\\modeling_qwen2.py'>

In [32]:
Qwen2ForCausalLM_gf = getattr(model_package, 'Qwen2ForCausalLM')
Qwen2ForCausalLM_gf

transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM

In [33]:
Qwen2ForQuestionAnswering_gf = getattr(model_package, 'Qwen2ForQuestionAnswering')
Qwen2ForQuestionAnswering_gf

transformers.models.qwen2.modeling_qwen2.Qwen2ForQuestionAnswering

AutoModelForCausalLM


_BaseAutoModelClass
```
 - elif type(config) in cls._model_mapping.keys():
        model_class = _get_model_class(config, cls._model_mapping) # 获取模型具体的类，比如Qwen2ForCausalLM
        return model_class.from_pretrained(
            pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
        )
```           

MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)



In [None]:
def __getitem__(self, key):
    if key in self._extra_content:
        return self._extra_content[key]
    model_type = self._reverse_config_mapping[key.__name__]
    if model_type in self._model_mapping:
        model_name = self._model_mapping[model_type]
        return self._load_attr_from_module(model_type, model_name)

    # Maybe there was several model types associated with this config.
    model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
    for mtype in model_types:
        if mtype in self._model_mapping:
            model_name = self._model_mapping[mtype]
            return self._load_attr_from_module(mtype, model_name)
    raise KeyError(key)

def keys(self):
    mapping_keys = [
        self._load_attr_from_module(key, name)
        for key, name in self._config_mapping.items()
        if key in self._model_mapping.keys()
    ]
    return mapping_keys + list(self._extra_content.keys())

def get(self, key, default):
    try:
        return self.__getitem__(key)
    except KeyError:
        return default

def __bool__(self):
    return bool(self.keys())

def values(self):
    mapping_values = [
        self._load_attr_from_module(key, name)
        for key, name in self._model_mapping.items()
        if key in self._config_mapping.keys()
    ]
    return mapping_values + list(self._extra_content.values())

def _load_attr_from_module(self, model_type, attr):
    module_name = model_type_to_module_name(model_type)
    if module_name not in self._modules:
        self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
    return getattribute_from_module(self._modules[module_name], attr)

In [34]:
a = {'a':1, 'b':2}
a

{'a': 1, 'b': 2}

In [None]:
a['a']  # __getitem__

1

In [None]:
a.get('a') # get

1

In [None]:
a.keys() # keys

dict_keys(['a', 'b'])

In [None]:
a.values() # values

dict_values([1, 2])