In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import os

os.environ["PATH_TO_ENV"] = "~/projects/chatsky-llm-autoconfig/.env"
os.getenv("EMBEDDER_MODEL")

'BAAI/bge-m3'

In [3]:
from dialog2graph.pipelines.core.dialog import DialogMessage
from pydantic import BaseModel
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List
from tqdm import tqdm
import pickle
from augmentation_prompts import naive_augmentation_prompt



In [12]:
from datasets import load_dataset

dataset = load_dataset("DeepPavlov/d2g_generated", token=True)

## Naive Dialog Augmentation

In [5]:
augmentation_prompt = PromptTemplate.from_template(naive_augmentation_prompt)


class DialogSequence(BaseModel):
    result: List[DialogMessage]


model = ChatOpenAI(
    model="gpt-4o-mini",
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
    temperature=0.7,
)

parser = JsonOutputParser(pydantic_object=DialogSequence)

chain = augmentation_prompt | model | parser

In [None]:
new_data = []

for i, example in enumerate(dataset["train"]):
    print(f"Augmenting example {i}:")
    topic = example["topic"]
    all_dialogs = example["dialogs"]

    example["augmented_dialogs"] = []

    for element in tqdm(all_dialogs, total=len(all_dialogs)):
        orig_dialog = element["messages"]
        aug_dialog = chain.invoke({"topic": topic, "dialog": orig_dialog})
        example["augmented_dialogs"].append(
            {"id": element["id"], "messages": aug_dialog}
        )

    new_data.append(example)
    with open("../data/gen_dataset_augment_naive", "wb") as fp:
        pickle.dump(new_data, fp)

Augmenting example 0:


100%|██████████| 6/6 [00:33<00:00,  5.61s/it]


Augmenting example 1:


100%|██████████| 12/12 [01:01<00:00,  5.14s/it]


Augmenting example 2:


100%|██████████| 12/12 [00:49<00:00,  4.15s/it]


Augmenting example 3:


100%|██████████| 9/9 [00:42<00:00,  4.69s/it]


Augmenting example 4:


100%|██████████| 13/13 [01:02<00:00,  4.77s/it]


Augmenting example 5:


100%|██████████| 12/12 [00:57<00:00,  4.76s/it]


Augmenting example 6:


100%|██████████| 5/5 [00:32<00:00,  6.43s/it]


Augmenting example 7:


100%|██████████| 17/17 [01:16<00:00,  4.48s/it]


Augmenting example 8:


100%|██████████| 10/10 [00:56<00:00,  5.67s/it]


Augmenting example 9:


100%|██████████| 6/6 [00:46<00:00,  7.68s/it]


Augmenting example 10:


100%|██████████| 16/16 [01:28<00:00,  5.51s/it]


Augmenting example 11:


100%|██████████| 5/5 [00:30<00:00,  6.16s/it]


Augmenting example 12:


100%|██████████| 5/5 [00:27<00:00,  5.46s/it]


Augmenting example 13:


100%|██████████| 4/4 [00:38<00:00,  9.73s/it]


Augmenting example 14:


100%|██████████| 10/10 [01:08<00:00,  6.83s/it]


Augmenting example 15:


100%|██████████| 4/4 [00:27<00:00,  6.83s/it]


Augmenting example 16:


100%|██████████| 10/10 [01:00<00:00,  6.02s/it]


Augmenting example 17:


100%|██████████| 20/20 [02:03<00:00,  6.16s/it]


Augmenting example 18:


100%|██████████| 19/19 [01:45<00:00,  5.55s/it]


Augmenting example 19:


100%|██████████| 7/7 [01:10<00:00, 10.04s/it]


Augmenting example 20:


100%|██████████| 9/9 [00:36<00:00,  4.02s/it]


Augmenting example 21:


100%|██████████| 31/31 [01:51<00:00,  3.61s/it]


Augmenting example 22:


