# Synthesizing dialogs for better conversational AI

* This notebook demonstrates how to use Gretel GPT to generate synthetic conversations.
* To run this notebook, you will need an API key from the [Gretel Console](https://console.gretel.ai/).

## Getting Started

In [None]:
%%capture
!pip install -U gretel-client

In [None]:
import json

import pandas as pd
import numpy as np

from datasets import load_dataset
from gretel_client import configure_session
from gretel_client.helpers import poll
from gretel_client.projects import create_or_get_unique_project, get_project

In [None]:
# Configure a Gretel session

from gretel_client import configure_session

configure_session(
    api_key="prompt",
    endpoint="https://api.gretel.cloud",
    validate=True,
    clear=True,
)

## Load and preview training data

In [None]:
# set DATASET to either [commonsense, counselchat, dailydialog]
DATASET = "dailydialog"
MAX_NUMBER_RECORDS = 1000
SPLIT = "train"

datasets = {
    "commonsense": "mvansegb123/commonsense-dialogues",
    "counselchat": "nbertagnolli/counsel-chat",
    "dailydialog": "daily_dialog"
}

dataset = load_dataset(datasets[DATASET])

df = pd.DataFrame(dataset[SPLIT])

display(df.sample(n=5))

## Define helper functions
These functions convert the structured source data into the Gretel-GPT desired format and back.

In [None]:
# convert from structure to paragraph
def process_text_data(df, dataset):

    text_data = []

    if dataset == "commonsense":
        for index, row in df.iterrows():
            text_string = f"The context of the following conversation is {row['context']}:  "
            speakers = [row['speaker'], "The friend"] * int(len(row['turns'])/2+1)
            turns = row['turns']
            for k in range(len(turns)):
                if k == 0:
                    text_string += f"{speakers[k]} says {turns[k]}  "
                else:
                    text_string += f"{speakers[k]} responds {turns[k]}  "
            text_data += [text_string]
        
    if dataset == "counselchat":
        for index, row in df.iterrows():
            if row["questionText"] and row["answerText"]:
                question_topic = f"Within the topic of {row['topic']}"
                question_title = f"the following patient-therapist question is in the theme of \"{row['questionTitle']}\""
                patient_question = row["questionText"].replace("\n","")
                therapist_response = row["answerText"].replace("\n","")
                text_string = f"{question_topic}, {question_title}:  The patient asks \"{patient_question}\".  The therapist responds \"{therapist_response}\"."
                text_data += [text_string]

    if dataset == "dailydialog":
        
        MIN_EMOTIONS_IN_DIALOG = 3
        ACTS = {0: "says", 1: "informs", 2: "questions", 3: "says in a directive tone", 4: "says in a commissive manner"}
        EMOTION = {0: None, 1: "anger", 2: "disgust", 3: "fear", 4: "happiness", 5: "sadness", 6: "surprise"}

        for index, row in df.iterrows():
            text_string = ""
            if np.count_nonzero(row["emotion"]) >= MIN_EMOTIONS_IN_DIALOG:
                for k in range(len(row["dialog"])):
                    act = ACTS[row["act"][k]]
                    emotion = EMOTION[row["emotion"][k]]
                    if emotion:
                        emotion_str = f"while revealing the emotion of {emotion}"
                    else:
                        emotion_str = f"while revealing no emotion"
                    sentence = row["dialog"][k]
                    text_string += f"Person {k%2+1} {act} \"{sentence}\" {emotion_str}.   "
                text_data += [text_string]

    data_source = pd.DataFrame(text_data, columns=['text'])
    
    return data_source

# convert from paragrap to structure
def convert_to_struct(df, dataset):

    synth_df = None

    if dataset == "commonsense":
        speakers = []
        context = []
        turns = []
        for index, row in df.iterrows():
            context_text, turns_text = row['text'].split(":")
            context_text = context_text.replace(f"The context of the following conversation is ","")
            turns_text = turns_text.split(".  ")
            turns_sentences = []
            for k in range(len(turns_text)):
                sentence = turns_text[k].strip()
                if k == 0:
                    speaker = sentence.split(" ")[0]
                    speakers += [speaker]
                sentence = sentence.replace(f"{speaker} says ","").replace(f"{speaker} responds ","").replace(f"The friend responds ","")
                if len(sentence) > 0:
                    turns_sentences += [sentence]
            turns += [turns_sentences]
            context += [context_text]

        synth_df = pd.DataFrame([context, speakers, turns]).T
        synth_df.columns = ["context", "speaker", "turns"]
        synth_df.T.to_json(f'synth_{dataset}.json', indent=4, ensure_ascii=False)

    if dataset == "counselchat":
        
        topic = []
        questionTitle = []
        questionText = []
        answerText = []
        for index, row in df.iterrows():
            row['text'] = row['text'].replace(".  The therapist responds with:",".  The therapist responds with ")
            row['text'] = row['text'].replace("Within the topic of ","")
            row['text'] = row['text'].replace(", the following patient-therapist question is in the theme of ","|")
            row['text'] = row['text'].replace(":  The patient asks ","|")
            row['text'] = row['text'].replace(".  The therapist responds","|")
            row['text'] = row['text'].replace("\"","")

            # validate if we have all
            fields = row['text'].split("|")
            if len(fields) == 4:
                question_topic, question_title, question_text, answer_text = row['text'].split("|")
                topic += [question_topic]
                questionTitle += [question_title]
                questionText += [question_text]
                answerText += [answer_text]

        
        synth_df = pd.DataFrame([questionTitle, questionText, topic, answerText]).T
        synth_df.columns = ["questionTitle", "questionText", "topic", "answerText"]
        synth_df.to_csv(f'synth_{dataset}.csv', index=None)

    if dataset == "dailydialog":

        ACTS = {0: "says", 1: "informs", 2: "questions", 3: "says in a directive tone", 4: "says in a commissive manner"}
        EMOTION = {0: None, 1: "anger", 2: "disgust", 3: "fear", 4: "happiness", 5: "sadness", 6: "surprise"}

        dialog = []
        act = []
        emotion = []

        for index, row in df.iterrows():
            dialog_text = row['text'].split(".  ")
            dialog_sentences = []
            act_sentences = []
            emotion_sentences = []
            for k in range(len(dialog_text)):
                sentence = dialog_text[k].strip()
                if len(sentence) > 0:
                    prefix = sentence.split(" \"")[0]
                    suffix = sentence.split("\" while revealing ")[1]
                    prefix_parsed = prefix.replace("Person 1","").replace("Person 2","").strip()
                    for key, val in ACTS.items():
                        if val in prefix_parsed:
                            act_sentences += [key]
                    for key, val in EMOTION.items():
                        if not val and suffix=="no emotion":
                            emotion_sentences += [key]
                        if val and val in suffix:
                            emotion_sentences += [key]
                    sentence = sentence.split("\"")[1].strip() + " "
                    dialog_sentences += [sentence]

            dialog += [dialog_sentences]
            act += [act_sentences]
            emotion += [emotion_sentences]
        
        synth_df = pd.DataFrame([dialog, act, emotion]).T
        synth_df.columns = ["dialog", "act", "emotion"]
        synth_df.T.to_json(f'synth_{dataset}.json', indent=4)
    
    display(synth_df.head())

# Helper functions for the Gretel-GPT config
def calc_steps(num_rows, batch_size, minutes=60) -> float:
    """Estimate the number of rows that can be trained within a time period"""
    rows_per_minute = 102.0
    epochs = (rows_per_minute * minutes) / num_rows
    return int(epochs * num_rows / batch_size)

def calc_text_length(df, max_tokens=2048) -> float:
    tokens_per_word = 3
    max_string_length = int(df.str.len().max()/tokens_per_word)
    return min(int(np.ceil(max_string_length/100)*100), max_tokens)

## Convert the source data 

In [None]:
data_source = process_text_data(df, dataset=DATASET)
data_source = data_source.sample(n=MAX_NUMBER_RECORDS, ignore_index=True)
display(data_source.head(n=5))

MAX_STRING_LENGTH = data_source['text'].str.len().max()
AVG_STRING_LENGTH = data_source['text'].str.len().mean()
print(f"Nb records in training data: {len(data_source)}")
print(f"Average string length: {AVG_STRING_LENGTH:.0f}")
print(f"Maximum string length: {MAX_STRING_LENGTH}")

## Configure the Gretel-GPT Model

In this example, we will finetune Gretel GPT to generate synthetic dialogs.

In [None]:
from gretel_client.projects.models import read_model_config

config = read_model_config("synthetics/natural-language")
config['models'][0]['gpt_x']['pretrained_model'] = "gretelai/mpt-7b"
config['models'][0]['gpt_x']['steps'] = calc_steps(len(data_source), config['models'][0]['gpt_x']['batch_size'])
config['models'][0]['gpt_x']['generate'] = {
    'num_records': 3, 
    'num_beams': 5,
    'maximum_text_length': calc_text_length(data_source["text"], MAX_STRING_LENGTH), 
    }
config

## Train the synthetic model

In [None]:
# Create project
GRETEL_PROJECT = f'project-{DATASET}'
project = create_or_get_unique_project(name=GRETEL_PROJECT)

# Create and submit model
model = project.create_model_obj(model_config=config, data_source=data_source)
model.submit_cloud()

poll(model, verbose=False)

## Generate synthetic conversations

In [None]:
params={
    "maximum_text_length": calc_text_length(data_source["text"]), 
    "top_p": 0.95, 
    "num_records": 100, 
    "num_beams": 5
    }

record_handler = model.create_record_handler_obj(params = params)
record_handler.submit_cloud()
poll(record_handler, verbose=False)

## Inspect the synthetic data results

In [None]:
gpt_output = pd.read_csv(record_handler.get_artifact_link("data"), compression='gzip')
display(gpt_output.head())
convert_to_struct(gpt_output, dataset=DATASET)