Skip to content

Commit

Permalink
Add template task, closes #448
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Mar 2, 2023
1 parent c533be5 commit cb9e5b4
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/python/txtai/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

# Logging configuration
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Workflow:
Expand Down Expand Up @@ -162,7 +161,8 @@ def process(self, elements, executor):
"""

# Run elements through each task
for task in self.tasks:
for x, task in enumerate(self.tasks):
logger.debug("Running Task #%d", x)
elements = task(elements, executor)

# Yield results processed by all tasks
Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/workflow/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from .retrieve import RetrieveTask
from .service import ServiceTask
from .storage import StorageTask
from .template import ExtractorTask, TemplateTask
from .url import UrlTask
from .workflow import WorkflowTask
14 changes: 12 additions & 2 deletions src/python/txtai/workflow/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
Task module
"""

import logging
import re
import types

import numpy as np
import torch

# Logging configuration
logger = logging.getLogger(__name__)


class Task:
"""
Expand Down Expand Up @@ -207,7 +211,7 @@ def upack(self, element, force=False):
"""

# Extract data from (id, data, tag) formatted elements
if (self.unpack or force) and isinstance(element, tuple):
if (self.unpack or force) and isinstance(element, tuple) and len(element) > 1:
return element[1]

return element
Expand All @@ -225,7 +229,7 @@ def pack(self, element, data):
"""

# Pack data into (id, data, tag) formatted elements
if self.unpack and isinstance(element, tuple):
if self.unpack and isinstance(element, tuple) and len(element) > 1:
# If new data is a (id, data, tag) tuple use that except for multi-action "hstack" merges which produce tuples
if isinstance(data, tuple) and (len(self.action) <= 1 or self.merge != "hstack"):
return data
Expand Down Expand Up @@ -314,13 +318,19 @@ def process(self, action, inputs):
action outputs
"""

# Log inputs
logger.debug("Inputs: %s", inputs)

# Execute action and get outputs
outputs = action(inputs)

# Consume generator output, if necessary
if isinstance(outputs, types.GeneratorType):
outputs = list(outputs)

# Log outputs
logger.debug("Outputs: %s", outputs)

return outputs

def postprocess(self, outputs):
Expand Down
121 changes: 121 additions & 0 deletions src/python/txtai/workflow/task/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Template module
"""

from string import Formatter

from .file import Task


class TemplateTask(Task):
"""
Task that generates text from a template and task inputs. Templates can be used to prepare data for a number of tasks
including generating large languge model (LLM) prompts.
"""

def register(self, template=None, rules=None, strict=True):
"""
Read template parameters.
Args:
template: prompt template
rules: parameter rules
strict: requires all task inputs to be consumed by template, defaults to True
"""

# pylint: disable=W0201
# Template text
self.template = template if template else self.defaulttemplate()

# Template processing rules
self.rules = rules if rules else self.defaultrules()

# Create formatter
self.formatter = TemplateFormatter() if strict else Formatter()

def prepare(self, element):
# Check if element matches any processing rules
match = self.match(element)
if match:
return match

# Apply template processing, if applicable
if self.template:
# Pass dictionary as named prompt template parameters
if isinstance(element, dict):
return self.formatter.format(self.template, **element)

# Pass tuple as prompt template parameters (arg0 - argN)
if isinstance(element, tuple):
return self.formatter.format(self.template, **{f"arg{i}": x for i, x in enumerate(element)})

# Default behavior is to use input as {text} parameter in prompt template
return self.formatter.format(self.template, text=element)

# Return original inputs when no prompt provided
return element

def defaulttemplate(self):
"""
Generates a default template for this task. Base method returns None.
Returns:
default template
"""

return None

def defaultrules(self):
"""
Generates a default rules for this task. Base method returns an empty dictionary.
Returns:
default rules
"""

return {}

def match(self, element):
"""
Check if element matches any processing rules.
Args:
element: input element
Returns:
matching value if found, None otherwise
"""

if self.rules and isinstance(element, dict):
# Check if any rules are matched
for key, value in self.rules.items():
if element[key] == value:
return element[key]

return None


class ExtractorTask(TemplateTask):
"""
Template task that prepares input for an extractor pipeline.
"""

def prepare(self, element):
# Allow partial input with the "question" field used to complete prompt template
if isinstance(element, dict):
element["question"] = super().prepare(element["question"])
return element

# Default mode is to use element text for both query and question
return {"query": element, "question": super().prepare(element)}


class TemplateFormatter(Formatter):
"""
Helper class used to format template checks.
"""

def check_unused_args(self, used_args, args, kwargs):
difference = set(kwargs).difference(used_args)
if difference:
raise KeyError(difference)

0 comments on commit cb9e5b4

Please sign in to comment.