In [1]:
import re
import os
import json
import pickle
import random
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from collections import OrderedDict
from bs4 import BeautifulSoup, element

from transformer.utils.tokenizer import MecabTokenizer, SpmTokenizer
from transformer.data.dataset import DatasetInterface, DatasetFromDir
# from transformer.preprocessors.bert_preprocessor import BertPreprocessor
# from transformer.preprocessors.blender_bot_preprocessor import GeneratorPretrainingPreprocessor
from transformer.preprocessors.utils import split_segment_by_speaker_ids, convert_turn_ids, flatten_sequence
from transformer.utils.common import get_nth_index, get_last_index, init_path

### Load Dataset

In [2]:
# # AIBUD_DEV
# dataset_dir = "/Users/aibud_dev/_jupyter"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# # AWS
# dataset_dir = "/home/ubuntu/data"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# # Korea_Server
# dataset_dir = "/home/mnt/guest1"
# path = "./config/file_path.json"
# file_path = None
# with open(path, "r", encoding="utf-8") as fp:
#     file_path = json.load(fp)

# bigshane_local
dataset_dir = "D:\_jupyter"
path = "./config/file_path.json"
file_path = None
with open(path, "r", encoding="utf-8") as fp:
    file_path = json.load(fp)

In [3]:
import re
import emoji
from soynlp.normalizer import repeat_normalize

emojis = ''.join(emoji.UNICODE_EMOJI.keys())
pattern = re.compile(f'[^ .,?!/@$%~％·∼()\x00-\x7Fㄱ-ㅣ가-힣{emojis}]+')
url_pattern = re.compile(
    r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)')

def clean(x):
    x = pattern.sub(' ', x)
    x = url_pattern.sub('', x)
    x = x.strip()
    x = repeat_normalize(x, num_repeats=2)
    return x

def trim_obj(obj):
    obj = list(set(obj))
    obj = [e for e in obj if e!=""]
    obj = sorted(obj)
    return obj
    
def show_five_nums(data):
    quartiles = np.percentile(data, [25, 50, 75])
    min_v = np.min(data)
    max_v = np.max(data)
    avg = np.mean(data)
    print("Min: {min_v:.3f}\tMax: {max_v:.3f}\tAvg: {avg:.3f}\tQ1: {q1:.3f}\tQ2: {q2:.3f}\tQ3: {q3:.3f}".format(min_v=min_v, max_v=max_v, avg=avg, q1=quartiles[0], q2=quartiles[1], q3=quartiles[2]))
    return min_v, max_v, quartiles[0], quartiles[1], quartiles[2], avg



## Wriet as Json format

In [4]:
# with open(root_dir + "/{language}/multi_turn/feed_data_0.json".format(language=language), "w", encoding=encoding) as fp:
#     json.dump(output, fp)

## Make dialog finetuning data

In [91]:
min_num_turns = 2
max_num_turns = 8
shuffle = True
train_ratio = 0.8
val_ratio = 0.15
test_ratio = 0.05
assert train_ratio+val_ratio+test_ratio == 1, "Sum must be equal to 1"


num_samples = 5000
append_condition = True

dataset_name_list = ["SelectStar", "EmpatheticDialogues", "AIHUB_Twitter", "AIHUB_EMOTION", "AIHUB_SSGI", "AIHUB_SNS", "KLUE_DST"]; _dataset_name = "eight"
# dataset_name_list = ["SelectStar", "EmpatheticDialogues", "AIHUB_Twitter", "AIHUB_EMOTION"]; _dataset_name = "four"
# dataset_name_list = ["SelectStar", "EmpatheticDialogues"]; _dataset_name = "selectstar"
# dataset_name_list = ["SelectStarPersona"]; _dataset_name = "persona"
_output_dir = dataset_dir + "/dataset/preprocessed/dialog_finetuning/kor/{_dataset_name}_n{min_v}x{max_v}".format(_dataset_name=_dataset_name, min_v=min_num_turns, max_v=max_num_turns)
_retriever_output_dir = dataset_dir + "/dataset/preprocessed/dialog_finetuning/retriever/{_dataset_name}_n{min_v}x{max_v}".format(_dataset_name=_dataset_name, min_v=min_num_turns, max_v=max_num_turns)
_generator_output_dir = dataset_dir + "/dataset/preprocessed/dialog_finetuning/generator/{_dataset_name}_n{min_v}x{max_v}".format(_dataset_name=_dataset_name, min_v=min_num_turns, max_v=max_num_turns)

In [92]:
language = "kor"
encoding = "UTF-8"
extension = "json"
print("# Load data from datasets...")
stat_dict = OrderedDict()
data = OrderedDict()

for dataset_name in dataset_name_list:
    root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
    if "multi_turn" not in file_path[dataset_name]["feed_data"]: continue

    data_path = file_path[dataset_name]["feed_data"]["multi_turn"].format(root_dir=root_dir, language=language)
    for file_name in os.listdir(data_path):
        if not file_name.endswith(extension): continue
        with open(data_path+file_name, "r", encoding=encoding) as fp:
            _data = json.load(fp)
            _data = [r for r in _data if len(r["speaker_ids"]) > 0]

            if dataset_name not in data: data[dataset_name] = []
            data[dataset_name] += _data

            _turn_size_list = []
            _utt_size_list = []
            for r in _data:
                _turn_size = 0
                last_speaker_id = prev_speaker_id = r["speaker_ids"][-1]
                for speaker_id in r["speaker_ids"][::-1]:
                    if speaker_id != last_speaker_id and speaker_id != prev_speaker_id: _turn_size += 1
                    prev_speaker_id = speaker_id
                _turn_size_list.append(_turn_size)
                _utt_size_list.append(len(r["speaker_ids"]))
            _avg_turn_size = np.mean(_turn_size_list)
            _avg_utt_size = np.mean(_utt_size_list)
            
            print("\t{}: {}\t\tavg(turn_size): {}\t\tavg(utt_size): {}".format(dataset_name, len(_data), _avg_turn_size, _avg_utt_size))
            stat_dict[dataset_name] = dict()
            stat_dict[dataset_name]["dataset_size"] = len(_data)
            stat_dict[dataset_name]["avg_turns"] = _avg_turn_size
            stat_dict[dataset_name]["avg_utterances"] = _avg_utt_size
print()

# Load data from datasets...
	SelectStar: 5692		avg(turn_size): 10.733661278988054		avg(utt_size): 33.66426563598032
	EmpatheticDialogues: 618		avg(turn_size): 5.699029126213592		avg(utt_size): 18.25242718446602
	AIHUB_Twitter: 1999		avg(turn_size): 3.1160580290145075		avg(utt_size): 6.789894947473737
	AIHUB_EMOTION: 84212		avg(turn_size): 2.6633021422125114		avg(utt_size): 5.326818030684463
	AIHUB_SSGI: 5855		avg(turn_size): 6.824935952177626		avg(utt_size): 15.44201537147737
	AIHUB_SNS: 36680		avg(turn_size): 2.788794983642312		avg(utt_size): 11.172001090512541
	KLUE_DST: 9000		avg(turn_size): 7.335111111111111		avg(utt_size): 14.670222222222222



