In [1]:
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

In [111]:
from pprint import pprint

import torch
import transformers

from tqdm.auto import tqdm

from new_semantic_parsing import utils
from new_semantic_parsing import EncoderDecoderWPointerModel, Seq2SeqTrainer
from new_semantic_parsing.schema_tokenizer import TopSchemaTokenizer
from new_semantic_parsing.utils import compute_metrics, get_src_pointer_mask
from new_semantic_parsing.data import PointerDataset, Seq2SeqDataCollator

from cli.predict import make_test_dataset
from cli.preprocess import make_dataset

In [113]:
MODEL_PATH = '../output_dir/debug'
TOKENIZER_PATH = '../output_dir/debug/tokenizer'
DATA_PATH = '../data/top-dataset-semantic-parsing/eval.tsv'

In [3]:
schema_tokenizer = TopSchemaTokenizer.load(TOKENIZER_PATH)
text_tokenizer: transformers.PreTrainedTokenizer = schema_tokenizer.src_tokenizer

model = EncoderDecoderWPointerModel.from_pretrained(MODEL_PATH)
model.eval()

2020-06-30 14:09:26 | INFO | transformers.configuration_utils | loading configuration file ../output_dir/debug/tokenizer/config.json
2020-06-30 14:09:26 | INFO | transformers.configuration_utils | Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

2020-06-30 14:09:26 | INFO | transformers.tokenization_utils | Model name '../output_dir/debug/tokenizer' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-ma

In [4]:
dataset: PointerDataset = make_test_dataset(DATA_PATH, text_tokenizer, max_len=63)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    collate_fn=Seq2SeqDataCollator(pad_id=text_tokenizer.pad_token_id).collate_batch,
    num_workers=8,
)

