-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
A16W8 using post_scale=True sometimes throws a CUDA error with some models. Example:
#Load the model on CPU
from transformers import AutoModelForCausalLM,AutoTokenizer
import torch
model_id = "Qwen/Qwen3-4B-Instruct-2507"
compute_dtype = torch.float16
device = 'cuda:0'
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, device_map='cpu')
tokenizer = AutoTokenizer.from_pretrained(model_id)
##########################################
from gemlite.helper import *
processor = A16W8_INT8(post_scale=True) #throws an error
#processor = A16W8_INT8(post_scale=False) #Works fine
patch_model(model, device=device, processor=processor)
print(model)
##########################################
prompt = "Give me a short introduction to large language model."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
generated_ids = model.generate(
tokenizer([text], return_tensors="pt").to(device).input_ids,
max_new_tokens=512
)
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)Metadata
Metadata
Assignees
Labels
No labels