100%|██████████| 17/17 [01:26<00:00,  5.08s/it]


Augmenting example 23:


100%|██████████| 18/18 [01:09<00:00,  3.85s/it]


Augmenting example 24:


100%|██████████| 5/5 [00:27<00:00,  5.40s/it]


Augmenting example 25:


100%|██████████| 8/8 [00:38<00:00,  4.86s/it]


Augmenting example 26:


100%|██████████| 4/4 [00:44<00:00, 11.17s/it]


Augmenting example 27:


100%|██████████| 19/19 [01:45<00:00,  5.54s/it]


Augmenting example 28:


100%|██████████| 9/9 [00:58<00:00,  6.48s/it]


Augmenting example 29:


100%|██████████| 9/9 [00:51<00:00,  5.67s/it]


Augmenting example 30:


100%|██████████| 7/7 [00:41<00:00,  5.98s/it]


Augmenting example 31:


100%|██████████| 9/9 [00:59<00:00,  6.63s/it]


Augmenting example 32:


100%|██████████| 8/8 [01:04<00:00,  8.10s/it]


Augmenting example 33:


100%|██████████| 9/9 [00:56<00:00,  6.30s/it]


Augmenting example 34:


100%|██████████| 4/4 [00:35<00:00,  8.81s/it]


Augmenting example 35:


100%|██████████| 9/9 [01:04<00:00,  7.16s/it]


Augmenting example 36:


100%|██████████| 13/13 [01:17<00:00,  5.95s/it]


Augmenting example 37:


100%|██████████| 8/8 [01:02<00:00,  7.80s/it]


Augmenting example 38:


100%|██████████| 10/10 [01:04<00:00,  6.44s/it]


Augmenting example 39:


100%|██████████| 14/14 [01:28<00:00,  6.35s/it]


Augmenting example 40:


100%|██████████| 19/19 [01:38<00:00,  5.16s/it]


Augmenting example 41:


100%|██████████| 7/7 [00:53<00:00,  7.59s/it]


Augmenting example 42:


100%|██████████| 8/8 [00:54<00:00,  6.81s/it]


Augmenting example 43:


100%|██████████| 5/5 [00:55<00:00, 11.03s/it]


Augmenting example 44:


100%|██████████| 12/12 [01:36<00:00,  8.04s/it]


Augmenting example 45:


100%|██████████| 13/13 [01:12<00:00,  5.60s/it]


Augmenting example 46:


100%|██████████| 7/7 [00:54<00:00,  7.72s/it]


Augmenting example 47:


100%|██████████| 14/14 [01:05<00:00,  4.71s/it]


Augmenting example 48:


100%|██████████| 10/10 [00:37<00:00,  3.73s/it]


Augmenting example 49:


100%|██████████| 12/12 [01:14<00:00,  6.18s/it]


Augmenting example 50:


100%|██████████| 11/11 [00:45<00:00,  4.14s/it]


Augmenting example 51:


100%|██████████| 18/18 [01:49<00:00,  6.11s/it]


Augmenting example 52:


100%|██████████| 5/5 [00:28<00:00,  5.63s/it]


Augmenting example 53:


100%|██████████| 38/38 [03:18<00:00,  5.22s/it]


Augmenting example 54:


100%|██████████| 10/10 [01:10<00:00,  7.01s/it]


Augmenting example 55:


100%|██████████| 11/11 [01:10<00:00,  6.41s/it]


Augmenting example 56:


100%|██████████| 14/14 [01:22<00:00,  5.90s/it]


Augmenting example 57:


100%|██████████| 12/12 [00:45<00:00,  3.77s/it]


Augmenting example 58:


100%|██████████| 10/10 [00:43<00:00,  4.31s/it]


Augmenting example 59:


100%|██████████| 11/11 [02:22<00:00, 12.92s/it]


Augmenting example 60:


100%|██████████| 7/7 [00:29<00:00,  4.15s/it]


