Skip to content

Commit

Permalink
Add support for chat messages in LLM pipeline, closes #718
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed May 18, 2024
1 parent 498eb7c commit 8bd4d78
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 26 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

extras["pipeline-image"] = ["imagehash>=4.2.1", "pillow>=7.1.2", "timm>=0.4.12"]

extras["pipeline-llm"] = ["litellm>=1.31.2", "llama-cpp-python>=0.2.20"]
extras["pipeline-llm"] = ["litellm>=1.37.16", "llama-cpp-python>=0.2.75"]

extras["pipeline-text"] = ["fasttext>=0.9.2", "sentencepiece>=0.1.91"]

Expand Down
14 changes: 9 additions & 5 deletions src/python/txtai/pipeline/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self, path=None, template=None, **kwargs):

def __call__(self, text, maxlength, **kwargs):
"""
Generates text using input text
Generates text. Supports the following input formats:
- String or list of strings
- List of dictionaries with `role` and `content` key-values or lists of lists
Args:
text: text|list
Expand All @@ -37,8 +40,8 @@ def __call__(self, text, maxlength, **kwargs):
generated text
"""

# List of texts
texts = text if isinstance(text, list) else [text]
# Format inputs
texts = [text] if isinstance(text, str) or isinstance(text[0], dict) else text

# Apply template, if necessary
if self.template:
Expand All @@ -51,7 +54,8 @@ def __call__(self, text, maxlength, **kwargs):
# Clean generated text
results = [self.clean(texts[x], result) for x, result in enumerate(results)]

return results[0] if isinstance(text, str) else results
# Extract results based on inputs
return results[0] if isinstance(text, str) or isinstance(text[0], dict) else results

def execute(self, texts, maxlength, **kwargs):
"""
Expand All @@ -78,7 +82,7 @@ def clean(self, prompt, result):
"""

# Replace input prompt
text = result.replace(prompt, "")
text = result.replace(prompt, "") if isinstance(prompt, str) else result

# Apply text cleaning rules
return text.replace("$=", "<=").strip()
8 changes: 6 additions & 2 deletions src/python/txtai/pipeline/llm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def __init__(self, path=None, quantize=False, gpu=True, model=None, task=None, *

def __call__(self, text, prefix=None, maxlength=512, workers=0, **kwargs):
"""
Generates text using input text
Generates text. Supports the following input formats:
- String or list of strings
- List of dictionaries with `role` and `content` key-values or lists of lists
Args:
text: text|list
Expand Down Expand Up @@ -80,7 +83,8 @@ def extract(self, result):

# Extract output from list, if necessary
result = result[0] if isinstance(result, list) else result
return result["generated_text"]
text = result["generated_text"]
return text[-1]["content"] if isinstance(text, list) else text

def task(self, path, task, **kwargs):
"""
Expand Down
20 changes: 6 additions & 14 deletions src/python/txtai/pipeline/llm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,15 @@ def __init__(self, path, template=None, **kwargs):
if not LITELLM:
raise ImportError('LiteLLM is not available - install "pipeline" extra to enable')

# Register prompt template
self.register(path)

def execute(self, texts, maxlength, **kwargs):
results = []
for text in texts:
result = api.completion(model=self.path, messages=[{"content": text, "role": "user"}], max_tokens=maxlength, **{**kwargs, **self.kwargs})
result = api.completion(
model=self.path,
messages=[{"content": text, "role": "prompt"}] if isinstance(text, str) else text,
max_tokens=maxlength,
**{**kwargs, **self.kwargs}
)
results.append(result["choices"][0]["message"]["content"])

return results

def register(self, path):
"""
Registers a custom prompt template for path.
Args:
path: model path
"""

api.register_prompt_template(model=path, roles={"user": {"pre_message": "", "post_message": ""}})
33 changes: 31 additions & 2 deletions src/python/txtai/pipeline/llm/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def __init__(self, path, template=None, **kwargs):
def execute(self, texts, maxlength, **kwargs):
results = []
for text in texts:
result = self.llm(text, max_tokens=maxlength, **kwargs)
results.append(result["choices"][0]["text"])
results.append(self.messages(text, maxlength, **kwargs) if isinstance(text, list) else self.prompt(text, maxlength, **kwargs))

return results

Expand All @@ -75,3 +74,33 @@ def download(self, path):

# Download and cache file
return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:]))

def messages(self, messages, maxlength, **kwargs):
"""
Processes a list of messages.
Args:
messages: list of dictionaries with `role` and `content` key-values
maxlength: maximum sequence length
kwargs: additional generation keyword arguments
Returns:
generated text
"""

return self.llm.create_chat_completion(messages=messages, max_tokens=maxlength, **kwargs)["choices"][0]["message"]["content"]

def prompt(self, text, maxlength, **kwargs):
"""
Processes a prompt.
Args:
prompt: prompt text
maxlength: maximum sequence length
kwargs: additional generation keyword arguments
Returns:
generated text
"""

return self.llm(text, max_tokens=maxlength, **kwargs)["choices"][0]["text"]
5 changes: 4 additions & 1 deletion src/python/txtai/pipeline/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __init__(self, path=None, method=None, **kwargs):

def __call__(self, text, maxlength=512, **kwargs):
"""
Generates text using input text
Generates text. Supports the following input formats:
- String or list of strings
- List of dictionaries with `role` and `content` key-values or lists of lists
Args:
text: text|list
Expand Down
8 changes: 7 additions & 1 deletion test/python/testpipeline/testllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ def testGeneration(self):
"""

# Test model generation with llama.cpp
model = LLM("TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf")
model = LLM("TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", chat_format="chatml")

# Test with prompt
self.assertEqual(model("2 + 2 = ", maxlength=10, seed=0, stop=["."]), "4")

# Test with list of messages
messages = [{"role": "system", "content": "You are a helpful assistant. You answer math problems."}, {"role": "user", "content": "2+2?"}]
self.assertEqual(model(messages, maxlength=10, seed=0, stop=["."]), "4")

0 comments on commit 8bd4d78

Please sign in to comment.