# Générateur de dialogue


In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
templates_path = './data/Templates-beta/'


In [None]:
import os
import torch

from ChatbotDS.generator.dialogs_generator import DialogsGenerator
from ChatbotDS.generator.dialog import Dialog
from ChatbotDS.code.code_response import CodeResponse
from ChatbotDS.preprocessing.preprocessing_dialogs import PreprocessingDialogs
from ChatbotDS.utils.voc import Voc
from ChatbotDS.utils.utils import import_template2, Student
from ChatbotDS.utils.utils import import_replace_variable
from ChatbotDS.utils.templates import Templates
from ChatbotDS.chatbot.chatbot import Chatbot
from ChatbotDS.chatbot.trainer import Trainer
from ChatbotDS.chatbot.evaluation import Evaluation


## Générateur

Création du générateur :


In [None]:
gen = DialogsGenerator(templates_path)


Génération des dialogues :


In [None]:
mode = 'Full'  # Full, Train, Test
dialogs_len = 20000
dialog_len = 35

gen.generate_dialogs(mode=mode, dialogs_len=dialogs_len, dialog_len=dialog_len)


In [None]:
gen.generate_dialog()


Sauvegarde des dialogues :


In [None]:
dialog_name = './data/ChatbotDS_P_{}_.tsv'.format(mode)
gen.save(dialog_name)


## Code de la réponse

Création de la génération du code :


In [None]:
mode = 'Full'
dialog_name = f'./data/ChatbotDS_P_{mode}_.tsv'
code = CodeResponse(dialog_name, templates_path)


On applique le code "Baseline" :


In [None]:
baseline = code.baseline()


Sauvegarde des dialogues avec le code :


In [None]:
save_path = f'./data/ChatbotDS_{mode}_Baseline.tsv'
code.save(save_path)


## Pré-processing


Initialisation du pré-processing :


In [None]:
set_type = "Full"
data = './data/ChatbotDS_{}_Baseline.tsv'.format(set_type)
memory_len = 6
preprocess = 'base2'
process_output = False
unk_token = 'unk_token'  # [UNK], unk_token

prep = PreprocessingDialogs(
    name='prep',
    diags_path=data,
    mem_len=memory_len,
    preprocess=preprocess,
    process_output=process_output,
    unk_token=unk_token,
)


Pré-processing :


In [None]:
prep.prepare_data()


Sauvegarde des dialogues avec pré-processing :


In [None]:
save_file = './data/ChatbotDS_P_{}_Baseline_{}.tsv'.format(
    set_type, preprocess)
to_zip = False

prep.save_diags(save_file, to_zip=to_zip)


## Training


Import des données :


In [None]:
diags_path = './data/ChatbotDS_P_Train_Baseline_base2.tsv'
diags = import_template2(path=diags_path)


Création du Vocabulaire :


In [None]:
data_memory_size = 6
unk_token = 'unk_token'
unk_text = "Je n'ai pas compris, merci de reformuler la question."

voc = Voc(diags, data_memory_size, unk_token=unk_token, unk_text=unk_text)


Création du chatbot :


In [None]:
hidden_size = 128
attn_method = "general"
attn_hidden_size = 128
memory_size = len(diags[0][0][2])
code_size = len(diags[0][0][2])
encoder_layers = 1
bidirectional = True
encoder_dropout = 0.2
decoder_dropout = 0.1

chatbot = Chatbot(voc, hidden_size, attn_method, attn_hidden_size, memory_size,
                  code_size, encoder_layers=encoder_layers, bidirectional=bidirectional,
                  encoder_dropout=encoder_dropout, decoder_dropout=decoder_dropout)


Création des données d'entrainement et initialisation du trainer :


In [None]:
data = [[[p[0], p[2]] for p in d] for d in diags]

trainer = Trainer(chatbot, data, device="cuda")


Entrainement :


In [None]:
iterations = 8000
learning_rate = 5e-5
clip = 10
print_every = 100
progress = True
teacher_forcing = 1
noise_p_word = 0.025
noise_p = 0.2

trainer(
    iterations,
    learning_rate,
    clip=clip,
    progress=progress,
    print_every=print_every,
    teacher_forcing=teacher_forcing,
    noise_p_word=noise_p_word,
    noise_p=noise_p,
)


Sauvegarde du modèle :


In [None]:
chatbot_name = os.path.split(diags_path)[-1].split('.')[0]
chatbot.save(f'./{chatbot_name}.tar')


Chargement d'un modèle précédemment sauvé :


In [None]:
chatbot = Chatbot.load(f'./{chatbot_name}.tar')


## Evaluation


In [None]:
var_replace = import_replace_variable('./data/var_data.json')

evaluation = Evaluation(
    chatbot,
    var_replace,
    device="cuda",
    preprocess='base2',
    verbose=False,
)


In [None]:
res = evaluation.chat(print_memory=True)


In [None]:
evaluation.eval_data('./data/ChatbotDS_P_Test_Baseline_base2.tsv')