Augmenting example 61:


100%|██████████| 7/7 [00:26<00:00,  3.80s/it]


Augmenting example 62:


100%|██████████| 22/22 [01:59<00:00,  5.45s/it]


Augmenting example 63:


100%|██████████| 7/7 [00:44<00:00,  6.31s/it]


Augmenting example 64:


100%|██████████| 7/7 [00:43<00:00,  6.24s/it]


Augmenting example 65:


100%|██████████| 6/6 [00:43<00:00,  7.25s/it]


Augmenting example 66:


100%|██████████| 4/4 [00:35<00:00,  8.93s/it]


Augmenting example 67:


100%|██████████| 17/17 [01:26<00:00,  5.09s/it]


Augmenting example 68:


100%|██████████| 6/6 [00:37<00:00,  6.25s/it]


Augmenting example 69:


100%|██████████| 7/7 [00:32<00:00,  4.64s/it]


Augmenting example 70:


100%|██████████| 10/10 [00:56<00:00,  5.65s/it]


Augmenting example 71:


100%|██████████| 6/6 [00:39<00:00,  6.58s/it]


Augmenting example 72:


100%|██████████| 9/9 [00:37<00:00,  4.22s/it]


Augmenting example 73:


100%|██████████| 12/12 [01:06<00:00,  5.51s/it]


Augmenting example 74:


100%|██████████| 10/10 [00:49<00:00,  4.95s/it]


Augmenting example 75:


100%|██████████| 17/17 [01:15<00:00,  4.41s/it]


Augmenting example 76:


100%|██████████| 7/7 [00:45<00:00,  6.48s/it]


Augmenting example 77:


100%|██████████| 15/15 [01:07<00:00,  4.52s/it]


Augmenting example 78:


100%|██████████| 14/14 [01:32<00:00,  6.59s/it]


Augmenting example 79:


100%|██████████| 8/8 [00:44<00:00,  5.54s/it]


Augmenting example 80:


100%|██████████| 12/12 [00:59<00:00,  4.98s/it]


Augmenting example 81:


100%|██████████| 5/5 [00:22<00:00,  4.47s/it]


Augmenting example 82:


100%|██████████| 3/3 [00:20<00:00,  6.78s/it]


Augmenting example 83:


100%|██████████| 11/11 [01:28<00:00,  8.02s/it]


Augmenting example 84:


100%|██████████| 9/9 [00:54<00:00,  6.11s/it]


Augmenting example 85:


100%|██████████| 9/9 [00:42<00:00,  4.69s/it]


Augmenting example 86:


100%|██████████| 8/8 [00:44<00:00,  5.53s/it]


Augmenting example 87:


100%|██████████| 6/6 [00:47<00:00,  7.84s/it]


Augmenting example 88:


100%|██████████| 10/10 [00:48<00:00,  4.84s/it]


Augmenting example 89:


100%|██████████| 12/12 [01:10<00:00,  5.90s/it]


Augmenting example 90:


100%|██████████| 9/9 [00:52<00:00,  5.81s/it]


Augmenting example 91:


100%|██████████| 6/6 [00:41<00:00,  6.86s/it]


Augmenting example 92:


100%|██████████| 7/7 [01:05<00:00,  9.41s/it]


Augmenting example 93:


100%|██████████| 9/9 [00:51<00:00,  5.72s/it]


Augmenting example 94:


100%|██████████| 7/7 [00:45<00:00,  6.43s/it]


Augmenting example 95:


100%|██████████| 7/7 [00:43<00:00,  6.23s/it]


Augmenting example 96:


100%|██████████| 15/15 [01:24<00:00,  5.64s/it]


Augmenting example 97:


100%|██████████| 5/5 [00:34<00:00,  6.93s/it]


Augmenting example 98:


100%|██████████| 9/9 [01:36<00:00, 10.72s/it]


Augmenting example 99:


