This project contains implementations of FloNet described in the paper End-to-End Learning of Flowchart Grounded Task Oriented Dialogs.
-
Clone this repo
-
Set up a python environment using the
requirements.txt
-
Download the pre-trained Glove embeddings and unzip the contents to the folder
code/glove6B/
-
Download the pretrained checkpoints. The compressed file contains the pretrained retriever and generator models for both the
S-Flo
and theU-Flo
settings -
Run the inference script
a.
S-Flo
setting:python flonet.py --save-name='FlonetInferValS' --retriever_checkpoint=path-to-the-sflo-pretrained-retriever-checkpoint.pth.tar --gpt_model_checkpoint=path-to-the-sflo-pretrained-generator-folder --dialog-dir='../data/dialogs/' --cached-dialog-path='../data/saved_data/cached_in_domain_hard_dialogs.pkl' --domain='in_domain_hard' --saved-glove-path=./glove6B/ --inference=1 --num-epochs=0 --max_length=60
b.
U-Flo
setting:python flonet.py --save-name='FlonetInferValU' --si_model_checkpoint=path-to-the-uflo-pretrained-retriever-checkpoint.pth.tar --gpt_model_checkpoint=path-to-the-uflo-pretrained-generator-folder --dialog-dir='../data/dialogs/' --cached-dialog-path='../data/saved_data/cached_out_domain_dialogs.pkl' --domain='out_domain' --saved-glove-path=./glove6B/ --inference=1 --max_length=60 --num-epochs=0 --emb-size=200 --hidden-size=600
-
Clone this repo
-
Set up a python environment using the
requirements.txt
-
Download the pre-trained Glove embeddings and unzip the contents to the folder
code/glove6B/
-
Pre-train the retriever using
retriever.py
. Example command shown below:python retriever.py --cached-dialog-path='../data/saved_data/flodial_out.pkl' --domain=out_domain --hidden-size=600 --emb-size=200
-
Pre-train the generator using
generator.py
(input thedata/gpt_input/
file generated byretriever.py
). Example command shown below:python generator.py --dataset_path="../data/gpt_data/Retriever_out_domain.json" --dataset_cache="../data/saved_data/flonet_out_cache"
-
Rename generator's last checkpoint (in
code/generator/*model_name*/
) topytorch_model.bin
-
Feed pre-trained retriever checkpoint (in
data/model/proxybest_checkpoint...
), pre-trained generator checkpoint (incode/generator/*model_name*/
) and retriever input toflownet.py
python flonet.py --cached-dialog-path='../data/saved_data/flodial_out.pkl' --domain=out_domain --hidden-size=600 --emb-size=200 --si_model_checkpoint='../data/model/Retriever_checkpoint_out_domain_600_0.0001_16.pth.tar' --gpt_model_checkpoint='../data/generator/_gpt2_flowchart_out_cache_BLEU_1628354260/'
-
retriever.py : Used to pretrain the retriever of FloNet.
- need to change the following arguments
- cached-dialog-path : path to a processed copy of input data
- domain : in_domain (s-flo) or out_domain (u-flo)
- save-name : prefix for the saved data and logs
- dialog-dir : path to the dataset folder, ../data/flodial/
- cached-scores-path : path for storing the proxy scores
- saved-glove-path : point it to the glove embedding folder
- hidden-size and emb-size according to the domain as explained in the paper
- Saves the data required for pretraining GPT in data/gpt_data/ folder
- model checkpoint saved in path pointed by mode-dir argument
- logs are saved in log-dir argument
- need to change the following arguments
-
generator.py - pretrain the generator. Needs the gpt input generated by retriever for training (for the case of training using only flowchart+dialog history or only dialog history, use the GPT input file generated by generate_data_for_generator.py)
- dataset_path : path of the GPT input file
- dataset_cache : path for saving a cache of processed GPT input (tokenization and other things)
- use_flowchart : 1, =0 when training for only dialog history
- max_length : max decode length
-
flonet.py : trains FloNet, if not given checkpoints of retriever and generator, it trains the NoPretrain version. Combines arguments of generator and retriever. has additional below two arguments:
- si_model_checkpoint : retriever checkpoint file path
- gpt_model_checkpoint : generator checkpoint folder path