-
Notifications
You must be signed in to change notification settings - Fork 532
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add generator pipeline for text generation, closes #416
- Loading branch information
1 parent
27be4fa
commit 0bfb596
Showing
7 changed files
with
165 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Generator | ||
|
||
![pipeline](../../images/pipeline.png#only-light) | ||
![pipeline](../../images/pipeline-dark.png#only-dark) | ||
|
||
The Generator pipeline takes an input prompt and generates follow-on text. | ||
|
||
## Example | ||
|
||
The following shows a simple example using this pipeline. | ||
|
||
```python | ||
from txtai.pipeline import Generator | ||
|
||
# Create and run pipeline | ||
generator = Generator() | ||
generator("Hello, how are you?") | ||
``` | ||
|
||
## Configuration-driven example | ||
|
||
Pipelines are run with Python or configuration. Pipelines can be instantiated in [configuration](../../../api/configuration/#pipeline) using the lower case name of the pipeline. Configuration-driven pipelines are run with [workflows](../../../workflow/#configuration-driven-example) or the [API](../../../api#local-instance). | ||
|
||
### config.yml | ||
```yaml | ||
# Create pipeline using lower case class name | ||
generator: | ||
|
||
# Run pipeline with workflow | ||
workflow: | ||
generator: | ||
tasks: | ||
- action: generator | ||
``` | ||
|
||
### Run with Workflows | ||
|
||
```python | ||
from txtai.app import Application | ||
|
||
# Create and run pipeline with workflow | ||
app = Application("config.yml") | ||
list(app.workflow("generator", ["Hello, how are you?"])) | ||
``` | ||
|
||
### Run with API | ||
|
||
```bash | ||
CONFIG=config.yml uvicorn "txtai.api:app" & | ||
|
||
curl \ | ||
-X POST "http://localhost:8000/workflow" \ | ||
-H "Content-Type: application/json" \ | ||
-d '{"name":"generator", "elements": ["Hello, how are you?"]}' | ||
``` | ||
|
||
## Methods | ||
|
||
Python documentation for the pipeline. | ||
|
||
### ::: txtai.pipeline.Generator.__init__ | ||
### ::: txtai.pipeline.Generator.__call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Generator module | ||
""" | ||
|
||
from ..hfpipeline import HFPipeline | ||
|
||
|
||
class Generator(HFPipeline): | ||
""" | ||
Generate text with a causal language model. | ||
""" | ||
|
||
def __init__(self, path=None, quantize=False, gpu=True, model=None): | ||
super().__init__(self.task(), path, quantize, gpu, model) | ||
|
||
def __call__(self, text, prefix=None, maxlength=512, workers=0): | ||
""" | ||
Generates text using input text | ||
Args: | ||
text: text|list | ||
prefix: optional prefix to prepend to text elements | ||
maxlength: maximum sequence length | ||
workers: number of concurrent workers to use for processing data, defaults to None | ||
Returns: | ||
generated text | ||
""" | ||
|
||
# List of texts | ||
texts = text if isinstance(text, list) else [text] | ||
|
||
# Add prefix, if necessary | ||
if prefix: | ||
texts = [f"{prefix}{x}" for x in texts] | ||
|
||
# Run pipeline | ||
results = self.pipeline(texts, max_length=maxlength, num_workers=workers) | ||
|
||
# Get generated text | ||
results = [self.clean(x) for x in results] | ||
|
||
return results[0] if isinstance(text, str) else results | ||
|
||
def clean(self, result): | ||
""" | ||
Applies a series of rules to clean generated text. | ||
Args: | ||
result: input result | ||
Returns: | ||
clean text | ||
""" | ||
|
||
# Extract output from list, if necessary | ||
result = result[0] if isinstance(result, list) else result | ||
|
||
# Get generated text field | ||
text = result["generated_text"] | ||
|
||
return text.replace("$=", "<=") | ||
|
||
def task(self): | ||
""" | ||
Get the pipeline task name. | ||
Returns: | ||
pipeline task name | ||
""" | ||
|
||
return "text-generation" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
Generator module tests | ||
""" | ||
|
||
import unittest | ||
|
||
from txtai.pipeline import Generator | ||
|
||
|
||
class TestGenerator(unittest.TestCase): | ||
""" | ||
Sequences tests. | ||
""" | ||
|
||
def testGeneration(self): | ||
""" | ||
Test text pipeline generation | ||
""" | ||
|
||
model = Generator("hf-internal-testing/tiny-random-gpt2") | ||
start = "Hello, how are" | ||
|
||
# Test that text is generator | ||
self.assertGreater(len(model(start)), len(start)) |