In [93]:
retriever_data = dict()
generator_data = dict()
for dataset_name, _data in data.items():
    random.shuffle(_data)
    split_idx = len(_data)//2
    retriever_data[dataset_name] = _data[:split_idx]
    generator_data[dataset_name] = _data[split_idx:]

In [101]:
data = retriever_data
_output_dir = _retriever_output_dir

# data = generator_data
# _output_dir = _generator_output_dir

In [105]:
print("# Extract utterances from data...")
train_output = []
val_output = []
test_output = []

model_speaker_id = 1
user_speaker_id = 0

for dataset_name, _data in data.items():
    _output = []
    _train_output = []
    _val_output = []
    _test_output = []
    
    train_idx = int(len(_data) * train_ratio)
    val_idx = int(len(_data) * (train_ratio+val_ratio))
    
    empty_row_cnt = 0
    for row_idx, row in enumerate(tqdm(_data, initial=0, total=len(_data), desc=dataset_name)):
        if len([_utterance for _utterance in row["utterances"] if _utterance!=""]) < 1:
            empty_row_cnt += 1
            continue

        if "labels" not in row:
            labels = [""] * len(row["utterances"])
            row["labels"] = labels

        for begin_idx in range(0, max((len(row["speaker_ids"]) - min_num_turns), (len(row["speaker_ids"]) - max_num_turns))):
            for end_idx in range(begin_idx+1, len(row["speaker_ids"])):
                cur_num_turns = 0
                _prev_speaker_id = row["speaker_ids"][begin_idx]
                for _speaker_id in row["speaker_ids"][begin_idx+1:end_idx+1]:
                    if _speaker_id != _prev_speaker_id:
                        _prev_speaker_id = _speaker_id
                        if _speaker_id == user_speaker_id: cur_num_turns += 1

                if begin_idx <= 0:
                    if cur_num_turns < min_num_turns: continue # 대화의 첫 발화이면 min_num_turns턴 미만인 경우에 continue
                    if cur_num_turns > max_num_turns: continue # 대화의 첫 발화이면 max_num_turns턴 초과인 경우에 continue
                else:
                    # if cur_num_turns < min_num_turns: continue # 대화의 첫 발화가 아니면 min_num_turns턴 미만인 경우에 continue
                    # if cur_num_turns > max_num_turns: continue # 대화의 첫 발화가 아니면 max_num_turns턴 초과인 경우에 continue
                    if cur_num_turns < max_num_turns: continue # 대화의 첫 발화가 아니면 max_num_turns에 맞게 윈도윙
                if cur_num_turns > max_num_turns: break # max_num_turns보다 크면 break
                if row["speaker_ids"][end_idx] != model_speaker_id: continue # model_id로 발화가 끝나지 않는 경우 continue
                if len(row["speaker_ids"]) > end_idx+1 and row["speaker_ids"][end_idx] == row["speaker_ids"][end_idx+1]: continue # 다음 발화도 model_id의 것이면 continue

                _utterances = row["utterances"][begin_idx:end_idx+1]
                _speaker_ids = row["speaker_ids"][begin_idx:end_idx+1]
                _labels = row["labels"][begin_idx:end_idx+1]
                _persona = row["persona"] if "persona" in row else None
                if isinstance(_persona, dict): _persona = list(_persona.values())
                _entities = dict()
                if "entities" in row:
                    for k,v in row["entities"].items():
                        _entities[k] = []
                        for _v in v:
                            _entitiy_row_idx, _entitiy_begin_idx, _entitiy_end_idx, _entitiy_span = _v
                            entity_row_idx = _entitiy_row_idx - begin_idx
                            if entity_row_idx < 0: continue
                            if entity_row_idx > (end_idx - begin_idx): continue
                            _v_row = [entity_row_idx, _entitiy_begin_idx, _entitiy_end_idx, _entitiy_span]
                            _entities[k].append(_v_row)

                # utterance_length constraints: timestep보다 길면 탈락 (128 - num_special_tokens)
#                 concat_utterances = " ".join([str(_utterance) for _utterance in _utterances])
#                 if preprocessor.get_src_token_length(sentence=concat_utterances) >= 125: continue

                row_dict = row.copy()
                row_dict["utterances"] = _utterances
                row_dict["speaker_ids"] = _speaker_ids
                row_dict["labels"] = _labels
                row_dict["entities"] = _entities
                row_dict["persona"] = _persona
#                 _conditions = get_condition(utterances=_utterances, speaker_ids=_speaker_ids, personas=_persona, entities=_entities, user_speaker_id=user_speaker_id)
                _conditions = None
                row_dict["conditions"] = _conditions

#                 u_reply = True if "name" in row_dict["entities"] and len(row_dict["entities"]["name"]) > 0 and row_dict["entities"]["name"][-1][0] == len(row_dict["utterances"])-1 else False
                if row_idx < train_idx:
                    _train_output.append(row_dict)
                elif row_idx < val_idx:
                    _val_output.append(row_dict)
                else:
                    _test_output.append(row_dict)
    train_output += _train_output
    val_output += _val_output
    test_output += _test_output
    print("\t{} : (train: {}, val: {}, test: {})".format(dataset_name, len(_train_output), len(_val_output), len(_test_output)), "(empty_row_cnt: {})".format(empty_row_cnt))
    stat_dict[dataset_name]["dialog_size"] = (len(_train_output), len(_val_output), len(_test_output))

random.shuffle(train_output)
random.shuffle(val_output)
random.shuffle(test_output)
output = train_output + val_output + test_output
random.shuffle(output)

# Extract utterances from data...


SelectStar: 100%|█████████████████████████████████████████████████████████████████| 2846/2846 [00:05<00:00, 536.42it/s]


	SelectStar : (train: 29951, val: 5637, test: 1904) (empty_row_cnt: 0)


EmpatheticDialogues: 100%|█████████████████████████████████████████████████████████| 309/309 [00:00<00:00, 3153.14it/s]


	EmpatheticDialogues : (train: 876, val: 164, test: 56) (empty_row_cnt: 35)


AIHUB_Twitter: 100%|██████████████████████████████████████████████████████████████| 999/999 [00:00<00:00, 16933.23it/s]


	AIHUB_Twitter : (train: 1257, val: 348, test: 66) (empty_row_cnt: 0)


AIHUB_EMOTION: 100%|██████████████████████████████████████████████████████████| 42106/42106 [00:00<00:00, 50844.98it/s]


	AIHUB_EMOTION : (train: 22386, val: 4119, test: 1429) (empty_row_cnt: 0)


AIHUB_SSGI: 100%|██████████████████████████████████████████████████████████████████| 2927/2927 [00:31<00:00, 92.55it/s]


	AIHUB_SSGI : (train: 13519, val: 2664, test: 937) (empty_row_cnt: 0)


AIHUB_SNS: 100%|██████████████████████████████████████████████████████████████| 18340/18340 [00:01<00:00, 12698.63it/s]


	AIHUB_SNS : (train: 10041, val: 1870, test: 609) (empty_row_cnt: 0)


KLUE_DST: 100%|██████████████████████████████████████████████████████████████████| 4500/4500 [00:00<00:00, 4896.70it/s]


	KLUE_DST : (train: 21151, val: 4209, test: 1328) (empty_row_cnt: 0)


