Skip to content

Commit

Permalink
Add generator pipeline for text generation, closes #416
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Jan 31, 2023
1 parent 27be4fa commit 0bfb596
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 46 deletions.
1 change: 1 addition & 0 deletions docs/pipeline/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The following is a list of the default pipelines available in txtai.
- Text
- [Entity](text/entity)
- [Extractive QA](text/extractor)
- [Generator](text/generator)
- [Labeling](text/labels)
- [Sequences](text/sequences)
- [Similarity](text/similarity)
Expand Down
62 changes: 62 additions & 0 deletions docs/pipeline/text/generator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Generator

![pipeline](../../images/pipeline.png#only-light)
![pipeline](../../images/pipeline-dark.png#only-dark)

The Generator pipeline takes an input prompt and generates follow-on text.

## Example

The following shows a simple example using this pipeline.

```python
from txtai.pipeline import Generator

# Create and run pipeline
generator = Generator()
generator("Hello, how are you?")
```

## Configuration-driven example

Pipelines are run with Python or configuration. Pipelines can be instantiated in [configuration](../../../api/configuration/#pipeline) using the lower case name of the pipeline. Configuration-driven pipelines are run with [workflows](../../../workflow/#configuration-driven-example) or the [API](../../../api#local-instance).

### config.yml
```yaml
# Create pipeline using lower case class name
generator:

# Run pipeline with workflow
workflow:
generator:
tasks:
- action: generator
```

### Run with Workflows

```python
from txtai.app import Application

# Create and run pipeline with workflow
app = Application("config.yml")
list(app.workflow("generator", ["Hello, how are you?"]))
```

### Run with API

```bash
CONFIG=config.yml uvicorn "txtai.api:app" &

curl \
-X POST "http://localhost:8000/workflow" \
-H "Content-Type: application/json" \
-d '{"name":"generator", "elements": ["Hello, how are you?"]}'
```

## Methods

Python documentation for the pipeline.

### ::: txtai.pipeline.Generator.__init__
### ::: txtai.pipeline.Generator.__call__
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ nav:
- Text:
- Entity: pipeline/text/entity.md
- Extractor: pipeline/text/extractor.md
- Generator: pipeline/text/generator.md
- Labels: pipeline/text/labels.md
- Sequences: pipeline/text/sequences.md
- Similarity: pipeline/text/similarity.md
Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/pipeline/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .crossencoder import CrossEncoder
from .entity import Entity
from .extractor import Extractor
from .generator import Generator
from .labels import Labels
from .questions import Questions
from .sequences import Sequences
Expand Down
72 changes: 72 additions & 0 deletions src/python/txtai/pipeline/text/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Generator module
"""

from ..hfpipeline import HFPipeline


class Generator(HFPipeline):
"""
Generate text with a causal language model.
"""

def __init__(self, path=None, quantize=False, gpu=True, model=None):
super().__init__(self.task(), path, quantize, gpu, model)

def __call__(self, text, prefix=None, maxlength=512, workers=0):
"""
Generates text using input text
Args:
text: text|list
prefix: optional prefix to prepend to text elements
maxlength: maximum sequence length
workers: number of concurrent workers to use for processing data, defaults to None
Returns:
generated text
"""

# List of texts
texts = text if isinstance(text, list) else [text]

# Add prefix, if necessary
if prefix:
texts = [f"{prefix}{x}" for x in texts]

# Run pipeline
results = self.pipeline(texts, max_length=maxlength, num_workers=workers)

# Get generated text
results = [self.clean(x) for x in results]

return results[0] if isinstance(text, str) else results

def clean(self, result):
"""
Applies a series of rules to clean generated text.
Args:
result: input result
Returns:
clean text
"""

# Extract output from list, if necessary
result = result[0] if isinstance(result, list) else result

# Get generated text field
text = result["generated_text"]

return text.replace("$=", "<=")

def task(self):
"""
Get the pipeline task name.
Returns:
pipeline task name
"""

return "text-generation"
50 changes: 4 additions & 46 deletions src/python/txtai/pipeline/text/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,13 @@
Sequences module
"""

from ..hfpipeline import HFPipeline
from .generator import Generator


class Sequences(HFPipeline):
class Sequences(Generator):
"""
Runs text through a sequence-sequence model.
"""

def __init__(self, path=None, quantize=False, gpu=True, model=None):
super().__init__("text2text-generation", path, quantize, gpu, model)

def __call__(self, text, prefix=None, maxlength=512, workers=0):
"""
Runs a sequence-sequence model for input texts.
Args:
text: text|list
prefix: optional prefix to prepend to text elements
maxlength: maximum sequence length
workers: number of concurrent workers to use for processing data, defaults to None
Returns:
generated text
"""

# List of texts
texts = text if isinstance(text, list) else [text]

# Add prefix, if necessary
if prefix:
texts = [f"{prefix}{x}" for x in texts]

# Run text2text pipeline
results = self.pipeline(texts, max_length=maxlength, num_workers=workers)

# Get generated text
results = [self.clean(x["generated_text"]) for x in results]

return results[0] if isinstance(text, str) else results

def clean(self, text):
"""
Applies a series of rules to clean generated text.
Args:
text: input text
Returns:
clean text
"""

return text.replace("$=", "<=")
def task(self):
return "text2text-generation"
24 changes: 24 additions & 0 deletions test/python/testpipeline/testgenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Generator module tests
"""

import unittest

from txtai.pipeline import Generator


class TestGenerator(unittest.TestCase):
"""
Sequences tests.
"""

def testGeneration(self):
"""
Test text pipeline generation
"""

model = Generator("hf-internal-testing/tiny-random-gpt2")
start = "Hello, how are"

# Test that text is generator
self.assertGreater(len(model(start)), len(start))

0 comments on commit 0bfb596

Please sign in to comment.