In this notebook we will load and use a generative language model that can produce a continuation for a given text. Learn more about the Text Generation task <a href="https://huggingface.co/tasks/text-generation" target="_blank" rel="noopener">here</a>.

We will be using a generic prediction UDF script. To execute queries and load data from Exasol database we will be using the <a href="https://github.com/exasol/pyexasol" target="_blank" rel="noopener">`pyexasol`</a> module.

Prior to using this notebook one needs to complete the follow steps:
1. [Create the database schema](../setup_db.ipynb).
2. [Initialize the Transformer Extension](te_init.ipynb).

In [34]:
# TODO: Move this to a separate configuration notebook. Here we just need to load this configuration from a store.
from dataclasses import dataclass

@dataclass
class SandboxConfig:
    EXTERNAL_HOST_NAME = "192.168.124.93"
    HOST_PORT = "8888"

    @property
    def EXTERNAL_HOST(self):
        return f"""{self.EXTERNAL_HOST_NAME}:{self.HOST_PORT}"""

    USER = "sys"
    PASSWORD = "exasol"
    BUCKETFS_PORT = "6666"
    BUCKETFS_USER = "w"
    BUCKETFS_PASSWORD = "write"
    BUCKETFS_USE_HTTPS = False
    BUCKETFS_SERVICE = "bfsdefault"
    BUCKETFS_BUCKET = "default"

    @property
    def EXTERNAL_BUCKETFS_HOST(self):
        return f"""{self.EXTERNAL_HOST_NAME}:{self.BUCKETFS_PORT}"""

    @property
    def BUCKETFS_URL_PREFIX(self):
        return "https://" if self.BUCKETFS_USE_HTTPS else "http://"

    @property
    def BUCKETFS_PATH(self):
        # Filesystem-Path to the read-only mounted BucketFS inside the running UDF Container
        return f"/buckets/{self.BUCKETFS_SERVICE}/{self.BUCKETFS_BUCKET}"

    SCRIPT_LANGUAGE_NAME = "PYTHON3_60"
    UDF_FLAVOR = "python3-ds-EXASOL-6.0.0"
    UDF_RELEASE= "20190116"
    UDF_CLIENT = "exaudfclient" # or for newer versions of the flavor exaudfclient_py3
    SCHEMA = "IDA"

    @property
    def SCRIPT_LANGUAGES(self):
        return f"""{self.SCRIPT_LANGUAGE_NAME}=localzmq+protobuf:///{self.BUCKETFS_SERVICE}/
            {self.BUCKETFS_BUCKET}/{self.UDF_FLAVOR}?lang=python#buckets/{self.BUCKETFS_SERVICE}/
            {self.BUCKETFS_BUCKET}/{self.UDF_FLAVOR}/exaudf/{self.UDF_CLIENT}""";

    @property
    def connection_params(self):
        return {"dns": self.EXTERNAL_HOST, "user": self.USER, "password": self.PASSWORD, "compression": True}

    @property
    def params(self):
        return {
            "script_languages": self.SCRIPT_LANGUAGES,
            "script_language_name": self.SCRIPT_LANGUAGE_NAME,
            "schema": self.SCHEMA,
            "BUCKETFS_PORT": self.BUCKETFS_PORT,
            "BUCKETFS_USER": self.BUCKETFS_USER,
            "BUCKETFS_PASSWORD": self.BUCKETFS_PASSWORD,
            "BUCKETFS_USE_HTTPS": self.BUCKETFS_USE_HTTPS,
            "BUCKETFS_BUCKET": self.BUCKETFS_BUCKET,
            "BUCKETFS_PATH": self.BUCKETFS_PATH
        }

    # Name of the BucketFS connection
    BFS_CONN = 'MyBFSConn'

    # Name of a sub-directory of the bucket root
    BFS_DIR = 'my_storage'

    # We will store all models in this sub-directory at BucketFS
    TE_MODELS_DIR = 'models'
    
    # We will save cached model in this sub-directory relative to the current directory on the local machine.
    TE_MODELS_CACHE_DIR = 'models_cache'

conf = SandboxConfig()

First we need to download a model from the Huggingface Hub and put into the BucketFS.

There are two ways of doing this.
1. Using the `TE_MODEL_DOWNLOADER_UDF` UDF.
2. Downloading a model to a local drive and subsequently uploading in into the BucketFS using a CLI.

In this notebook we will use the second method.

To demonstrate the text generation task we will use [Open Pretrained Transformers (OPT)](https://huggingface.co/facebook/opt-125m), a decoder-only pre-trained transformer from Facebook.

This is a public model, therefore the last parameter - the name of the Huggingface token connection - can be an empty string.

In [36]:
# This is the name of the model at the Huggingface Hub
MODEL_NAME = 'facebook/opt-125m'

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=conf.TE_MODELS_CACHE_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=conf.TE_MODELS_CACHE_DIR)

Now we can upload the model into the BucketFS using a command line. Unfortunately we cannot tell exactly when this process has finished. Notebook's hourglass may not be a reliable indicator. BucketFS will still be doing some work when the call issued by the notebook returns. Please wait for few moments after that, before querying the model.

In [23]:
upload_command = f"""python -m exasol_transformers_extension.upload_model \
    --bucketfs-name {conf.BUCKETFS_SERVICE} \
    --bucketfs-host {conf.EXTERNAL_HOST_NAME} \
    --bucketfs-port {conf.BUCKETFS_PORT} \
    --bucketfs-user {conf.BUCKETFS_USER} \
    --bucketfs-password {conf.BUCKETFS_PASSWORD} \
    --bucket {conf.BUCKETFS_BUCKET} \
    --path-in-bucket {conf.BFS_DIR} \
    --model-name {MODEL_NAME}  \
    --sub-dir {conf.TE_MODELS_DIR} \
    --local-model-path {conf.TE_MODELS_CACHE_DIR}
    """
!{upload_command}

Let's put the start of our conversation in a variable.

In [1]:
MY_TEXT = 'The bar-headed goose can fly at much'

# Make sure our texts can be used in an SQL statement.
MY_TEXT = MY_TEXT.replace("'", "''")

In [3]:
# Let's put a limit on the length of text the model can generate in one call.
# The limit is specified in the number of characters.
MAX_LENGTH = 30

We will be updating this variable at every call to the model.
Please run the next cell multiple times to see how the text evolves.

In [None]:
import pyexasol

sql = f"""
SELECT {conf.SCHEMA}.TE_TEXT_GENERATION_UDF(
    NULL,
    '{conf.BFS_CONN}',
    NULL,
    '{conf.TE_MODELS_DIR}',
    '{MODEL_NAME}',
    '{MY_TEXT}',
    {MAX_LENGTH},
    True
)
"""

with pyexasol.connect(dsn=conf.EXTERNAL_HOST, user=conf.USER, password=conf.PASSWORD, compression=True) as conn:
    result = conn.export_to_pandas(query_or_table=sql, query_params=conf.params).squeeze()
    MY_TEXT = result['GENERATED_TEXT']
    # The error can be observed at result['ERROR_MESSAGE']

print(MY_TEXT)
MY_TEXT = MY_TEXT.replace("'", "''")