In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import os

os.environ["PATH_TO_ENV"] = "~/projects/chatsky-llm-autoconfig/.env"

In [3]:
from tqdm import tqdm
import json
import pickle
import pandas as pd

## 1. Error analysis

In [19]:
with open("../data/gen_dataset_augmented_0-402_v3", "rb") as file:
    data = pickle.load(file)
len(data)

402

In [20]:
all_lens = []
normals, exceptions, errors = [], [], []

for i, example in enumerate(tqdm(data)):
    augmented_dialogues = example["augmented_dialogues"]

    for j, aug_dia in enumerate(augmented_dialogues):
        try:
            utterances_lists = [turn["text"] for turn in aug_dia["messages"]]
            lens = [len(uttr_list) for uttr_list in utterances_lists]
            if len(set(lens)) == 1:
                normals.append((i, j))
            else:
                exceptions.append((i, j, lens))
                # exceptions.append((i, j))
            all_lens.append(lens)
        except Exception as e:
            errors.append((i, j, e))
            # errors.append((i, j))
len(errors), len(all_lens), len(normals), len(exceptions)

100%|██████████| 402/402 [00:00<00:00, 27423.11it/s]


(49, 4454, 4399, 55)

In [168]:
exceptions

[(37, 2, [3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3]),
 (37, 3, [3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3]),
 (53, 7, [4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]),
 (57, 0, [1, 3, 1, 3, 1, 3, 1, 1, 1]),
 (57, 6, [2, 3, 2, 3, 2, 3, 2]),
 (57, 7, [1, 3, 2, 3, 2, 3, 2]),
 (57, 8, [1, 3, 2, 3, 2, 3, 2]),
 (67, 4, [1, 3, 1, 3, 1, 3, 1, 3, 1]),
 (67, 15, [2, 3, 3, 3, 3]),
 (74, 0, [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3]),
 (74, 4, [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3]),
 (84, 5, [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3]),
 (85, 4, [3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3]),
 (90, 4, [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1]),
 (90, 5, [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1]),
 (100, 10, [3, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 3]),
 (114, 0, [1, 3, 1, 3, 1, 2, 1, 2, 1, 2, 1, 2, 1]),
 (114, 1, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1]),
 (114, 4, [1, 2, 1, 2, 1, 1, 1, 1, 1]),
 (114, 5, [1, 3, 1, 2, 1, 1, 1, 1, 1]),
 (114, 6, [1, 3, 1, 2, 1, 1, 1, 1, 1]),
 (114, 7, [1, 2, 1, 2, 1, 2, 1, 2, 

In [21]:
# graphs with length errors
graphs = set([graph for graph, _, _ in exceptions])
print(len(graphs), sorted(graphs))

25 [37, 53, 57, 67, 74, 84, 85, 90, 100, 114, 207, 214, 220, 221, 222, 223, 224, 228, 230, 244, 255, 331, 343, 351, 400]


In [22]:
dialogues_total = []
for graph in sorted(graphs):
    dialogues_total.append((graph, len(data[graph]["dialogues"])))
tmp = pd.DataFrame(dialogues_total, columns=["graph", "dialogues_total"])
tmp

Unnamed: 0,graph,dialogues_total
0,37,8
1,53,38
2,57,12
3,67,17
4,74,10
5,84,9
6,85,9
7,90,9
8,100,48
9,114,11


In [23]:
df = pd.DataFrame(exceptions, columns=["graph", "dialogue", "length_error"])
df.groupby("graph").count().reset_index().join(tmp, rsuffix="_").drop(
    columns=["graph_"]
)

Unnamed: 0,graph,dialogue,length_error,dialogues_total
0,37,2,2,8
1,53,1,1,38
2,57,4,4,12
3,67,2,2,17
4,74,2,2,10
5,84,1,1,9
6,85,1,1,9
7,90,2,2,9
8,100,1,1,48
9,114,6,6,11


In [24]:
df.loc[df["graph"] == 57]

Unnamed: 0,graph,dialogue,length_error
3,57,0,"[1, 3, 1, 3, 1, 3, 1, 1, 1]"
4,57,6,"[2, 3, 2, 3, 2, 3, 2]"
5,57,7,"[1, 3, 2, 3, 2, 3, 2]"
6,57,8,"[1, 3, 2, 3, 2, 3, 2]"


In [None]:
df.loc[df["graph"] == 23355]["length_error"].values

array([list([1, 2, 1, 2, 1])], dtype=object)

In [31]:
data[37]["dialogues"][2]

{'id': 'Canceling a service due to a duplicate purchase_1_2',
 'messages': [{'participant': 'assistant',
   'text': 'Hello! How can I assist you today?'},
  {'participant': 'user', 'text': 'I want to cancel my service.'},
  {'participant': 'assistant',
   'text': "I'm sorry to hear you'd like to cancel a service. Could you please provide your account number?"},
  {'participant': 'user', 'text': '123456'},
  {'participant': 'assistant',
   'text': 'Thank you. I see that there are two purchases for the same service on your account. How would you like to proceed?'},
  {'participant': 'user', 'text': 'Please cancel my duplicate service.'},
  {'participant': 'assistant',
   'text': 'Your duplicate purchase has been canceled. Is there anything else I can help you with?'},
  {'participant': 'user',
   'text': 'Actually, I need help with something else.'},
  {'participant': 'assistant',
   'text': 'Sure, what else would you like help with?'},
  {'participant': 'user',
   'text': 'Actually, I n

In [165]:
data[400]["augmented_dialogues"][1]

{'id': 'Providing emergency contact details._1_1',
 'messages': [{'participant': 'assistant',
   'text': ['Hi there! How can I help you today?',
    'Hello! What can I do for you today?',
    'Greetings! How may I assist you?']},
  {'participant': 'user',
   'text': ['I want to update my emergency contact information.',
    'I need to change my emergency contact details.',
    'I’d like to modify my emergency contact info.']},
  {'participant': 'assistant',
   'text': ['Of course, I can assist with that. Can you share the name of your emergency contact?',
    'Absolutely! Please tell me the name of your emergency contact.',
    'Sure thing! What is the name of your emergency contact?']},
  {'participant': 'user', 'text': ['John D.']},
  {'participant': 'assistant',
   'text': ['Awesome! Can you give me their phone number?',
    'Perfect. What’s their phone number?',
    'Great! Please provide their contact number.']},
  {'participant': 'user', 'text': ['555-1234']},
  {'participant': '

## 2. Re-augmentation

In [7]:
from augmentation_prompts import variations_augmentation_prompt_10
from dialogue_augmentation import augment_dialogue

In [4]:
with open("../data/gen_dataset_augmented_0-402_v3", "rb") as file:
    data = pickle.load(file)
len(data)

402

In [5]:
examples_for_reaugmentation = [
    (57, 0),
    (57, 6),
    (57, 7),
    (57, 8),
    (67, 15),
    (74, 0),
    (74, 4),
    (84, 5),
    (85, 4),
    (100, 10),
    (114, 0),
    (207, 6),
    (214, 79),
    (214, 88),
    (214, 90),
    (214, 91),
    (221, 22),
    (221, 24),
    (221, 27),
    (221, 32),
    (221, 33),
    (221, 34),
    (221, 35),
    (228, 3),
    (230, 5),
    (255, 6),
    (331, 1),
    (351, 2),
]
len(examples_for_reaugmentation)

28

In [10]:
for i, j in examples_for_reaugmentation:
    print(f"Augmenting example {i} dialogue {j}")
    topic = data[i]["topic"]
    orig_dialogue = data[i]["dialogues"][j]["messages"]

    try:
        aug_dialogue = augment_dialogue(
            orig_dialogue,
            topic,
            variations_augmentation_prompt_10,
            "gpt-4o-mini-2024-07-18",
        )
    except Exception as e:
        aug_dialogue = e

    data[i]["augmented_dialogues"][j]["messages"] = aug_dialogue

with open("../data/gen_dataset_augmented_0-402_v4", "wb") as file:
    pickle.dump(data, file)

Augmenting example 57 dialogue 0
Augmenting example 57 dialogue 6
Augmenting example 57 dialogue 7
Augmenting example 57 dialogue 8
Augmenting example 67 dialogue 15
Augmenting example 74 dialogue 0
Augmenting example 74 dialogue 4
Augmenting example 84 dialogue 5
Augmenting example 85 dialogue 4
Augmenting example 100 dialogue 10
Augmenting example 114 dialogue 0
Augmenting example 207 dialogue 6
Augmenting example 214 dialogue 79
Augmenting example 214 dialogue 88
Augmenting example 214 dialogue 90
Augmenting example 214 dialogue 91
Augmenting example 221 dialogue 22
Augmenting example 221 dialogue 24
Augmenting example 221 dialogue 27
Augmenting example 221 dialogue 32
Augmenting example 221 dialogue 33
Augmenting example 221 dialogue 34
Augmenting example 221 dialogue 35
Augmenting example 228 dialogue 3
Augmenting example 230 dialogue 5
Augmenting example 255 dialogue 6
Augmenting example 331 dialogue 1
Augmenting example 351 dialogue 2


In [None]:
with open("../data/gen_dataset_augmented_0-402_v4", "rb") as file:
    data = pickle.load(file)
len(data)

402

In [7]:
all_lens = []
normals, exceptions, errors = [], [], []

for i, example in enumerate(tqdm(data)):
    augmented_dialogues = example["augmented_dialogues"]

    for j, aug_dia in enumerate(augmented_dialogues):
        try:
            utterances_lists = [turn["text"] for turn in aug_dia["messages"]]
            lens = [len(uttr_list) for uttr_list in utterances_lists]
            if len(set(lens)) == 1:
                normals.append((i, j))
            else:
                # exceptions.append((i, j, lens))
                exceptions.append((i, j, aug_dia["id"]))
                # exceptions.append((i, j))
            all_lens.append(lens)
        except Exception as e:
            errors.append((i, j, e))
            # errors.append((i, j))
len(errors), len(all_lens), len(normals), len(exceptions)

100%|██████████| 402/402 [00:00<00:00, 20863.05it/s]


(49, 4454, 4421, 33)

In [8]:
for i, j in examples_for_reaugmentation:
    if (i, j) in [(k, l) for (k, l, _) in exceptions]:
        print(i, j)

57 6
57 7
57 8
74 0
221 33
221 35


In [9]:
# graphs with length errors
graphs = set([graph for graph, _, _ in exceptions])
print(len(graphs), sorted(graphs))

16 [37, 53, 57, 67, 74, 90, 114, 220, 221, 222, 223, 224, 228, 244, 343, 400]


In [10]:
dialogues_total = []
for graph in sorted(graphs):
    dialogues_total.append((graph, len(data[graph]["dialogues"])))
tmp = pd.DataFrame(dialogues_total, columns=["graph", "dialogues_total"])
tmp

Unnamed: 0,graph,dialogues_total
0,37,8
1,53,38
2,57,12
3,67,17
4,74,10
5,90,9
6,114,11
7,220,69
8,221,46
9,222,47


In [21]:
df = pd.DataFrame(exceptions, columns=["graph", "dialogue", "uttr_list_error"])
df.groupby("graph").count().reset_index().join(tmp, rsuffix="_").drop(
    columns=["graph_"]
)

Unnamed: 0,graph,dialogue,uttr_list_error,dialogues_total
0,37,2,2,8
1,53,1,1,38
2,57,3,3,12
3,67,1,1,17
4,74,1,1,10
5,90,2,2,9
6,114,5,5,11
7,220,1,1,69
8,221,7,7,46
9,222,1,1,47


In [22]:
df.loc[df["graph"] == 53]

Unnamed: 0,graph,dialogue,uttr_list_error
2,53,7,Reporting an incorrect product allergen listin...


In [16]:
with open("../data/idx_and_damaged_topics.json", "r") as file:
    data_to_remove = json.load(file)
len(data_to_remove)

25

In [17]:
graph_idx_to_remove = set([graph for graph, _ in data_to_remove])
print(sorted(graph_idx_to_remove))

[2, 17, 20, 65, 75, 95, 117, 152, 189, 197, 256, 300, 314, 317, 319, 324, 326, 329, 342, 350, 352, 356, 359, 392, 398]


In [18]:
graph_idx_exceptions = set([graph for graph, _, _ in exceptions])
print(sorted(graph_idx_exceptions))

[37, 53, 57, 67, 74, 90, 114, 220, 221, 222, 223, 224, 228, 244, 343, 400]


In [19]:
graph_idx_to_remove & graph_idx_exceptions

set()

In [82]:
with open(
    "../data/unstable_augmentation_examples.json", "w", encoding="utf-8"
) as file:
    json.dump(exceptions, file, indent=4)