-
Notifications
You must be signed in to change notification settings - Fork 25.7k
/
llm_engine.py
92 lines (73 loc) · 3.2 KB
/
llm_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from enum import Enum
from typing import Dict, List
from huggingface_hub import InferenceClient
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL_CALL = "tool-call"
TOOL_RESPONSE = "tool-response"
@classmethod
def roles(cls):
return [r.value for r in cls]
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
"""
Subsequent messages with the same role will be concatenated to a single message.
Args:
message_list (`List[Dict[str, str]]`): List of chat messages.
"""
final_message_list = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!")
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
if role in role_conversions:
message["role"] = role_conversions[role]
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n===\n" + message["content"]
else:
final_message_list.append(message)
return final_message_list
llama_role_conversions = {
MessageRole.SYSTEM: MessageRole.USER,
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
class HfEngine:
def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"):
self.model = model
self.client = InferenceClient(model=self.model, timeout=120)
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
if "Meta-Llama-3" in self.model:
if "<|eot_id|>" not in stop_sequences:
stop_sequences.append("<|eot_id|>")
if "!!!!!" not in stop_sequences:
stop_sequences.append("!!!!!")
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get answer
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
response = response.choices[0].message.content
# Remove stop sequences from the answer
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response