HBox(children=(FloatProgress(value=0.0, description='tokenization', max=4462.0, style=ProgressStyle(descriptio…




## Beam search (k=1)

In [103]:
beam_search_preds_ids, beam_search_preds = utils.iterative_prediction(
    model=model,
    dataloader=dataloader,
    schema_tokenizer=schema_tokenizer,
    max_len=63,
    num_beams=1,
)

HBox(children=(FloatProgress(value=0.0, description='generation', max=140.0, style=ProgressStyle(description_w…




In [11]:
beam_search_preds[:10]

["[IN:GET_EVENT what's to do ]",
 '[IN:GET_EVENT What are they best place I could use to [SL:DESTINATION book a trip ] ]',
 '[IN:GET_EVENT Where can we take the [SL:DESTINATION [IN:GET_LOCATION [SL:LOCATION kids ] ] ] ]',
 '[IN:GET_EVENT Any [SL:DATE_TIME festivals ] [SL:DATE_TIME this weekend ] ]',
 '[IN:GET_EVENT Are there any [SL:DATE_TIME Christmas parties ] [SL:DATE_TIME this weekend ] ]',
 "[IN:GET_ESTIMATED_DURATION I need a restaurant that seems classy but is really [SL:DESTINATION [IN:GET_LOCATION [SL:DESTINATION Manhattan but doesn't require dinner time reservations ] ' t require dinner time reservations ] but doesn't require dinner time reservations [SL:DESTINATION Manhattan but doesn't require dinner time reservations ]",
 '[IN:GET_EVENT Any [SL:CATEGORY_EVENT live music events ] on [SL:DATE_TIME friday ] ]',
 '[IN:GET_EVENT [SL:CATEGORY_EVENT concerts ] by [SL:LOCATION sia ] ]',
 '[IN:GET_INFO_TRAFFIC when is the [SL:DATE_TIME next showing of the nutcracker ] ]',
 '[IN:GET

In [13]:
ex = dataset[0]
ex

InputDataClass(input_ids=tensor([ 101, 1184,  112,  188, 1106, 1202,  102]), decoder_input_ids=None, attention_mask=None, decoder_attention_mask=None, pointer_mask=tensor([0., 1., 1., 1., 1., 1., 0.]), decoder_pointer_mask=None, labels=None)

In [45]:
input_ids = ex.input_ids.unsqueeze(0)
pointer_mask = ex.pointer_mask.unsqueeze(0)
decoder_input_ids = (torch.ones(1, 1) * schema_tokenizer.bos_token_id).long()

print(input_ids,'\n', pointer_mask, '\n', decoder_input_ids)

tensor([[ 101, 1184,  112,  188, 1106, 1202,  102]]) 
 tensor([[0., 1., 1., 1., 1., 1., 0.]]) 
 tensor([[1]])


In [46]:
out = model(input_ids=input_ids, pointer_mask=pointer_mask, decoder_input_ids=decoder_input_ids)
logits = out[0]
print(logits.shape)
print(schema_tokenizer.vocab_size, input_ids.shape[1])

torch.Size([1, 1, 74])
67 7


In [47]:
next_symb = logits.max(-1).indices
next_symb

tensor([[65]])

In [53]:
decoder_input_ids = torch.cat([decoder_input_ids, next_symb], axis=-1)

In [54]:
next_symb.item()

65

In [63]:
logits[:, -1, :].unsqueeze(1).shape

torch.Size([1, 1, 74])

In [64]:
input_ids = ex.input_ids.unsqueeze(0)
pointer_mask = ex.pointer_mask.unsqueeze(0)
decoder_input_ids = (torch.ones(1, 1) * schema_tokenizer.bos_token_id).long()

for _ in range(63):
    out = model(input_ids=input_ids, pointer_mask=pointer_mask, decoder_input_ids=decoder_input_ids)
    logits = out[0]

    next_symb = logits[:, -1, :].max(-1).indices.unsqueeze(1)

    if next_symb.item() in [schema_tokenizer.eos_token_id, schema_tokenizer.pad_token_id]:
        break
    
    decoder_input_ids = torch.cat([decoder_input_ids, next_symb], axis=-1)

print(decoder_input_ids.squeeze())
print(schema_tokenizer.decode(decoder_input_ids.squeeze(), input_ids.squeeze()))

tensor([ 1, 65, 34, 21, 68, 69, 70, 71, 72, 66])
[BOS] [IN:GET_EVENT what's to do ]


In [56]:
next_symb

tensor([[65, 34]])

## Greedy search

maximally simple implementation

In [136]:
def iterative_greedy_prediction(model, dataset, schema_tokenizer, max_len):
    # we use dataset instead of dataloader here to simplify stuff

    predictions_ids = []
    predictions_str = []

    for i, ex in enumerate(tqdm(dataset, desc='prediction')):
        input_ids = ex.input_ids.unsqueeze(0)
        pointer_mask = ex.pointer_mask.unsqueeze(0)
        decoder_input_ids = (torch.ones(1, 1) * schema_tokenizer.bos_token_id).long()

        for _ in range(63):
            out = model(input_ids=input_ids, pointer_mask=pointer_mask, decoder_input_ids=decoder_input_ids)
            logits = out[0]

            next_symb = logits[:, -1, :].max(-1).indices.unsqueeze(1)

            decoder_input_ids = torch.cat([decoder_input_ids, next_symb], axis=-1)
            
            if next_symb.item() in [schema_tokenizer.eos_token_id, schema_tokenizer.pad_token_id]:
                break


        prediction = decoder_input_ids.squeeze()[1:]
        predictions_ids.append(prediction)
        predictions_str.append(schema_tokenizer.decode(prediction, input_ids.squeeze(), skip_special_tokens=True))

    return predictions_ids, predictions_str

In [131]:
greedy_preds_ids, greedy_preds = iterative_greedy_prediction(model, dataset, schema_tokenizer, 63)

HBox(children=(FloatProgress(value=0.0, description='prediction', max=4462.0, style=ProgressStyle(description_…




## Beam-1 vs greedy

sanity check passes

In [132]:
beam_search_preds[:10]

["[IN:GET_EVENT what's to do ]",
 '[IN:GET_EVENT What are they best place I could use to [SL:DESTINATION book a trip ] ]',
 '[IN:GET_EVENT Where can we take the [SL:DESTINATION [IN:GET_LOCATION [SL:LOCATION kids ] ] ] ]',
 '[IN:GET_EVENT Any [SL:DATE_TIME festivals ] [SL:DATE_TIME this weekend ] ]',
 '[IN:GET_EVENT Are there any [SL:DATE_TIME Christmas parties ] [SL:DATE_TIME this weekend ] ]',
 "[IN:GET_ESTIMATED_DURATION I need a restaurant that seems classy but is really [SL:DESTINATION [IN:GET_LOCATION [SL:DESTINATION Manhattan but doesn't require dinner time reservations ] ' t require dinner time reservations ] but doesn't require dinner time reservations [SL:DESTINATION Manhattan but doesn't require dinner time reservations ]",
 '[IN:GET_EVENT Any [SL:CATEGORY_EVENT live music events ] on [SL:DATE_TIME friday ] ]',
 '[IN:GET_EVENT [SL:CATEGORY_EVENT concerts ] by [SL:LOCATION sia ] ]',
 '[IN:GET_INFO_TRAFFIC when is the [SL:DATE_TIME next showing of the nutcracker ] ]',
 '[IN:GET

In [133]:
greedy_preds[:10]

["[IN:GET_EVENT what's to do ]",
 '[IN:GET_EVENT What are they best place I could use to [SL:DESTINATION book a trip ] ]',
 '[IN:GET_EVENT Where can we take the [SL:DESTINATION [IN:GET_LOCATION [SL:LOCATION kids ] ] ] ]',
 '[IN:GET_EVENT Any [SL:DATE_TIME festivals ] [SL:DATE_TIME this weekend ] ]',
 '[IN:GET_EVENT Are there any [SL:DATE_TIME Christmas parties ] [SL:DATE_TIME this weekend ] ]',
 "[IN:GET_ESTIMATED_DURATION I need a restaurant that seems classy but is really [SL:DESTINATION [IN:GET_LOCATION [SL:DESTINATION Manhattan but doesn't require dinner time reservations ] ' t require dinner time reservations ] but doesn't require dinner time reservations [SL:DESTINATION Manhattan but doesn't require dinner time reservations ]",
 '[IN:GET_EVENT Any [SL:CATEGORY_EVENT live music events ] on [SL:DATE_TIME friday ] ]',
 '[IN:GET_EVENT [SL:CATEGORY_EVENT concerts ] by [SL:LOCATION sia ] ]',
 '[IN:GET_INFO_TRAFFIC when is the [SL:DATE_TIME next showing of the nutcracker ] ]',
 '[IN:GET

In [138]:
beam_search_preds == greedy_preds

True

In [139]:
n_errors = 0
_beam_search_preds = beam_search_preds
_beam_search_preds_ids = beam_search_preds_ids

for i in range(len(beam_search_preds)):
    if _beam_search_preds[i] != greedy_preds[i]:
        n_errors += 1
        print('Mismatch ', n_errors)
        print(f'Beam-1: ', _beam_search_preds[i])
        print('Greedy len: ', len(greedy_preds_ids[i]), 'Beam-1 len: ', len(_beam_search_preds_ids[i]))
        print(f'Greedy: ', greedy_preds[i])
        print()

## Teacher forcing

In [116]:
dataset_with_labels: PointerDataset = make_dataset(DATA_PATH, text_tokenizer, schema_tokenizer)
dataloader_with_labels = torch.utils.data.DataLoader(
    dataset_with_labels,
    batch_size=32,
    collate_fn=Seq2SeqDataCollator(pad_id=text_tokenizer.pad_token_id).collate_batch,
    num_workers=8,
)

100%|██████████| 4462/4462 [00:00<00:00, 5113.48it/s]
100%|██████████| 4462/4462 [00:02<00:00, 1736.62it/s]


In [127]:
def teacher_forcing_prediction(model, dataset, schema_tokenizer, max_len):
    # we use dataset instead of dataloader here to simplify stuff

    predictions_ids = []
    predictions_str = []

    for i, ex in enumerate(tqdm(dataset, desc='prediction')):
        input_ids = ex.input_ids.unsqueeze(0)
        pointer_mask = ex.pointer_mask.unsqueeze(0)
        decoder_input_ids = ex.decoder_input_ids.unsqueeze(0)

        out = model(input_ids=input_ids, pointer_mask=pointer_mask, decoder_input_ids=decoder_input_ids)
        logits = out[0]
        prediction = logits.squeeze().max(-1).indices

        predictions_ids.append(prediction)
        predictions_str.append(schema_tokenizer.decode(prediction, input_ids.squeeze(), skip_special_tokens=True))

    return predictions_ids, predictions_str

In [130]:
tf_preds_ids, tf_preds = teacher_forcing_prediction(model, dataset_with_labels, schema_tokenizer, 63)

HBox(children=(FloatProgress(value=0.0, description='prediction', max=4462.0, style=ProgressStyle(description_…




## Teacher forced EM vs gready EM

comparing indices:

In [141]:
tf_preds_ids[0]

tensor([65, 34, 21, 68, 69, 70, 71, 72, 66,  2])

In [163]:
def tensors_equal(t1, t2):
    # strip EOS token
    if t1[-1] == 2:
        t1 = t1[:-1]
    if t2[-1] == 2:
        t2 = t2[:-1]

    return int(t1.shape == t2.shape and torch.all(t1 == t2))

In [167]:
tf_em = sum(tensors_equal(pred, ex.labels) for pred, ex in zip(tf_preds_ids, dataset_with_labels))
greedy_em = sum(tensors_equal(pred, ex.labels) for pred, ex in zip(greedy_preds_ids, dataset_with_labels))

print(tf_em, greedy_em)
print(tf_em == greedy_em)

303 303
True


comparing decoded strings:

In [170]:
import pandas as pd
data_df = pd.read_table(DATA_PATH, names=['text', 'tokens', 'schema'])
targets_str = list(data_df.schema)
targets_str[:3]

["[IN:GET_EVENT what 's to do ]",
 '[IN:UNSUPPORTED What are they best place I could use to book a trip ]',
 '[IN:GET_EVENT Where can we take [SL:ATTRIBUTE_EVENT the kids ] ]']

In [172]:
tf_em = sum(int(pred == target) for pred, target in zip(tf_preds, targets_str))
greedy_em = sum(int(pred == target) for pred, target in zip(greedy_preds, targets_str))

print(tf_em, greedy_em)

274 273


In [173]:
tf_matches = [int(pred == target) for pred, target in zip(tf_preds, targets_str)]
greedy_matches = [int(pred == target) for pred, target in zip(greedy_preds, targets_str)]

delta = [i for i, (x, y) in enumerate(zip(tf_matches, greedy_matches)) if x != y]

for d in delta:
    print('Traget: ', targets_str[d])
    print('Forced: ', tf_preds[d])
    print('Greedy: ', greedy_preds[d])

Traget:  [IN:GET_EVENT [SL:CATEGORY_EVENT College events ] ]
Forced:  [IN:GET_EVENT [SL:CATEGORY_EVENT College events ] ]
Greedy:  [IN:GET_EVENT [SL:CATEGORY_EVENT College events ] ] events [SL:DESTINATION [IN:GET_LOCATION [SL:LOCATION [IN:GET_LOCATION [SL:LOCATION [IN:GET_LOCATION [SL:LOCATION ] ] ] ] ] ] ] ] ]


^^ makes sence

## The issue

decoding back to text breaks something, probably it merges tokens like "what 's"

In [174]:
n_errors = 0

ids_preds = tf_preds_ids
decoded_preds = tf_preds

for i in range(len(targets_str)):
    if tensors_equal(ids_preds[i], dataset_with_labels[i].labels) and decoded_preds[i] != targets_str[i]:
        n_errors += 1
        print('Mismatch ', n_errors)

        print('Target str: ', targets_str[i])
        print('Decoded   : ', decoded_preds[i])

        print('Target ids : ', dataset_with_labels[i].labels)
        print('Predictions: ', ids_preds[i])
        print()

Mismatch  1
Target str:  [IN:GET_EVENT what 's to do ]
Decoded   :  [IN:GET_EVENT what's to do ]
Target ids :  tensor([65, 34, 21, 68, 69, 70, 71, 72, 66,  2])
Predictions:  tensor([65, 34, 21, 68, 69, 70, 71, 72, 66,  2])

Mismatch  2
Target str:  [IN:GET_EVENT What 's going on [SL:DATE_TIME Saturday ] ]
Decoded   :  [IN:GET_EVENT What's going on [SL:DATE_TIME Saturday ] ]
Target ids :  tensor([65, 34, 21, 68, 69, 70, 71, 72, 65, 53, 11, 73, 66, 66,  2])
Predictions:  tensor([65, 34, 21, 68, 69, 70, 71, 72, 65, 53, 11, 73, 66, 66,  2])

Mismatch  3
Target str:  [IN:GET_EVENT What 's going on [SL:DATE_TIME today ] ]
Decoded   :  [IN:GET_EVENT What's going on [SL:DATE_TIME today ] ]
Target ids :  tensor([65, 34, 21, 68, 69, 70, 71, 72, 65, 53, 11, 73, 66, 66,  2])
Predictions:  tensor([65, 34, 21, 68, 69, 70, 71, 72, 65, 53, 11, 73, 66, 66,  2])

Mismatch  4
Target str:  [IN:GET_EVENT What 's going on [SL:DATE_TIME tonight ] ]
Decoded   :  [IN:GET_EVENT What's going on [SL:DATE_TIME ton

## ~~Solution~~ Crutch: postprocessing

In [184]:
def top_postprocess(predicted_str):
    predicted_str = predicted_str.replace("'s", " 's")
    predicted_str = predicted_str.replace(".", " .")
    predicted_str = predicted_str.replace(",", " ,")
    return predicted_str

In [187]:
n_errors = 0

ids_preds = tf_preds_ids
decoded_preds = [top_postprocess(p) for p in tf_preds]

for i in range(len(targets_str)):
    if tensors_equal(ids_preds[i], dataset_with_labels[i].labels) and decoded_preds[i] != targets_str[i]:
        n_errors += 1
        print('Mismatch ', n_errors)

        print('Target str: ', targets_str[i])
        print('Decoded   : ', decoded_preds[i])

        print('Target ids : ', dataset_with_labels[i].labels)
        print('Predictions: ', ids_preds[i])
        print()

if n_errors == 0:
    print('Hooray!')

Hooray!