100%|██████████| 13/13 [01:27<00:00,  6.74s/it]


Augmenting example 100:


100%|██████████| 48/48 [04:31<00:00,  5.65s/it]


Augmenting example 101:


100%|██████████| 9/9 [00:52<00:00,  5.86s/it]


Augmenting example 102:


100%|██████████| 8/8 [00:37<00:00,  4.73s/it]


Augmenting example 103:


100%|██████████| 6/6 [00:32<00:00,  5.42s/it]


Augmenting example 104:


100%|██████████| 4/4 [00:38<00:00,  9.54s/it]


Augmenting example 105:


100%|██████████| 11/11 [01:24<00:00,  7.64s/it]


Augmenting example 106:


100%|██████████| 5/5 [00:25<00:00,  5.08s/it]


Augmenting example 107:


100%|██████████| 5/5 [00:37<00:00,  7.59s/it]


Augmenting example 108:


100%|██████████| 5/5 [00:24<00:00,  4.97s/it]


Augmenting example 109:


100%|██████████| 10/10 [01:29<00:00,  8.99s/it]


Augmenting example 110:


100%|██████████| 8/8 [00:44<00:00,  5.53s/it]


Augmenting example 111:


100%|██████████| 8/8 [01:23<00:00, 10.48s/it]


Augmenting example 112:


100%|██████████| 8/8 [01:17<00:00,  9.66s/it]


Augmenting example 113:


100%|██████████| 8/8 [00:40<00:00,  5.11s/it]


Augmenting example 114:


100%|██████████| 11/11 [01:01<00:00,  5.60s/it]


Augmenting example 115:


100%|██████████| 13/13 [01:12<00:00,  5.56s/it]


Augmenting example 116:


100%|██████████| 11/11 [01:23<00:00,  7.61s/it]


Augmenting example 117:


100%|██████████| 13/13 [01:33<00:00,  7.20s/it]


Augmenting example 118:


100%|██████████| 10/10 [00:31<00:00,  3.13s/it]


Augmenting example 119:


100%|██████████| 9/9 [00:40<00:00,  4.54s/it]


Augmenting example 120:


100%|██████████| 7/7 [00:48<00:00,  6.95s/it]


Augmenting example 121:


100%|██████████| 17/17 [01:25<00:00,  5.00s/it]


Augmenting example 122:


100%|██████████| 11/11 [00:48<00:00,  4.44s/it]


Augmenting example 123:


100%|██████████| 8/8 [00:42<00:00,  5.37s/it]


Augmenting example 124:


100%|██████████| 26/26 [01:53<00:00,  4.37s/it]


Augmenting example 125:


100%|██████████| 13/13 [00:56<00:00,  4.32s/it]


Augmenting example 126:


100%|██████████| 5/5 [00:29<00:00,  5.81s/it]


Augmenting example 127:


100%|██████████| 8/8 [00:52<00:00,  6.54s/it]


Augmenting example 128:


100%|██████████| 9/9 [00:45<00:00,  5.08s/it]


Augmenting example 129:


100%|██████████| 8/8 [00:54<00:00,  6.85s/it]


Augmenting example 130:


100%|██████████| 16/16 [01:09<00:00,  4.33s/it]


Augmenting example 131:


100%|██████████| 14/14 [01:39<00:00,  7.11s/it]


Augmenting example 132:


100%|██████████| 7/7 [00:49<00:00,  7.03s/it]


Augmenting example 133:


100%|██████████| 9/9 [00:27<00:00,  3.07s/it]


Augmenting example 134:


100%|██████████| 7/7 [00:53<00:00,  7.68s/it]


Augmenting example 135:


100%|██████████| 10/10 [00:51<00:00,  5.19s/it]


Augmenting example 136:


100%|██████████| 8/8 [00:37<00:00,  4.73s/it]


Augmenting example 137:


100%|██████████| 10/10 [01:07<00:00,  6.75s/it]


