# Generating a Dataset and Writing to Databricks

## Generating Chat Log Data using Navigator

In [0]:
%pip install gretel-client
%pip install --upgrade typing-extensions
dbutils.library.restartPython()

In [0]:
import logging
import yaml
from getpass import getpass

from gretel_client import Gretel, create_or_get_unique_project
from gretel_client.config import configure_session, get_session_config
from gretel_client.rest_v1.api.connections_api import ConnectionsApi
from gretel_client.rest_v1.api.logs_api import LogsApi
from gretel_client.rest_v1.api.workflows_api import WorkflowsApi
from gretel_client.rest_v1.models import (
    CreateConnectionRequest,
    CreateWorkflowRequest,
    CreateWorkflowRunRequest,
)
from gretel_client.workflows.logs import print_logs_for_workflow_run

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, to_timestamp
from pyspark.sql.types import (
    IntegerType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

gretel = Gretel(api_key="prompt")

In [0]:
# the `backend_model` argument is optional and defaults "gretelai/auto" 
tabular = gretel.factories.initialize_navigator_api("tabular", backend_model="gretelai/auto")

prompt = """\
Generate customer support chatbot data:
customer_id: A unique numeric identifier for each customer interaction (e.g, 1234).
timestamp: The date and time of the customer interaction (e.g. 2024-03-15 10:01:17).
interaction_type: Type of interaction (e.g., chat, email, phone).
customer_query: The text of the customer’s question or request.
intent: The identified intent of the customer query (e.g., account balance inquiry, transaction dispute).
response: The response provided to the customer.
resolution_status: Whether the query was resolved successfully.
agent_notes: Additional notes from the support agent.
sentiment: The sentiment score of the interaction (e.g., positive, neutral, negative).

"""

# generate tabular data from a natural language prompt
df = tabular.generate(prompt, num_records=150)

In [0]:
# Create SparkSession
spark = SparkSession.builder.getOrCreate()

# Convert Pandas DataFrame to Spark DataFrame
spark_df = spark.createDataFrame(df)

# Apply the column casting using withColumn
spark_df = spark_df.withColumn("customer_id", col("customer_id").cast(IntegerType()))
spark_df = spark_df.withColumn("timestamp", col("timestamp").cast(TimestampType()))

# Define the schema for the DataFrame
schema = StructType([
    StructField("customer_id", IntegerType(), True),
    StructField("timestamp", TimestampType(), True),
    StructField("interaction_type", StringType(), True),
    StructField("customer_query", StringType(), True),
    StructField("intent", StringType(), True),
    StructField("response", StringType(), True),
    StructField("resolution_status", StringType(), True),
    StructField("agent_notes", StringType(), True),
    StructField("sentiment", StringType(), True)
])

catalog = input('Catalog to write to:')
schema = input('Schema to write to:')

spark.sql(f"DROP TABLE IF EXISTS {catalog}.{schema}.chat_bot_logs")

# Save the DataFrame as a Delta table
spark_df.write.format("delta").mode("overwrite").saveAsTable(f"{catalog}.{schema}.chat_bot_logs")

### Creating a Gretel NavFt Workflow and write synthetic data to Databricks

In [0]:
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Set up of API's needed to run workflows
setup = configure_session(api_key="prompt")
session = get_session_config()

connection_api = session.get_v1_api(ConnectionsApi)
workflow_api = session.get_v1_api(WorkflowsApi)
log_api = session.get_v1_api(LogsApi)

project = create_or_get_unique_project(name="databricks-demo-navft")

project.get_console_url()

In [0]:
""" 
Creates source and destination connections for databricks
"""

source_conn = connection_api.create_connection(
    CreateConnectionRequest(
        name="databricks-source",
        project_id=project.project_guid,
        type="databricks",
        config={
            "server_hostname": input('Source Connection(server_hostname):'),
            "http_path": input('Source Connection(http_path):'),
            "catalog": input('Source Connection(catalog):'),
            "schema": input('Source Connection(schema):'),
        },
        credentials={
            "personal_access_token": getpass(prompt='Source Connection(Personal Access Token (PAT)):')
        },
    )
)

dest_conn = connection_api.create_connection(
    CreateConnectionRequest(
        name="databricks-dest",
        project_id=project.project_guid,
        type="databricks",
        config={
            "server_hostname": input('Destination Connection(server_hostname):'),
            "http_path": input('Destination Connection(http_path):'),
            "catalog": input('Destination Connection(catalog):'),
            "schema": input('Destination Connection(schema):'),
        },
        credentials={
            "personal_access_token": getpass(prompt='Destination Connection(Personal Access Token (PAT)): ')
        },
    )
)

In [0]:
""" 
Sample config for a Gretel Workflow that
1. Reads data from databricks
2. Generates synthetic data using our Navigator Fine Tuning (https://docs.gretel.ai/create-synthetic-data/models/synthetics/gretel-navigator-fine-tuning) model.
3. Writes generated synthetic data back to a Databricks Destination

Note: volume name can be edited in 'databricks-destination' action
"""

workflow_config = yaml.safe_load(f"""
name: databricks-navft-worflow
actions:
  - name: databricks-read
    type: databricks_source
    connection: {source_conn.id}
    config:
      sync:
        mode: full
  - name: model-train-run
    type: gretel_tabular
    input: databricks-read
    config:
      project_id: {project.project_guid}
      train:
        model_config:
          schema_version: "1.0"
          name: navigator_ft
          models:
            - navigator_ft:
                data_source: __tmp__
                group_training_examples_by: null
                order_training_examples_by: null
                generate:
                  num_records: 5000
                params:
                  num_input_records_to_sample: 25000
        dataset: "{{outputs.databricks-read.dataset}}"
  - name: databricks-write
    type: databricks_destination
    connection: {dest_conn.id}
    input: model-train-run
    config:
      sync:
        mode: replace
      dataset: "{{outputs.model-train-run.dataset}}"
      volume: "{input("Provide name for the volume: ")}"

"""
)

# Creates a workflow with the config above
workflow = workflow_api.create_workflow(
    CreateWorkflowRequest(
        name="Databricks NavFT Demo",
        project_id=project.project_guid,
        config=workflow_config,
    )
)

In [0]:
# Kicks off a run of the workflow created
workflow_run = workflow_api.create_workflow_run(
    CreateWorkflowRunRequest(workflow_id=workflow.id)
)

print_logs_for_workflow_run(workflow_run.id, session)