diff --git a/README.md b/README.md index 93b5e8e..04f4c2d 100644 --- a/README.md +++ b/README.md @@ -68,8 +68,8 @@ Once the predictative files are generated, we will depend on the [AirDialogue to We are currently working on the scoring script. ``` airdialogue score --pred_data ./data/out_dir/dev_inference_out.txt \ - --true_data ./data/airdialogue/json/dev_data.json \ - --true_kb ./data/airdialogue/json/dev_kb.json \ + --true_data ./data/airdialogue/tokenized/dev.infer.tar.data \ + --true_kb ./data/airdialogue/tokenied/dev.infer.kb \ --task infer \ --output ./data/out_dir/dev_bleu.json ``` @@ -155,8 +155,8 @@ bash ./scripts/evaluate.sh -p dev -a ood1 -m ./data/synthesized_out_dir -o ./dat ###### Scoring ``` airdialogue score --pred_data ./data/synthesized_out_dir/dev_inference_out.txt \ - --true_data ./data/synthesized/json/dev_data.json \ - --true_kb ./data/airdialogue/json/dev_kb.json \ + --true_data ./data/synthesized/tokenized/dev.infer.tar.data \ + --true_kb ./data/airdialogue/tokenized/dev.infer.kb \ --task infer \ --output ./data/synthesized_out_dir/dev_bleu.json ``` diff --git a/airdialogue_model_tf.py b/airdialogue_model_tf.py index 6636fba..3752f2b 100644 --- a/airdialogue_model_tf.py +++ b/airdialogue_model_tf.py @@ -358,6 +358,11 @@ def add_arguments(parser): type=int, default=50, help="maximum sentence length for dialogue inference") + parser.add_argument( + "--self_play_start_turn", + type=str, + default=None, + help="Force self-play to run for an agent/customer start. [agent | customer]") parser.add_argument( "--num_kb_fields_per_entry", type=int, @@ -592,6 +597,7 @@ def create_hparams(flags): vocab_file=flags.vocab_file, max_dialogue_len=flags.max_dialogue_len, max_inference_len=flags.max_inference_len, + self_play_start_turn=flags.self_play_start_turn, num_kb_fields_per_entry=flags.num_kb_fields_per_entry, len_action=flags.len_action, # selfplay diff --git a/scripts/evaluate.sh b/scripts/evaluate.sh index fc51570..c67ac05 100644 --- a/scripts/evaluate.sh +++ b/scripts/evaluate.sh @@ -63,9 +63,9 @@ echo "out_dir", ${out_dir} echo "num_gpus", ${num_gpus} # run in foreground once and display the results -# python airdialogue_model_tf.py --task_type INFER --eval_prefix $partition --num_gpus $num_gpus \ -# --input_dir ${input_dir} --out_dir ${out_dir} \ -# --inference_output_file ${out_dir}/dev_inference_out.txt +python airdialogue_model_tf.py --task_type INFER --eval_prefix $partition --num_gpus $num_gpus \ + --input_dir ${input_dir} --out_dir ${out_dir} \ + --inference_output_file ${out_dir}/dev_inference_out.txt # run in foreground once and display the results python airdialogue_model_tf.py --task_type SP_EVAL --eval_prefix $partition --num_gpus $num_gpus \ diff --git a/utils/dialogue_utils.py b/utils/dialogue_utils.py index 41e6c45..c30aa50 100644 --- a/utils/dialogue_utils.py +++ b/utils/dialogue_utils.py @@ -16,6 +16,7 @@ import codecs import random +import re import numpy as np import tensorflow.compat.v1 as tf from airdialogue.evaluator.metrics import f1 @@ -217,7 +218,7 @@ def decode_and_evaluate(name, trans_f.write('') # Write empty string to ensure file is created. while True: try: - ut1, ut2, _ = model.generate_infer_utterance(sess, + ut1, ut2, action = model.generate_infer_utterance(sess, data_iterator_handle) batch_size = ut1.shape[0] for sent_id in range(batch_size): @@ -226,7 +227,19 @@ def decode_and_evaluate(name, nmt_outputs = [ut1, ut2][speaker] translation = get_translation_cut_both(nmt_outputs, sent_id, hparams.t1.encode(), hparams.t2.encode()) - trans_f.write((translation + b'\n').decode('utf-8')) + translation = translation.decode('utf-8') + if hparams.self_play_start_turn == 'agent': + if '' in translation: + ac_arr = [w.decode('utf-8') for w in action[sent_id]] + name = ac_arr[0] + ' ' + ac_arr[1] + flight = re.match(r'', ac_arr[2]) + flight = flight.group(1) if flight else '' + status = re.match(r'', ac_arr[3]) + status = status.group(1) if status else '' + translation += '|' + '|'.join([name, flight, status]) + else: + translation += '|||' + trans_f.write(translation + '\n') cnt += 1 if last_cnt - cnt >= 10000: # 400k in total utils.print_out('cnt= ' + str(cnt)) diff --git a/utils/iterator_utils.py b/utils/iterator_utils.py index 7ed86bc..5b91cef 100644 --- a/utils/iterator_utils.py +++ b/utils/iterator_utils.py @@ -150,7 +150,7 @@ def get_sub_items_supervised(data, kb): def get_sub_items_infer(data, kb): """process procedure for inference.""" - all_data = tf.string_split([data], sep="|").values + all_data = tf.string_split([data], sep="|", skip_empty=False).values intent, dialogue_context = all_data[0], all_data[1] return intent, dialogue_context, kb