## Completing tabular datasets with Tabular LLM

*   This notebook demonstrates how to use Gretel Tabular LLM to fill in missing fields in table columns.
*  To run this notebook, you will need an API key from the Gretel Console.

## Getting Started

In [1]:
%%capture
!pip install -Uqq gretel-client

## Load and preview training data

In [None]:
# Read in the data
import pandas as pd

pd.set_option("display.max_colwidth", None)
data_source = pd.read_csv("https://gretel-datasets.s3.us-west-2.amazonaws.com/sparse/shoes-sparse-500.csv")
data_source.head()

## Define helper functions

The below helper functions facilitate interacting with Gretel Tabular LLM.

In [3]:
from dataclasses import dataclass

import yaml

from gretel_client.gretel.artifact_fetching import fetch_synthetic_data
from gretel_client.helpers import poll
from gretel_client.projects import Project, create_or_get_unique_project
from gretel_client.projects.models import Model

TABLLM_CONFIG = """
schema_version: 1.0
models:
- tabllm:
        model_id: "gretelai/tabular-v0"
        output_format: csv
"""


@dataclass
class TabLLM:
    project: Project
    model: Model

    def submit_generate(
        self,
        prompt: str,
        seed_data: pd.DataFrame,
        temperature: float = 0.8,
        top_p: float = 1,
        top_k: int = 50,
        keep_remote_data: bool = False,
        verbose: bool = False,
    ) -> pd.DataFrame:

        params = {"temperature": temperature, "top_p": top_p, "top_k": top_k, "num_records": len(seed_data)}

        if isinstance(seed_data, pd.Series):
            seed_data = seed_data.to_frame()

        ref_data = {"data": seed_data} if seed_data is not None else None

        data_processor = self.model.create_record_handler_obj(
            data_source=pd.DataFrame({"prompt": [prompt]}), params=params, ref_data=ref_data
        )

        print("Submitting generate job...")
        print(f"Prompt: {prompt}")
        print(f"Model URL: {self.project.get_console_url()}/models/{self.model.model_id}/data")

        data_processor.submit()
        poll(data_processor, verbose=verbose)

        df_generated = fetch_synthetic_data(data_processor)

        if not keep_remote_data:
            data_processor.delete()

        return df_generated

def initialize_tabllm(project_name: str):
    project = create_or_get_unique_project(name=project_name)
    print(f"Project URL: {project.get_console_url()}")

    model_config = yaml.safe_load(TABLLM_CONFIG)
    model_list = list(project.search_models())
    if len(model_list) > 0:
        model = model_list[0]
        print(f"Found existing model with id {model.model_id}")
        print(f"Model URL: {project.get_console_url()}/models/{model.model_id}/data")
    else:
        model = project.create_model_obj(model_config)
        model.submit()
        poll(model, verbose=False)

    return TabLLM(project=project, model=model)

# Configure the Tabular LLM session

In [None]:
# Configure session and initalize Tabular LLM

from gretel_client import configure_session

configure_session(api_key="prompt", cache="yes", validate=True)

tabllm = initialize_tabllm(project_name="gretel-demo-complete-partial-data")

## Create the prompt for Tabular LLM


In [5]:
prompt = """\
Add these columns
* Manufacturer# - If the Manufacturer is empty, replace with the Manufacturer based on the other columns
* Style# - If the Style is empty, replace with the Style based on the other columns
* Color# - If the Color is empty, replace with the Color based on the Name of this specific shoe
* Size# - If the Size is empty, replace with the integer value for a US shoe size based on the other columns. Default to 8 if unknown
* Description# - If the Description is empty, replace with a helpful Description of this specific shoe based on the other columns
* Gender# - If the Gender is empty, replace with the Gender based on the other columns
"""

## Complete the table using Tabular LLM

Running the below will take < 10 minutes to complete.

In [None]:
%%time
data_completed = tabllm.submit_generate(
    prompt=prompt,
    seed_data=data_source,
    temperature=0.2,
    verbose=True
  )

## Inspect the results

In [None]:
complete_columns = [col for col in data_completed.columns if col.endswith('#') or col=='Name']
data_completed = data_completed[complete_columns]

print("Original Table")
display(data_source)
print("----")

print("Completed Table")
display(data_completed)
print("----")