In [1]:
import argparse
import functools
import json
import math
import pickle

from os.path import isfile
from os.path import join as pjoin
from glob import glob
from tqdm import tqdm
from time import time

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AutoModel, AutoTokenizer, BertConfig

from utils_parsing import *
from utils_caip import *

from train_model import *

from pprint import pprint
from random import choice
from time import time

model_name = '../data/ttad_transformer_model/caip_model_dir/caip_test_model'

args = pickle.load(open(model_name + '_args.pk', 'rb'))

args.data_dir = '../data/ttad_transformer_model/annotated_data/'

tokenizer = AutoTokenizer.from_pretrained(args.pretrained_encoder_name)
full_tree, tree_i2w = json.load(open(args.tree_voc_file))
dataset = CAIPDataset(tokenizer, args, prefix="", full_tree_voc=(full_tree, tree_i2w))

enc_model = AutoModel.from_pretrained(args.pretrained_encoder_name)
bert_config = BertConfig.from_pretrained("bert-base-uncased")
bert_config.is_decoder = True
bert_config.vocab_size = len(tree_i2w) + 8
bert_config.num_hidden_layers = args.num_decoder_layers
dec_with_loss = DecoderWithLoss(bert_config, args, tokenizer)
encoder_decoder = EncoderDecoderWithLoss(enc_model, dec_with_loss, args)
encoder_decoder.load_state_dict(torch.load(model_name + '.pth'))

encoder_decoder = encoder_decoder.cuda()
_ = encoder_decoder.eval()

def get_beam_tree(chat, noop_thres=0.95, beam_size=5, well_formed_pen=1e2):
    btr = beam_search(chat, encoder_decoder, tokenizer, dataset, beam_size, well_formed_pen)
    if btr[0][0].get('dialogue_type', 'NONE') == 'NOOP' and math.exp(btr[0][1]) < noop_thres:
        tree = btr[1][0]
    else:
        tree = btr[0][0]
    return tree


To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
# one-word commands
get_beam_tree("come")

{'dialogue_type': 'HUMAN_GIVE_COMMAND',
 'action_sequence': [{'action_type': 'MOVE'}]}

In [3]:
# with a small typo
get_beam_tree("buildz an elevator between the mountains")

{'dialogue_type': 'HUMAN_GIVE_COMMAND',
 'action_sequence': [{'action_type': 'BUILD',
   'location': {'location_type': 'REFERENCE_OBJECT',
    'relative_direction': 'BETWEEN',
    'reference_object': {'has_name': [0, [5, 5]]}},
   'schematic': {'has_name': [0, [2, 2]]}}]}

In [4]:
# example error
get_beam_tree("how big is the blue house?")

{'dialogue_type': 'GET_MEMORY'}

In [5]:
# reformulating to how it looks in templates
get_beam_tree("what is the size of this blue house?")

{'dialogue_type': 'GET_MEMORY'}

In [6]:
# the right interpretation is in the beam
[(b, s) for b, s, _ in beam_search("how big is the blue house?", encoder_decoder, tokenizer, dataset)]

[({'dialogue_type': 'GET_MEMORY'}, -0.20432472229003906),
 ({'dialogue_type': 'NOOP'}, -6.626201152801514),
 ({'dialogue_type': 'GET_MEMORY',
   'has_size': [0, [1, 1]],
   'has_colour': [0, [4, 4]],
   'reference_object': {'has_name': [0, [5, 5]]},
   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},
  -318.8774108886719),
 ({'dialogue_type': 'GET_MEMORY',
   'has_size': [0, [1, 1]],
   'has_colour': [0, [4, 4]],
   'reference_object': {'has_name': [0, [5, 5]]},
   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},
  -325.29071044921875),
 ({'dialogue_type': 'GET_MEMORY',
   'has_size': [0, [1, 1]],
   'has_colour': [0, [4, 4]],
   'reference_object': {'has_name': [0, [5, 5]]},
   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},
  -418.1733703613281)]