In [103]:
# model_speaker_id = 0 # user_speaker_id = 1
final_output = [] 
final_train_output = [] 
final_val_output = [] 
final_test_output = [] 
final_output += output
final_train_output += train_output
final_val_output += val_output
final_test_output += test_output
output_dir = _output_dir + "_one/"
print("{}: {}, {}, {}, {}".format(output_dir, len(output), len(train_output), len(val_output), len(test_output)))

D:\_jupyter/dataset/preprocessed/dialog_finetuning/retriever/eight_n2x8_one/: 132682, 105754, 20174, 6754


In [106]:
# model_speaker_id = 1 # user_speaker_id = 0
final_output += output
final_train_output += train_output
final_val_output += val_output
output = final_output
train_output = final_train_output
val_output = final_val_output
test_output = final_test_output
output_dir = _output_dir + "_both/"
print("{}: {}, {}, {}, {}".format(output_dir, len(output), len(train_output), len(val_output), len(test_output)))

D:\_jupyter/dataset/preprocessed/dialog_finetuning/retriever/eight_n2x8_both/: 257203, 204935, 39185, 6754


In [107]:
output_filename_template = "feed_data_{idx}.json"
size_per_file = 100000

print("# Write extracted dialogs to '{}'...".format(output_dir))
init_path(output_dir, reset=True)
if shuffle: random.shuffle(output)
for idx in range(0, len(output)//size_per_file+1):
    begin_idx = idx * size_per_file
    end_idx = min(len(output), (idx+1) * size_per_file)
    _output = output[begin_idx:end_idx]
    output_filename = output_filename_template.format(idx=str(idx).zfill(2))
    with open(output_dir+output_filename, "w", encoding="utf-8") as fp:
        print("\t{}: {}".format(output_filename, len(_output)))
        json.dump(_output, fp)

if not os.path.isdir(output_dir + "train"): os.mkdir(output_dir + "train")
if shuffle: random.shuffle(train_output)
for idx in range(0, len(train_output)//size_per_file+1):
    begin_idx = idx * size_per_file
    end_idx = min(len(train_output), (idx+1) * size_per_file)
    _output = train_output[begin_idx:end_idx]
    output_filename = output_filename_template.format(idx=str(idx).zfill(2))
    with open(output_dir+"train/"+output_filename, "w", encoding="utf-8") as fp:
        print("\ttrain {}: {}".format(output_filename, len(_output)))
        json.dump(_output, fp)
        
if not os.path.isdir(output_dir + "val"): os.mkdir(output_dir + "val")
if shuffle: random.shuffle(val_output)
for idx in range(0, len(val_output)//size_per_file+1):
    begin_idx = idx * size_per_file
    end_idx = min(len(val_output), (idx+1) * size_per_file)
    _output = val_output[begin_idx:end_idx]
    output_filename = output_filename_template.format(idx=str(idx).zfill(2))
    with open(output_dir+"val/"+output_filename, "w", encoding="utf-8") as fp:
        print("\tval {}: {}".format(output_filename, len(_output)))
        json.dump(_output, fp)
        
if not os.path.isdir(output_dir + "test"): os.mkdir(output_dir + "test")
if shuffle: random.shuffle(test_output)
for idx in range(0, len(test_output)//size_per_file+1):
    begin_idx = idx * size_per_file
    end_idx = min(len(test_output), (idx+1) * size_per_file)
    _output = test_output[begin_idx:end_idx]
    output_filename = output_filename_template.format(idx=str(idx).zfill(2))
    with open(output_dir+"test/"+output_filename, "w", encoding="utf-8") as fp:
        print("\ttest {}: {}".format(output_filename, len(_output)))
        json.dump(_output, fp)
        
# write sample data
if not os.path.isdir(output_dir + "sample"): os.mkdir(output_dir + "sample")
sample_data_path = output_dir + "sample/feed_data_sample.json"
with open(sample_data_path, "w", encoding="utf-8") as fp:
    json.dump(output[:num_samples], fp)
    
# write DESC.md
with open(output_dir + "DESC.md", "w", encoding="utf-8") as fp:
    _str = "## DESC\n" +\
        "- Korean Dialog dataset for Language Model training\n" +\
        "- total: {size}, avg_turns: {avg:.3f}\n".format(size=len(output), avg=np.mean([len(row["utterances"]) for row in output])) +\
        "- min turns: {mn_t}, max turns: {mx_t}\n".format(mn_t=min_num_turns, mx_t=max_num_turns) +\
        "\n" +\
        "## dataset statistics\n" +\
        "### dataset_name: (avg_turns, avg_utterances, num_dialogues, num_total_rows, num_train_rows, num_val_rows, num_test_rows)\n"
    for dataset_name, stat in stat_dict.items():
        _row_stat_template = "\t- {dataset_name}: ({avg_turns:.3f}, {avg_utterances:.3f}, {num_dialogues}, {num_total_rows}, {num_train_rows}, {num_val_rows}, {num_test_rows})\n"
        _row_stat = _row_stat_template.format(dataset_name=dataset_name, avg_turns=stat["avg_turns"], avg_utterances=stat["avg_utterances"], num_dialogues=stat["dataset_size"], num_total_rows=sum(stat["dialog_size"]), num_train_rows=stat["dialog_size"][0], num_val_rows=stat["dialog_size"][1], num_test_rows=stat["dialog_size"][2])
        _str += _row_stat
    fp.write(_str)

# Write extracted dialogs to 'D:\_jupyter/dataset/preprocessed/dialog_finetuning/retriever/eight_n2x8_both/'...
	feed_data_00.json: 100000
	feed_data_01.json: 100000
	feed_data_02.json: 57203
	train feed_data_00.json: 100000
	train feed_data_01.json: 100000
	train feed_data_02.json: 4935
	val feed_data_00.json: 39185
	test feed_data_00.json: 6754


In [None]:
u_utterance = {
    "P": "즐거움",
    "J": "기쁨",
    "S": "슬픔",
    "A": "분노",
    "O": "보통"
}
g_utterance = {
    "U": "유저정보",
    "R": "페르소나",
    "K": "외부지식",
    "E": "공감형대화",
    "N": "일반/기타"
}

from collections import Counter
labels = [label for row in output for label in row["labels"]]
Counter(labels)

### Insert Condition

In [None]:
from transformer.services.dialog_retriever.poly_encoder import PolyEncoderDialogRetriever
dataset_name = "four_n2x8_both"
_epoch = 10
_model_dir = dataset_dir + "/model/poly_encoder/v3/concat/{dataset_name}/epoch_{_epoch}/".format(dataset_name=dataset_name, _epoch=_epoch)

service = PolyEncoderDialogRetriever()
service.verbose = False
service.set_device(device="cuda:0")
service.load_model(model_dir=_model_dir)

def get_condition_from_retriever(utterances, speaker_ids, min_length=10, top_n=5, weight_bm25=True, prev_utterance=None, intersection_tolerance=0.9, max_retry=5):
    conditions = None
    try:
#         outputs = service.infer_next_utterance_bm25(utterances, min_length=min_length, top_n=top_n, prev_utterance=prev_utterance, intersection_tolerance=intersection_tolerance)
        outputs = service.infer_next_utterance(utterances, speaker_ids, min_length=min_length, top_n=top_n, weight_bm25=weight_bm25, prev_utterance=prev_utterance, intersection_tolerance=intersection_tolerance, max_retry=max_retry)
        conditions = [outputs[0][0]]
    except Exception as ex:
        print("{}: {}".format(type(ex), ex))
#         conditions = utterances[last_user_idx+1:]
    return conditions

def get_condition(utterances, speaker_ids, persona, entities, user_speaker_id):
    is_user_condition = False
    name_in_utterance = None
    conditions = None
    last_index = get_last_index(speaker_ids, value=user_speaker_id)

    if persona is not None and entities is not None:
        user_persona = [_persona for _persona in persona if _persona["id"]==user_speaker_id][0]
        if "name" in entities and len(entities["name"]) > 0:
            last_entity_info = entities["name"][-1]
            entity_row_idx, entity_begin_idx, entity_end_idx, entity_span = last_entity_info
            name_in_utterance = utterances[entity_row_idx][entity_begin_idx:entity_end_idx]
            if entity_row_idx in range(last_index+1, len(utterances)) and name_in_utterance != "" and name_in_utterance in user_persona["name"]:
                is_user_condition = True

    if is_user_condition:
        # 유저 정보를 이용하면, 컨디션으로 유저 정보를 넣어줘야지
        # user_name = user_persona["name"]
        # name_condition = {"name": name_in_utterance}
        # name_condition = name_condition.__str__()
        name_condition = "상대의 이름은 {name}입니다.".format(name=name_in_utterance)
        conditions = [name_condition]
    else:
#         conditions = None
        conditions = get_condition_from_retriever(utterances=utterances[:last_index+1], speaker_ids=speaker_ids[:last_index+1])
    return conditions

def insert_condition(row, user_speaker_id, use_condition):
    row_dict = row.copy()
    utterances = row["utterances"]
    speaker_ids = row["speaker_ids"]
    last_index = get_last_index(speaker_ids, value=user_speaker_id)

    persona = None
    entities = None
    if use_condition and "persona" in row: persona = row["perosna"]
    if use_condition and "entities" in row: persona = row["entities"]
    conditions = get_condition(utterances=utterances[:last_index+1], speaker_ids=speaker_ids[:last_index+1], persona=persona, entities=entities, user_speaker_id=user_speaker_id)
    row_dict["conditions"] = conditions
    return row_dict

In [None]:
input_dir = dataset_dir + "/dataset/preprocessed/dialog_finetuning/generator/{dataset_name}/".format(dataset_name=dataset_name)
output_dir = dataset_dir + "/dataset/preprocessed/dialog_finetuning/generator/condition/{dataset_name}/".format(dataset_name=dataset_name)
# dir_postfix_list = ["train/", "val/", "test/", "sample/", ""]
dir_postfix_list = ["train/", "sample/"]
# dir_postfix_list = ["val/", "test/", ""]

model_speaker_id = 0
user_speaker_id = 1
use_condition = False

for dir_postfix in dir_postfix_list:
    print("data_dir: {}".format(input_dir+dir_postfix))
    for filename in os.listdir(input_dir+dir_postfix):
        print("\tfilename: {}".format(filename))
        if os.path.isdir(input_dir+dir_postfix + filename): continue
        if not filename.endswith(".json"): continue

        _data = None
        with open(input_dir + dir_postfix + filename, "r", encoding="utf-8") as fp:
            _data = json.load(fp)

        if _data is None or len(_data) <= 0: continue
        data = []
        for row_idx, row in tqdm(enumerate(_data), initial=0, total=len(_data)):
            row.pop("conditions")
            row = insert_condition(row=row, user_speaker_id=user_speaker_id, use_condition=use_condition)
            assert row is not None, "Row is None!"
            data.append(row)

        init_path(output_dir + dir_postfix, reset=False)
        with open(output_dir + dir_postfix + filename, "w", encoding="utf-8") as fp:
            json.dump(data, fp)

<br><br><br><hr><br><br><br>

## SelectStar

In [49]:
u_utterance = {
    "P": "즐거움",
    "J": "기쁨",
    "S": "슬픔",
    "A": "분노",
    "O": "보통"
}
g_utterance = {
    "U": "유저정보",
    "R": "페르소나",
    "K": "외부지식",
    "E": "공감형대화",
    "N": "일반/기타"
}

In [50]:
dataset_name = "SelectStar"
language = "kor"
encoding = "UTF-8"
extension = "json"

root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

data = []
for filename in os.listdir(raw_dir):
    if not filename.endswith(extension): continue
    if filename == 'SDS_final_2nd_20210901.json':  continue
    with open(raw_dir+filename, "r", encoding=encoding) as fp:
        __data = json.load(fp)
        
        _data = []
        for row in __data:
            if "speakers_ids" in row:
                row["speaker_ids"] = row.pop("speakers_ids")
            if isinstance(row["persona"], dict) and "g" in row["persona"] and "u" in row["persona"]:
                row["persona"] = list(row["persona"].values())
            _data.append(row)
        data += _data

In [51]:
name_entity_pattern = "\$\{[^(\$\{).]*\}"
with open(raw_dir + "SDS_final_2nd_20210901.json", "r", encoding="utf-8") as fp:
    _data = json.load(fp)
    
new_output = []
for row_idx, row in enumerate(tqdm(_data, initial=0, total=len(_data))):
    row = row[0]
    utterances = []
    speaker_ids = []
    entities = {"name": []}
    for utterance, speaker_id in zip(row["utterances"], row["speakers_ids"]):
        _utterances = [_utterance for _utterance in utterance.split("\n") if _utterance.strip()!=""]
        utterances += _utterances
        speaker_ids += len(_utterances) * [speaker_id]
    
    labels = []
    for utterance_idx, (utterance, speaker_id) in enumerate(zip(utterances, speaker_ids)):
        counter_speaker_id = (speaker_id + 1) % 2
        counter_persona = None
        for k,v in row["persona"].items():
            if v["id"] == counter_speaker_id:
                counter_persona = v
                break
        
        label = ""
        search_result = re.search(name_entity_pattern, utterance)
        while search_result is not None:
            begin_idx = search_result.start()
            end_idx = search_result.end()
            name_to_replace = utterance[begin_idx+2:end_idx-1]
            utterance = utterance[:begin_idx] + name_to_replace + utterance[end_idx:]
            utterances[utterance_idx] = utterance
            
            if name_to_replace in counter_persona["name"]:
                end_idx = end_idx - 3
                span = end_idx - begin_idx
                entitiy_row = (utterance_idx, begin_idx, end_idx, span)
                entities["name"].append(entitiy_row)
                label = "U"

            # if speaker_id == 0: label = "U"
            search_result = re.search(name_entity_pattern, utterance)
        labels.append(label)
    
    row.pop("speakers_ids")
    row["idx"] = row_idx
    row["utterances"] = utterances
    row["speaker_ids"] = speaker_ids
    row["persona"] = list(row["persona"].values())
    row["labels"] = labels
    row["entities"] = entities
    assert len(row["utterances"])==len(row["speaker_ids"])==len(row["labels"]), "length does not match"
    new_output.append(row)

100%|██████████| 3351/3351 [00:00<00:00, 10488.66it/s]


In [52]:
data += new_output
random.shuffle(data)

In [53]:
print("data size:", len(data))

error_indice = []
for row_idx, row in enumerate(data):
    error_list = []
    if len(row["persona"])!=2: 
        error_list.append("persona length: {}".format(len(row["persona"])))
    if len(row["utterances"])!=len(row["speaker_ids"]) or len(row["utterances"])!=len(row["labels"]) or len(row["speaker_ids"])!=len(row["labels"]): 
        error_list.append("length difference: {}/{}/{}".format(len(row["utterances"]),len(row["speaker_ids"]),len(row["labels"])))

#     for speaker_id, label in zip(row["speaker_ids"], row["labels"]):
#         if speaker_id == 0 and label not in g_utterance:
#             error_list.append("wrong label for g: '{}'".format(label))
#         if speaker_id == 1 and label not in u_utterance:
#             error_list.append("wrong label for u: '{}'".format(label))

    if len(error_list) > 0:
        error_indice.append((row_idx, error_list))
        print("error: {}".format(row_idx))
        
string_output = ""
for row in error_indice:
    _str = "{}:\n".format(row[0])
    for error in row[1]:
        _str += "\t{}\n".format(error)
    string_output += _str

print("error nums:", len(error_indice))
if len(error_indice) <= 0:
#     random.shuffle(output)
    with open(root_dir + "/{language}/multi_turn/feed_data_0.json".format(language=language), "w", encoding=encoding) as fp:
        json.dump(data, fp)
else:
    with open("./error_contents.txt", "w", encoding="utf-8") as fp:
        fp.write(string_output)

data size: 5693
error nums: 0


## AIHUB_SSGI

In [None]:
dataset_name = "AIHUB_SSGI"
language = "kor"
encoding = "UTF-8"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

final_output = []
file_name_list = [file_name for file_name in os.listdir(raw_dir+"dialog") if file_name.endswith(".xlsx")]
for file_name in file_name_list:
    df = pd.read_excel(raw_dir+"dialog/"+file_name, engine="openpyxl")
    df = df.fillna("")
    
    speaker_info = dict()
    domain = None
    utterances = []
    speaker_ids = []
    categories = []
    tasks = []
    subtasks = []

    idx = 0
    output = []
    for row in df[["SPEAKER", "SENTENCE", "DOMAIN", "CATEGORY", "MAIN", "SUB", "SPEAKERID", "SENTENCEID"]].values.tolist():
        speaker, sentence, _domain, category, main, sub, speaker_id, sentence_id = row

        if int(sentence_id) == 1 and len(utterances) > 0:
            # append
            output_row = OrderedDict()
            output_row["idx"] = idx
            output_row["speakers"] = dict(OrderedDict(sorted(speaker_info.items())))
            output_row["domain"] = domain
            output_row["utterances"] = utterances
            output_row["speaker_ids"] = speaker_ids
            output_row["category"] = trim_obj(categories)
            output_row["task"] = trim_obj(tasks)
            output_row["subtask"] = trim_obj(subtasks)
            output_row = dict(output_row)
            output.append(output_row)

            # reset
            idx += 1
            speaker_info = dict()
            domain = None
            utterances = []
            speaker_ids = []
            categories = []
            tasks = []
            subtasks = []

        speaker_info[speaker_id] = speaker
        domain = _domain
        utterances.append(sentence)
        speaker_ids.append(speaker_id)
        categories.append(category)
        tasks.append(main)
        subtasks.append(sub)

    # append
    output_row = OrderedDict()
    output_row["idx"] = idx
    output_row["speakers"] = speaker_info
    output_row["domain"] = domain
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row["category"] = categories
    output_row["task"] = tasks
    output_row["subtask"] = subtasks
    output_row = dict(output_row)
    output.append(output_row)

    final_output += output
    print("output_size:", len(output), "final_output size:", len(final_output))

## AIHUB_Twitter

In [None]:
dataset_name = "AIHUB_Twitter"
language = "kor"
encoding = "UTF-8"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

df = pd.read_excel(raw_dir+"트위터_대화시나리오DB_2000Set.xlsx", engine="openpyxl")
df = df.fillna("")

output = []
for row in df.values.tolist():
    utterances = []
    for utterance in row:
        if utterance == "": break
        utterances.append(utterance)
    speaker_ids = [i%2 for i in range(0, len(utterances))]
    
    output_row = OrderedDict()
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row = dict(output_row)
    output.append(output_row)

## AIHUB_EMOTION

In [None]:
language = "kor"
encoding = "UTF-8"
extension = "xlsx"

dataset_name = "AIHUB_EMOTION"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

df = pd.DataFrame()
for file_name in os.listdir(raw_dir):
    if not file_name.endswith(extension): continue
    _df = pd.read_excel(raw_dir+file_name)
    df = pd.concat([df, _df], axis=0)
df = df.fillna("")
df = df.drop_duplicates()

output = []
domain = "emotional_dialog"
for row_idx, row in enumerate(df.values.tolist()):
    _, _, age_range, gender, category, healt_status, task, subtask, u1, s1, u2, s2, u3, s3 = row
    if gender == "여성": gender = "F"
    elif gender == "남성": gender = "M"
    else: gender = "O"
    speakers = {"0":"시스템", "1":"사용자"}
    utterances = [u1, s1, u2, s2, u3, s3]
    utterances = [utterance for utterance in utterances if utterance!=""]
    speaker_ids = [(i+1)%2 for i in range(0, len(utterances))]
    
    output_row = OrderedDict()
    output_row["speakers"] = speakers
    output_row["domain"] = domain
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row["category"] = category
    output_row["task"] = task
    output_row["subtask"] = subtask
    output_row = dict(output_row)
    output.append(output_row)

## AIHUB_SNS

In [123]:
language = "kor"
encoding = "UTF-8"
extension = "json"

dataset_name = "AIHUB_SNS"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

data = []
for file_name in os.listdir(raw_dir):
    if not file_name.endswith(extension): continue
    with open(raw_dir+file_name, "r", encoding=encoding) as fp:
        _data = json.load(fp)
        data += _data["data"]
    break

output = []
for row_idx, row in enumerate(data):
    domain = row["header"]["dialogueInfo"]["type"]
    category = row["header"]["dialogueInfo"]["topic"]
    speakers = OrderedDict()
    speaker_mapping = dict()
    for speaker_id, speaker in enumerate(row["header"]["participantsInfo"]):
        speaker_mapping[speaker["participantID"]] = speaker_id
        speakers[speaker_id] = OrderedDict()
        speakers[speaker_id]["id"] = speaker["participantID"]
        speakers[speaker_id]["sex"] = speaker["gender"]
        speakers[speaker_id]["age"] = speaker["age"]
        speakers[speaker_id]["residence"] = speaker["residentialProvince"]
        speakers[speaker_id] = dict(speakers[speaker_id])
    speakers = dict(speakers)
    summary = row["body"]["summary"]
    
    utterances = []
    speaker_ids = []
    datetime_list = []
    for _row in row["body"]["dialogue"]:
        utterance = _row["utterance"]
        utterances.append(utterance)
        speaker_id = speaker_mapping[_row["participantID"]]
        speaker_ids.append(speaker_id)
        _datetime = _row["date"] + " " + _row["time"]
        datetime_list.append(_datetime)
    
    output_row = OrderedDict()
    output_row["idx"] = row_idx
    output_row["speakers"] = speakers
    output_row["domain"] = domain
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row["datetime"] = datetime_list
    output_row["category"] = category
    output_row["summary"] = summary
    output_row = dict(output_row)
    output.append(output_row)

## KLUE_DST

In [None]:
dataset_name = "KLUE_DST" # 국립국어원
language = "kor"
encoding = "UTF-8"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

file_name_list = [file_name for file_name in os.listdir(raw_dir) if file_name.endswith(".json") and file_name.startswith("wos")]

output = []
kor_regex = "[ㄱ-ㅣ가-힣]"

for file_name in file_name_list:
    with open(raw_dir+file_name, "r", encoding=encoding) as fp:
        data = json.load(fp)

    for dialog_idx, dialog_row in enumerate(data):
        dialog = dialog_row["dialogue"]

        output_row = OrderedDict()
        utterances = []
        speaker_ids = []
        category = []
        task = []
        subtask = []
        _speaker_info = dict()
        for row in dialog:
            speaker = row["role"]
            if speaker not in _speaker_info:
                _speaker_info[speaker] = len(_speaker_info)
            utterance = row["text"]
            utterances.append(utterance)
            speaker_id = _speaker_info[speaker]
            speaker_ids.append(speaker_id)

            if "state" not in row: continue
            state = row["state"]
            for _state in state:
                _state = _state.split("-")
                if len(_state) > 0: 
                    _category = _state[0]
                    category.append(_category)
                if len(_state) > 1:
                    _task = _state[1]
                    task.append(_task)
                if len(_state) > 2:
                    _subtask = _state[2]
                    if re.search(kor_regex, _subtask) is not None:
                        subtask.append(_subtask)
            category = trim_obj(category)
            task = trim_obj(task)
            subtask = trim_obj(subtask)

        speaker_info = {v:k for k,v in _speaker_info.items()}

        output_row["idx"] = dialog_idx
        output_row["speakers"] = speaker_info
        output_row["utterances"] = utterances
        output_row["speaker_ids"] = speaker_ids
        output_row["category"] = category
        output_row["task"] = task
        output_row["subtask"] = subtask
        output_row = dict(output_row)
        output.append(output_row)

## NIKL_2020

In [None]:
dataset_name = "NIKL_2020"
language = "kor"
encoding = "UTF-8"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

file_name_list = [file_name for file_name in os.listdir(raw_dir) if file_name.endswith(".json")]
output = []
for idx, file_name in enumerate(file_name_list):
    with open(raw_dir+file_name, "r", encoding=encoding) as fp:
        data = json.load(fp)

    # date
    # date = datetime.strptime(data["document"][0]["metadata"]["date"], "%Y%M%d").strftime("%Y-%M-%d")
    _date = data["document"][0]["metadata"]["date"]
    date = "-".join([_date[0:4], _date[4:6], _date[6:8]])
    # tasks
    tasks = data["document"][0]["metadata"]["topic"]
    tasks = tasks.split(">")[-1]
    tasks = [task.strip() for task in tasks.split(",")]
    tasks = trim_obj(tasks)
    # speaker_info
    spkid2id = {speaker["id"]:_idx for _idx, speaker in enumerate(data["document"][0]["metadata"]["speaker"])}
    speaker_info = {spkid2id[speaker["id"]]:speaker for speaker in data["document"][0]["metadata"]["speaker"]}
    speaker_info = dict(OrderedDict(sorted(speaker_info.items())))
    # utterances & speaker_ids
    utterances = []
    speaker_ids = []
    elements = data["document"][0]["utterance"]
    prev_utterance = elements[0]["form"] # 철자 전사
    # utterance = elements[0]["original_form"] # 발음 전사
    prev_speaker_id = elements[0]["speaker_id"]
    for element_idx in range(1, len(elements)):
        cur_utterance = elements[element_idx]["form"] # 철자 전사
        # cur_utterance = elements[element_idx]["original_form"] # 발음 전사
        cur_speaker_id = elements[element_idx]["speaker_id"]

        if cur_speaker_id == prev_speaker_id:
            if prev_utterance.endswith("."): 
                utterances.append(prev_utterance)
                speaker_ids.append(prev_speaker_id)
                prev_utterance = cur_utterance
            else:
                prev_utterance = prev_utterance + cur_utterance
        else:
            for _prev_utterance in prev_utterance.split(". "):
                if _prev_utterance.strip() == "": continue
                if prev_utterance.strip().endswith(".") and not _prev_utterance.strip().endswith("."): 
                    _prev_utterance = _prev_utterance + "."
                utterances.append(_prev_utterance)
                speaker_ids.append(prev_speaker_id)
            # reset
            prev_utterance = cur_utterance
            prev_speaker_id = cur_speaker_id
    utterances.append(prev_utterance)
    speaker_ids.append(prev_speaker_id)
    speaker_ids = [spkid2id[speaker_id] for speaker_id in speaker_ids]


    output_row = OrderedDict()
    output_row["idx"] = idx
    output_row["document_id"] = data["document"][0]["id"]
    output_row["date"] = date
    output_row["task"] = tasks
    output_row["speaker_info"] = speaker_info
    output_row["setting"] = data["document"][0]["metadata"]["setting"]
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row = dict(output_row)
    output.append(output_row)

## NIKL_Dialog

In [293]:
def get_output_row(idx, soup):
    output_row = OrderedDict()
    speaker_info = dict()
    _speaker_id_mapping = dict()
    utterances = []
    speaker_ids = []

    header = soup.find("body").find("teiheader")
    _file_desc = header.find("filedesc")
    _profile_desc = header.find("profiledesc")
    speakers = _profile_desc.findAll("person")
    for speaker_id, speaker in enumerate(speakers):
        _speaker_info = OrderedDict(speaker.attrs)
        if len(splited) > 0:
            _speaker_info["job"] = splited[0].strip()
        if len(splited) > 1:
            _speaker_info["residence"] = splited[1].strip()
        speaker_info[speaker_id] = dict(_speaker_info)
        _speaker_id_mapping[_speaker_info["id"]] = speaker_id

    utterances = []
    speaker_ids = []

    prev_speaker_id = -1
    prev_utterance = ""
    _utterances = soup.findAll("u")
    for _utterance in _utterances:
        _speaker_id = _utterance.attrs["who"]
        if _speaker_id not in _speaker_id_mapping: continue
        speaker_id = _speaker_id_mapping[_speaker_id]
        # utterance = [u.text.strip() for u in _utterance.findAll("s")]
        utterance = [child.strip() for s in _utterance.findAll("s") for child in s.children if isinstance(child, str)]
        utterance = " ".join(utterance)
        utterance = utterance.replace("::", "")
        if utterance == "": continue

        if speaker_id == prev_speaker_id:
            prev_utterance = prev_utterance + " " + utterance
        else:
            if prev_speaker_id != -1:
                utterances.append(prev_utterance)
                speaker_ids.append(prev_speaker_id)
            prev_utterance = utterance
            prev_speaker_id = speaker_id

    if speaker_id == prev_speaker_id:
        prev_utterance = prev_utterance + utterance
    utterances.append(prev_utterance)
    speaker_ids.append(prev_speaker_id)
    
    assert len(utterances) == len(speaker_ids), "{l1} vs {l2}".format(l1=len(utterances), l2=len(speaker_ids))

    output_row["idx"] = idx
    output_row["speakers"] = speaker_info
    output_row["title"] = _file_desc.find("title").text
    output_row["domain"] = _profile_desc.find("settingdesc").text
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row["project_name"] = header.find("projectdesc").text
    output_row["distributor"] = _file_desc.find("distributor").text
    extent = _file_desc.find("extent").text
    output_row["num_eojeols"] = int(re.sub("[^0-9]", "", extent))
    output_row = dict(output_row)
    return output_row

file_path_list = []
for sub_dir in os.listdir(raw_dir):
    file_path = raw_dir+sub_dir+"/원시/"
    if not os.path.isdir(file_path): continue
    file_name_list = os.listdir(file_path)
    if len(file_name_list) != 1: continue
    file_name = file_name_list[0]
    file_path_list.append(file_path+file_name)
    
output = []
for idx, file_path in enumerate(file_path_list):
    with open(file_path, 'rb') as f:
        contents = f.read()
    contents = contents.decode("utf-16")
    contents = re.sub("\r\n", "", contents)
    soup = BeautifulSoup(contents, 'lxml')
    output_row = get_output_row(idx=idx, soup=soup)
    output.append(output_row)

## OpenSubtitles

In [None]:
dataset_name = "OpenSubtitles"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

pickle_path = file_path[dataset_name]["pickle"].format(root_dir=root_dir, language=language)
data = None
with open(pickle_path, "rb") as fp:
    data = pickle.load(fp)

output = []
kor_regex = "[ㄱ-ㅣ가-힣]"

_data = data["train"] + data["test"]
idx = 0
for row in _data:
    output_row = OrderedDict()
    output_row["idx"] = idx
    output_row["file_id"] = row.pop("file_id")
    context = row.pop("context")
    response = row.pop("response")
    
    _utterances = [row[k] for k in sorted(row.keys(), reverse=True)]
    _utterances.append(context)
    _utterances.append(response)

    utterances = []
    speaker_ids = []
    _idx = 0
    
    not_korean = True
    for utterance in _utterances:
        speaker_id = -1 * ((_idx+1) % 2)
        _idx += 1
        utterance = utterance.replace("'", "")
        utterance = utterance.replace('"', '')
        utterance = re.sub("^.* :", "", utterance).strip()
        utterance = utterance.replace("nbsp;", "").strip()
        if utterance=="": continue
        utterances.append(utterance)
        speaker_ids.append(speaker_id)
        if re.search(kor_regex, utterance) is not None: not_korean = False
    
    if not_korean: continue
        
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row = dict(output_row)
    output.append(output_row)
    idx += 1

## EmpatheticDialogues

#### kor

In [69]:
dataset_name = "SelectStar"
language = "kor"
encoding = "UTF-8"
extension = "json"

root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

data = []
for filename in os.listdir(raw_dir):
    if not filename.endswith(extension): continue
    with open(raw_dir+filename, "r", encoding=encoding) as fp:
        __data = json.load(fp)
        
        _data = []
        for row in __data:
            if "speakers_ids" in row:
                row["speaker_ids"] = row.pop("speakers_ids")
            _data.append(row)
        data += _data
print("data size:", len(data))
total_personas = [persona for row in data for persona in row["persona"].values()]
persona_df = pd.DataFrame(total_personas)
persona_df = persona_df.fillna("")

pd.options.mode.chained_assignment = None  # default='warn'
def find_most_similar_persona(age, gender, default_name="당신"):
    target_df = persona_df
    target_persona = {"name": default_name}
    try:
        target_df = target_df[target_df["gender"]==gender]
        age_diff = abs(target_df["age"] - age)
        target_df["age_diff"] = list(age_diff)
        target_age_diff = target_df.sort_values("age_diff").values.tolist()[0][-1]
        target_df = target_df.loc[(target_df["age_diff"]==target_age_diff)]
        target_persona = target_df.sample(1).to_dict("records")[0]
    except:
        target_persona = target_df.sample(1).to_dict("records")[0]
    return target_persona

data size: 2342


In [70]:
dataset_name = "EmpatheticDialogues"
language = "kor"
encoding = "UTF-8"
extension = "xlsx"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

filename_list = os.listdir(raw_dir)
df = pd.DataFrame()
for filename in filename_list:
    if not filename.endswith(extension): continue
    _df = pd.read_excel(raw_dir + filename)
    df = pd.concat([df, _df], axis=0)
df = df.fillna("")

In [71]:
# make preprocessed raw_data
data = []

name_entity_pattern = "\$\$[ㄱ-ㅣ가-힣]+\$\$"
prev_row_idx = -1
utterances = []
speaker_ids = []
personas = []
entities = dict()
turn_idx = 0

for _idx, row in enumerate(df.values.tolist()):    
    row_idx, persona_idx, eng, kor_before, kor_after, gender, age, married, job = row
    
    if _idx == 0: 
        prev_row_idx = row_idx
    if _idx > 0 and row_idx != "" and prev_row_idx != int(row_idx):
        # append dialog
        output_row = OrderedDict()
        output_row["idx"] = int(prev_row_idx)
        output_row["utterances"] = utterances
        output_row["speaker_ids"] = speaker_ids
        output_row["persona"] = personas
        output_row["entities"] = entities
        output_row = dict(output_row)
        data.append(output_row)
        # reset
        prev_row_idx = row_idx
        utterances = []
        speaker_ids = []
        personas = []
        entities = dict()
        turn_idx = 0
        
    _utterances = clean(str(kor_after))
    speaker_id = (turn_idx+1) % 2
    for utterance in _utterances.split(". "):
        utterances.append(utterance)
        speaker_ids.append(speaker_id)
        
    if gender!="":
        persona = OrderedDict()
        persona["id"] = speaker_id
        persona["persona_idx"] = int(persona_idx)
        persona["gender"] = gender
        if age != "": 
            persona["age"] = int(age)
        if married != "": 
            persona["married"] = int(married)
        if job != "": 
            persona["job"] = int(job)
        persona = dict(persona)
        personas.append(persona)
    turn_idx += 1
    
# append dialog
output_row = OrderedDict()
output_row["idx"] = int(prev_row_idx)
output_row["utterances"] = utterances
output_row["speaker_ids"] = speaker_ids
output_row["persona"] = personas
output_row["entities"] = entities
output_row = dict(output_row)
data.append(output_row)

In [72]:
# 이름 채워넣기
_data = data.copy()
data = []
for row in _data: 
    new_personas = []
    for persona in row["persona"]:
        new_persona = find_most_similar_persona(age=persona["age"], gender=persona["gender"])
        persona["name"] = new_persona["name"]
        new_personas.append(persona)
    row["persona"] = new_personas
    data.append(row)

In [73]:
# 이름 치환 & entities 정보 수집
name_entity_pattern = "\$\$[ㄱ-ㅣ가-힣]+\$\$"

_data = data.copy()
data = []

for _row in _data:
    row = _row.copy()
    
    utterances = []
    speaker_ids = []
    labels = []
    entities = dict()
    for utterance_idx, (_utterance, _speaker_id) in enumerate(zip(row["utterances"], row["speaker_ids"])):
        label = "" if _speaker_id == 1 else "N"
        
        while True:
            name_entity_search = re.search(name_entity_pattern, _utterance)
            if name_entity_search is not None:
                cur_persona = [persona for persona in row["persona"] if persona["id"]==(_speaker_id+1)%2][0]
                entity_start_idx = name_entity_search.start()
                entity_end_idx = name_entity_search.end()

                name = cur_persona["name"]
                if len(name) == 3: name = name[1:]
                _utterance = _utterance[:entity_start_idx] + name + _utterance[entity_end_idx:]
                entity_span = len(name)
                entity_end_idx = entity_start_idx + entity_span

                entity = [utterance_idx, entity_start_idx, entity_end_idx, entity_span]
                if "name" not in entities: entities["name"] = []
                entities["name"].append(entity)
                if _speaker_id == 0: label = "U"
            else:
                break

        utterances.append(_utterance)
        speaker_ids.append(_speaker_id)
        labels.append(label)

    row["utterances"] = utterances
    row["speaker_ids"] = speaker_ids
    row["labels"] = labels
    row["entities"] = entities
    
    data.append(row)

# with open(raw_dir + "empathetic_dialogues.pickle", "wb") as fp:
#     pickle.dump(data, fp)
# with open(root_dir + "/{language}/multi_turn/feed_data_0.json".format(language=language), "w", encoding=encoding) as fp:
#     json.dump(data, fp)

#### eng

In [None]:
dataset_name = "EmpatheticDialogues"
language = "eng"
encoding = "UTF-8"
root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

_data = None
with open(raw_dir+"empathetic_dialogues.pickle", "rb") as fp:
    _data = pickle.load(fp)

output = []
row_idx = 0
for k,v in _data.items():
    print("v size:", len(v))
    print("v conv size:", len(pd.DataFrame(v)["conv_id"].unique()))
    rows = []

    prev_conv_id = None
    output_row = OrderedDict()
    for row in v:
        cur_conv_id = row["conv_id"]
        if prev_conv_id is None:
            output_row["idx"] = row_idx
            output_row["context"] = row["context"]
            output_row["prompt"] = row["prompt"]
            output_row["utterances"] = []
            output_row["speaker_ids"] = []
            output_row["selfeval"] = row["selfeval"]
            output_row["tags"] = []
            prev_conv_id = cur_conv_id
            continue

        if cur_conv_id != prev_conv_id:
            tags = sorted(set(output_row["tags"]))
            output_row["tags"] = tags
            prompt = output_row.pop("prompt")
            output_row["utterances"].insert(0, prompt)
            if len(output_row["speaker_ids"]) == 0:
                previous_speaker_id = 0
            elif len(output_row["speaker_ids"]) > 1:
                previous_speaker_id = output_row["speaker_ids"][1]
            else:
                if output_row["speaker_ids"][0] == 0:
                    previous_speaker_id = 1
                else:
                    previous_speaker_id = 0
            output_row["speaker_ids"].insert(0, previous_speaker_id)
            output_row = dict(output_row)
            rows.append(output_row)
            prev_conv_id = cur_conv_id
            row_idx += 1

            output_row = OrderedDict()
            output_row["idx"] = row_idx
            output_row["context"] = row["context"]
            output_row["prompt"] = row["prompt"]
            output_row["utterances"] = []
            output_row["speaker_ids"] = []
            output_row["selfeval"] = row["selfeval"]
            output_row["tags"] = []
        else:
            output_row["utterances"].append(row["utterance"])
            output_row["speaker_ids"].append(int(row["speaker_idx"]))
            output_row["tags"].append(row["tags"])

    tags = sorted(set(output_row["tags"]))
    output_row["tags"] = tags
    prompt = output_row.pop("prompt")
    output_row["utterances"].insert(0, prompt)
    if len(output_row["speaker_ids"]) > 1:
        previous_speaker_id = output_row["speaker_ids"][1]
    else:
        if output_row["speaker_ids"][0] == 0:
            previous_speaker_id = 1
        else:
            previous_speaker_id = 0
    output_row["speaker_ids"].insert(0, previous_speaker_id)
    output_row = dict(output_row)
    rows.append(output_row)
    prev_conv_id = cur_conv_id
    row_idx += 1
    
    print("rows size:", len(rows), "\n")
    output += rows

## KaggleConversation

In [None]:
dataset_name = "KaggleConversation"
language = "kor"
encoding = "UTF-8"

root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

titles = pd.read_csv(raw_dir + "conversation_titles.csv")
data_df = pd.read_csv(raw_dir + "conversations.csv")

output = []
for row_idx, _date in enumerate(titles["date"].unique().tolist()):
    title = titles[titles["date"]==_date]["kor_title"].tolist()[0]
    translated_title = titles[titles["date"]==_date]["eng_title"].tolist()[0]
    utterances = data_df[data_df["date"]==_date]["kor_sent"].tolist()
    translated_utterances = data_df[data_df["date"]==_date]["eng_sent"].tolist()
    speaker_ids = [(i+1)%2 for i in range(0, len(utterances))]
    
    output_row = OrderedDict()
    output_row["idx"] = row_idx
    output_row["title"] = title
    output_row["utterances"] = utterances
    output_row["speaker_ids"] = speaker_ids
    output_row["translated_title"] = translated_title
    output_row["translated_utterances"] = translated_utterances
    output_row = dict(output_row)
    output.append(output_row)

## SelectStarPersona

In [151]:
dataset_name = "SelectStarPersona"
language = "kor"
encoding = "UTF-8"
name_entity_pattern = "\$\{[^(\$\{).]*\}"

root_dir = file_path[dataset_name]["root_dir"].format(dataset_dir=dataset_dir)
raw_dir = file_path[dataset_name]["raw_dir"].format(root_dir=root_dir, language=language)

file_name_list = [file_name for file_name in os.listdir(raw_dir) if file_name.endswith(".json")]
_output = []
for idx, file_name in enumerate(file_name_list):
    with open(raw_dir+file_name, "r", encoding=encoding) as fp:
        data = json.load(fp)
        for rows in data:
            for row in rows:
                _output.append(row)
                
output = []
for row in _output:
    row["speaker_ids"] = row.pop("speakers_ids")

    utterances = []
    speaker_ids = []
    for speaker_id, utterance in zip(row["speaker_ids"], row["utterances"]):
        for _utterance in utterance.strip().split("\n"):
            utterances.append(_utterance)
            speaker_ids.append(speaker_id)

    _utterances = []
    _labels = []
    entities = dict()
    entities["name"] = []

    for _utterance_idx, (_utterance, _speaker_id) in enumerate(zip(utterances, speaker_ids)):
        _label = ""
        search_result = re.search(name_entity_pattern, _utterance)
        while search_result is not None:
            begin_idx = search_result.start()
            end_idx = search_result.end()
            _utterance = _utterance[:begin_idx] + _utterance[begin_idx+2:end_idx-1] + _utterance[end_idx:]

            end_idx = end_idx - 3
            span = end_idx - begin_idx
            entitiy_row = (_utterance_idx, begin_idx, end_idx, span)
            entities["name"].append(entitiy_row)
            if _speaker_id == 0: _label = "U"            

            search_result = re.search(name_entity_pattern, _utterance)
        _utterances.append(_utterance)
        _labels.append(_label)

    row["utterances"] = _utterances
    row["speaker_ids"] = speaker_ids
    row["entities"] = entities
    row["labels"] = _labels
    persona = list(row["persona"].values())
    row["persona"] = persona
    output.append(row)