# Fine-tune and conditionally generate text with an LLM

This notebook will walk you fine-tuning and applying a cutting edge open-source LLM (MosaicML mpt-7b or Llama2) using Gretel's API service, and then using the LLM to generate additional examples matching the desired label.

To run this notebook, you will need an API key from the Gretel console at https://console.gretel.cloud. Running the entire notebook should take about 20 minutes for fine-tuning and generation.

In [None]:
!pip install -Uqq gretel-client

To get started with your project, you'll need to set up the following parameters:

* `DATASET_PATH`: Specify the path to your dataset that you want to use for training and generation.
* `LLM`: Choose the Language Model (LLM) you wish to use. This must be a supported model in https://docs.gretel.ai/reference/synthetics/models/gretel-gpt.
* `GRETEL_PROJECT`: Define the name of your Gretel project where you'll store the trained model and its results. This should be a unique and descriptive name.
* `TEXT_COLUMN`: Specify the name of the column in your training dataset that contains the text data you want to use for training the model.
* `LABEL_COLUMN`: Identify the corresponding column in your training dataset that contains the class labels or categories for your data.


In [None]:
import json

import pandas as pd
from gretel_client import configure_session
from gretel_client.helpers import poll
from gretel_client.projects import create_or_get_unique_project, get_project
from gretel_client.projects.models import read_model_config, Model


DATASET_PATH = 'https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/banking77.csv'  # @param {type:"string"}
LLM = "gretelai/mpt-7b"  # @param {type:"string"}
GRETEL_PROJECT = 'banking77'  # @param {type:"string"}
TEXT_COLUMN = "text"# @param {type:"string"}
LABEL_COLUMN = "intent" # @param {type:"string"}

In [None]:
# Log into Gretel and configure project

configure_session(api_key="prompt", cache="yes", endpoint="https://api.gretel.cloud", validate=True, clear=True)

project = create_or_get_unique_project(name=GRETEL_PROJECT)

## Load and preview the training dataset
For fine-tuning the LLM, we need to combine the class labels and the text into a single column that we'll add to the dataset. We'll use `,` as a separator.


In [None]:
import pandas as pd

LABEL_AND_TEXT_COLUMN = 'label_and_text'
SEPARATOR = ','

def create_finetune_dataset(dataset_path: str) -> pd.DataFrame:
    """
    Create a dataset for fine-tuning a language model by combining class labels and text.

    Args:
        dataset_path (str): The path to the input dataset in CSV format.

    Returns:
        pd.DataFrame: The dataset augmented with a combined label_and_text column.
    """
    records = []

    try:
        df = pd.read_csv(dataset_path)
        df[LABEL_AND_TEXT_COLUMN] = df[LABEL_COLUMN] + SEPARATOR + df[TEXT_COLUMN]
        return df
    except FileNotFoundError:
        print(f"Error: File not found at '{dataset_path}'")
        return None


# Create the fine-tuned dataset
df = create_finetune_dataset(DATASET_PATH)
df


## Train the synthetic model
In this step, we will task a worker running in the Gretel cloud to fine-tune the GPT language model on the source dataset.

In [None]:
%%time

# Download and update base config
config = read_model_config("synthetics/natural-language")
config['models'][0]['gpt_x']['pretrained_model'] = LLM
config['models'][0]['gpt_x']['column_name'] = LABEL_AND_TEXT_COLUMN
config

# Create and submit model
model = project.create_model_obj(model_config=config, data_source=df)
print(f"Follow along with training in the console: {project.get_console_url()}")
model.name = f"{GRETEL_PROJECT}-{LLM}"
model.submit_cloud()

poll(model, verbose=False)

## Create prompts
As we have fine-tuned the model on examples of the form `<label>,<text>`, we can generate new synthetic text examples by promting the text completion model with `<label>,` for a desired class label.

The prompt dataset should have a single column with one prompt record for each synthetic text record we want in the output.


In [None]:
PROMPT_LABEL = "card arrival"  # @param {type:"string"}
NUM_RECORDS = 25  # @param {type:"number"}


In [None]:

pd.set_option('max_colwidth', 300)

def create_prompt_df(prompt_label: str, num_records: int = 25) -> pd.DataFrame:
    """
    Create a prompt DataFrame with the given number of rows, each containing a prompt.

    Args:
        prompt_label (str): The class label to use in the prompt.
        num_records (int): The number of records to generate in the prompt DataFrame.
            The generated synthetic data will have the same number of records.

    Returns:
        pd.DataFrame: A DataFrame with the given number of rows, each containing a class
            label prompt.
    """
    # Note: the column name in this dataframe doesn't matter, as it may only contain a single
    # column anyway.
    # The column name in the generated synthetic data will be taken from the training dataset
    # instead.
    return pd.DataFrame([prompt_label + SEPARATOR] * num_records, columns=["prompt"])


print("Text completion prompts with class labels")
prompt_df = create_prompt_df(PROMPT_LABEL, num_records=NUM_RECORDS)
prompt_df



# Create synthetic data

Prompt our fine-tuned model with the prompt dataset, using the model to create new synthetic text examples for the given class label.



In [None]:
import pandas as pd

def generate_synthetic_data(model: Model, prompt_df: pd.DataFrame):
    """
    Generate synthetic data based on a prompt using an AI model.

    Args:
        model: The LLM used for generating synthetic data.
        prompt_df: A single-column dataframe containing the prompts.

    Returns:
        df: A dataframe containing the synthetic data generated by the model.
    """

    # Create a response handler object
    response_handler = model.create_record_handler_obj(
        params={"maximum_text_length": 50, "temperature": 0.7},
        data_source=prompt_df
    )
    response_handler.submit_cloud()
    poll(response_handler, verbose=False)

    # Read the response into a dataframe
    df = pd.read_csv(response_handler.get_artifact_link("data"), compression='gzip')

    return df

synthetic_data = generate_synthetic_data(model, prompt_df)
synthetic_data