Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local backends fix #18

Merged
merged 11 commits into from
Nov 27, 2023
Merged
27 changes: 19 additions & 8 deletions backends/huggingface_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import copy

logger = backends.get_logger(__name__)

Expand Down Expand Up @@ -153,17 +154,23 @@ def generate_response(self, messages: List[Dict], model: str,
logger.info(f"Finished loading huggingface model: {model}")
logger.info(f"Model device map: {self.model.hf_device_map}")

# deepcopy messages to prevent reference issues:
current_messages = copy.deepcopy(messages)

# flatten consecutive user messages:
for msg_idx, message in enumerate(messages):
if msg_idx > 0 and message['role'] == "user" and messages[msg_idx - 1]['role'] == "user":
messages[msg_idx - 1]['content'] += f" {message['content']}"
del messages[msg_idx]
for msg_idx, message in enumerate(current_messages):
if msg_idx > 0 and message['role'] == "user" and current_messages[msg_idx - 1]['role'] == "user":
current_messages[msg_idx - 1]['content'] += f" {message['content']}"
del current_messages[msg_idx]
elif msg_idx > 0 and message['role'] == "assistant" and current_messages[msg_idx - 1]['role'] == "assistant":
current_messages[msg_idx - 1]['content'] += f" {message['content']}"
del current_messages[msg_idx]

# apply chat template & tokenize:
prompt_tokens = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
prompt_tokens = self.tokenizer.apply_chat_template(current_messages, return_tensors="pt")
prompt_tokens = prompt_tokens.to(self.device)

prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False)
prompt_text = self.tokenizer.batch_decode(prompt_tokens)[0]
prompt = {"inputs": prompt_text, "max_new_tokens": max_new_tokens,
"temperature": self.temperature, "return_full_text": return_full_text}

Expand All @@ -189,15 +196,19 @@ def generate_response(self, messages: List[Dict], model: str,
do_sample=do_sample
)

model_output = self.tokenizer.batch_decode(model_output_ids, skip_special_tokens=True)[0]
model_output = self.tokenizer.batch_decode(model_output_ids)[0]

response = {'response': model_output}

# cull input context; equivalent to transformers.pipeline method:
if not return_full_text:
response_text = model_output.replace(prompt_text, '').strip()
# remove llama2 EOS token at the end of output:
if response_text[-4:len(response_text)] == "</s>":
response_text = response_text[:-4]
else:
response_text = model_output.strip()

response = {'response': model_output}
return prompt, response, response_text

def supports(self, model_name: str):
Expand Down
40 changes: 24 additions & 16 deletions backends/llama2_hf_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import copy

logger = backends.get_logger(__name__)

Expand Down Expand Up @@ -89,18 +90,24 @@ def generate_response(self, messages: List[Dict], model: str,
# turn off redundant transformers warnings:
transformers.logging.set_verbosity_error()

# deepcopy messages to prevent reference issues:
current_messages = copy.deepcopy(messages)

if model in self.chat_models: # chat completion
# flatten consecutive user messages:
for msg_idx, message in enumerate(messages):
if msg_idx > 0 and message['role'] == "user" and messages[msg_idx - 1]['role'] == "user":
messages[msg_idx - 1]['content'] += f" {message['content']}"
del messages[msg_idx]
for msg_idx, message in enumerate(current_messages):
if msg_idx > 0 and message['role'] == "user" and current_messages[msg_idx - 1]['role'] == "user":
current_messages[msg_idx - 1]['content'] += f" {message['content']}"
del current_messages[msg_idx]
elif msg_idx > 0 and message['role'] == "assistant" and current_messages[msg_idx - 1]['role'] == "assistant":
current_messages[msg_idx - 1]['content'] += f" {message['content']}"
del current_messages[msg_idx]

# apply chat template & tokenize
prompt_tokens = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
prompt_tokens = self.tokenizer.apply_chat_template(current_messages, return_tensors="pt")
prompt_tokens = prompt_tokens.to(self.device)
# apply chat template for records:
prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False)
prompt_text = self.tokenizer.batch_decode(prompt_tokens)[0]
prompt = {"inputs": prompt_text, "max_new_tokens": max_new_tokens,
"temperature": self.temperature}

Expand All @@ -119,18 +126,20 @@ def generate_response(self, messages: List[Dict], model: str,
max_new_tokens=max_new_tokens
)

model_output = self.tokenizer.batch_decode(model_output_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
model_output = self.tokenizer.batch_decode(model_output_ids)[0]

response = {"response": model_output}

# cull prompt from output:
response_text = model_output.replace(prompt_text, "").strip()
# remove EOS token at the end of output:
if response_text[-4:len(response_text)] == "</s>":
response_text = response_text[:-4]

response = {
"role": "assistant",
"content": model_output.replace(prompt_text, ''),
}

response_text = model_output.replace(prompt_text, '').strip()

else: # default (text completion)
prompt = "\n".join([message["content"] for message in messages])
prompt = "\n".join([message["content"] for message in current_messages])

prompt_tokens = self.tokenizer.encode(
prompt,
Expand All @@ -153,8 +162,7 @@ def generate_response(self, messages: List[Dict], model: str,
max_new_tokens=max_new_tokens
)

model_output = self.tokenizer.batch_decode(model_output_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
model_output = self.tokenizer.batch_decode(model_output_ids)[0]

response_text = model_output.replace(prompt, '').strip()

Expand Down