Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llm/inference/qwen1.5-0.5b/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import gradio as gr
import mindspore
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
from mindnlp.transformers import TextIteratorStreamer
from threading import Thread

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)

system_prompt = "You are a helpful and friendly chatbot"

def build_input_from_chat_history(chat_history, msg: str):
messages = [{'role': 'system', 'content': system_prompt}]
for user_msg, ai_msg in chat_history:
messages.append({'role': 'user', 'content': user_msg})
messages.append({'role': 'assistant', 'content': ai_msg})
messages.append({'role': 'user', 'content': msg})
return messages

# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]

# Formatting the input for the model.
messages = build_input_from_chat_history(history, message)
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="ms",
tokenize=True
)
streamer = TextIteratorStreamer(tokenizer, timeout=120, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.9,
temperature=0.1,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="Qwen1.5-0.5b-Chat",
description="问几个问题",
examples=['你是谁?', '介绍一下华为公司']
).launch() # Launching the web interface.
62 changes: 62 additions & 0 deletions llm/inference/tinyllama/app_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import gradio as gr
import mindspore
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
from mindnlp.transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from mindnlp.quant.smooth_quant import quantize, w8x8
from threading import Thread

mindspore.set_context(pynative_synchronize=True)
# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ms_dtype=mindspore.float16)

quantize_cfg = w8x8(model.model.config)
quantize(model, cfg=quantize_cfg)

# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return mindspore.Tensor(True)
return mindspore.Tensor(False)


# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()

# Formatting the input for the model.
messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="ms")
streamer = TextIteratorStreamer(tokenizer, timeout=3600, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=10,
temperature=0.7,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="Tinyllama_chatBot",
description="Ask Tiny llama any questions",
examples=['How to cook a fish?', 'Who is the president of US now?']
).launch() # Launching the web interface.
4 changes: 1 addition & 3 deletions mindnlp/core/ops/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import mindspore
from mindspore import ops
from mindnlp.configs import use_pyboost, ON_ORANGE_PI
from mindnlp.configs import use_pyboost

# allclose
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
Expand Down Expand Up @@ -128,8 +128,6 @@ def not_equal(input, other):

# sort
def sort(input, *, dim=-1, descending=False, stable=False):
if ON_ORANGE_PI:
return topk(input, input.shape[dim], dim, descending)
if use_pyboost():
return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
return ops.sort(input, dim, descending)
Expand Down
Empty file added mindnlp/quant/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions mindnlp/quant/smooth_quant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""smooth quant"""
from .quant import *
from .configs import *
135 changes: 135 additions & 0 deletions mindnlp/quant/smooth_quant/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""quant configs"""
def no(model_cfg, act_max):
return {}


# 静态混合精度分解
def sd(model_cfg, act_max):
quant_cfg = {}
h_mx, d_mx = findN(0.04 * model_cfg.hidden_size), findN(
0.1 * model_cfg.intermediate_size
)
scale, step = 4, 4 / model_cfg.num_hidden_layers
for i in range(model_cfg.num_hidden_layers):
scale = max(0, scale - step)
h_cur, d_cur = max(16, h_mx >> int(scale)), max(32, d_mx >> int(scale))
for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]:
quant_cfg[str(i) + "." + name] = {
"type": "W8SD",
"act_scale": True,
"alpha": h_cur,
}
quant_cfg[str(i) + ".down_proj"] = {
"type": "W8SD",
"act_scale": True,
"alpha": d_cur,
}
quant_cfg["lm_head"] = {"type": "W8SD"}
quant_cfg["act_scales_path"] = act_max
return quant_cfg


def findN(N):
sum = 1
while True:
if sum * 2 > N:
return sum
sum = sum * 2


# 平滑激活
def smooth(model_cfg, act_max):
quant_cfg = {}
for i in range(model_cfg.num_hidden_layers):
for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]:
quant_cfg[str(i) + "." + name] = {"type": "W8X8"}
# 对某一个具体的层加act_scale的作用: 若为W8X8,则对该层进行smooth;如为W8SD,则用act_scale进行混合精度分解。
quant_cfg[str(i) + ".down_proj"] = {
"type": "W8X8",
"act_scale": True,
"alpha": 0.85,
}
quant_cfg["lm_head"] = {"type": "W8X8", "act_scale": True, "alpha": 0.85}
quant_cfg["act_scales_path"] = act_max
quant_cfg["alpha"] = 0.85 # smoothquant 迁移系数
quant_cfg["smooth"] = (
True # 整体的smooth控制是将激活值的缩放与RMSNorm融合,不会造成额外的开销,但down_proj层无法使用
)
return quant_cfg


# 对down_proj混合精度分解,对其他部分平滑激活
def smsd(model_cfg, act_max):
quant_cfg = {}
d_mx = findN(0.1 * model_cfg.intermediate_size)
scale, step = 4, 4 / model_cfg.num_hidden_layers
for i in range(model_cfg.num_hidden_layers):
scale = max(0, scale - step)
d_cur = max(32, d_mx >> int(scale))
for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]:
quant_cfg[str(i) + "." + name] = {"type": "W8X8"}
quant_cfg[str(i) + ".down_proj"] = {
"type": "W8SD",
"act_scale": True,
"alpha": d_cur,
}
quant_cfg["lm_head"] = {"type": "W8SD", "act_scale": True, "alpha": 64}
quant_cfg["act_scales_path"] = act_max
quant_cfg["smooth"] = True
return quant_cfg


# 仅权重int8量化
def w8(model_cfg, act_max):
quant_cfg = {}
for i in range(model_cfg.num_hidden_layers):
for name in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]:
quant_cfg[str(i) + "." + name] = {"type": "W8"}
quant_cfg["lm_head"] = {"type": "W8"}
return quant_cfg


# 动态混合精度分解
def w8dx(model_cfg, act_max):
quant_cfg = {}
for i in range(model_cfg.num_hidden_layers):
for name in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]:
quant_cfg[str(i) + "." + name] = {"type": "W8DX"}
# quant_cfg["lm_head"] = {"type":"W8DX"} # 可以根据需要取消注释
# quant_cfg["act_scales_path"] = act_max # 可以根据需要取消注释
# quant_cfg["smooth"] = True # 可以根据需要取消注释
return quant_cfg


# per-token absmax量化
def w8x8(model_cfg):
quant_cfg = {}
for i in range(model_cfg.num_hidden_layers):
for name in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]:
quant_cfg[str(i) + "." + name] = {"type": "W8X8"}
quant_cfg["lm_head"] = {"type": "W8X8"}
return quant_cfg
Loading
Loading