Augmenting example 138:


100%|██████████| 9/9 [00:43<00:00,  4.78s/it]


Augmenting example 139:


100%|██████████| 10/10 [01:01<00:00,  6.20s/it]


Augmenting example 140:


100%|██████████| 6/6 [00:43<00:00,  7.26s/it]


Augmenting example 141:


100%|██████████| 17/17 [01:52<00:00,  6.64s/it]


Augmenting example 142:


100%|██████████| 9/9 [00:54<00:00,  6.01s/it]


Augmenting example 143:


100%|██████████| 6/6 [00:43<00:00,  7.24s/it]


Augmenting example 144:


100%|██████████| 8/8 [00:40<00:00,  5.11s/it]


Augmenting example 145:


100%|██████████| 6/6 [00:45<00:00,  7.56s/it]


Augmenting example 146:


100%|██████████| 9/9 [00:44<00:00,  4.91s/it]


Augmenting example 147:


100%|██████████| 8/8 [00:56<00:00,  7.06s/it]


Augmenting example 148:


100%|██████████| 11/11 [00:46<00:00,  4.25s/it]


Augmenting example 149:


100%|██████████| 8/8 [00:37<00:00,  4.74s/it]


Augmenting example 150:


100%|██████████| 9/9 [00:50<00:00,  5.61s/it]


Augmenting example 151:


100%|██████████| 9/9 [00:58<00:00,  6.46s/it]


Augmenting example 152:


100%|██████████| 6/6 [00:39<00:00,  6.66s/it]


Augmenting example 153:


100%|██████████| 10/10 [01:10<00:00,  7.07s/it]


Augmenting example 154:


100%|██████████| 9/9 [00:40<00:00,  4.50s/it]


Augmenting example 155:


100%|██████████| 8/8 [00:32<00:00,  4.05s/it]


Augmenting example 156:


100%|██████████| 7/7 [00:57<00:00,  8.21s/it]


Augmenting example 157:


100%|██████████| 8/8 [00:54<00:00,  6.77s/it]


Augmenting example 158:


100%|██████████| 12/12 [00:59<00:00,  4.93s/it]


Augmenting example 159:


100%|██████████| 11/11 [00:47<00:00,  4.33s/it]


Augmenting example 160:


  8%|▊         | 1/12 [00:18<03:24, 18.57s/it]


KeyboardInterrupt: 

# One-shot Dialog Augmentation

In [5]:
from augmentation_prompts import one_shot_augmentation_prompt

In [26]:
augmentation_prompt = PromptTemplate.from_template(one_shot_augmentation_prompt)


class DialogSequence(BaseModel):
    result: List[DialogMessage]


model = ChatOpenAI(
    model="gpt-4o-mini",
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
    temperature=0.7,
)

parser = JsonOutputParser(pydantic_object=DialogSequence)

chain = augmentation_prompt | model | parser

In [None]:
new_data = []

for i, example in enumerate(dataset["train"]):
    if i == 0 or i == 1:
        continue

    print(f"Augmenting example {i}:")
    topic = example["topic"]
    all_dialogs = example["dialogs"]

    example["augmented_dialogs"] = []

    for element in tqdm(all_dialogs, total=len(all_dialogs)):
        try:
            orig_dialog = element["messages"]
            aug_dialog = chain.invoke({"topic": topic, "dialog": orig_dialog})
        except Exception as e:
            aug_dialog = e

        example["augmented_dialogs"].append(
            {"id": element["id"], "messages": aug_dialog}
        )

    new_data.append(example)
    with open("../data/gen_dataset_augment_one-shot_2", "wb") as fp:
        pickle.dump(new_data, fp)

    if i == 3:
        break

Augmenting example 2:


100%|██████████| 12/12 [00:57<00:00,  4.83s/it]


Augmenting example 3:


100%|██████████| 9/9 [00:33<00:00,  3.71s/it]
