In [1]:
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

In [2]:
from pprint import pprint

import torch
import transformers

from tqdm.auto import tqdm

from new_semantic_parsing import EncoderDecoderWPointerModel, Seq2SeqTrainer
from new_semantic_parsing.schema_tokenizer import TopSchemaTokenizer
from new_semantic_parsing.utils import compute_metrics, get_src_pointer_mask
from new_semantic_parsing.data import PointerDataset, Seq2SeqDataCollator

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased')

vocab = {'[', ']', 'IN:', 'SL:', 'GET_DIRECTIONS', 'DESTINATION',
         'DATE_TIME_DEPARTURE', 'GET_ESTIMATED_ARRIVAL'}
schema_tokenizer = TopSchemaTokenizer(vocab, tokenizer)

source_texts = [
    'Directions to Lowell',
    'Get directions to Mountain View',
]
schema_texts = [
    '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]',
    '[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View]]'
]

source_ids = [tokenizer.encode(t) for t in source_texts]
source_pointer_masks = [get_src_pointer_mask(i, tokenizer) for i in source_ids]

schema_ids = []
schema_pointer_masks = []

for src_id, schema in zip(source_ids, schema_texts):
    item = schema_tokenizer.encode_plus(schema, src_id)
    schema_ids.append(item.ids)
    schema_pointer_masks.append(item.pointer_mask)

dataset = PointerDataset(source_ids, schema_ids, source_pointer_masks, schema_pointer_masks)
dataset.torchify()

In [4]:
vars(dataset[0])

{'input_ids': tensor([  101, 17055,  1116,  1106, 16367,   102]),
 'decoder_input_ids': tensor([ 1,  9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10]),
 'attention_mask': None,
 'decoder_attention_mask': None,
 'pointer_mask': tensor([0., 1., 1., 1., 1., 0.]),
 'decoder_pointer_mask': tensor([0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 0.]),
 'labels': tensor([ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2])}

In [22]:
src_maxlen, _ = dataset.get_max_len()

model = EncoderDecoderWPointerModel.from_parameters(
    layers=3, hidden=128, heads=2, max_src_len=src_maxlen,
    src_vocab_size=tokenizer.vocab_size, tgt_vocab_size=schema_tokenizer.vocab_size
)

train_args = transformers.TrainingArguments(
    output_dir='output_dir',
    do_train=True,
    num_train_epochs=30,
    seed=42,
    learning_rate=1e-3,
)

# doesn't work, patch transformers?
transformers.trainer.is_wandb_available = lambda: False  # workaround

trainer = Seq2SeqTrainer(
    model,
    train_args,
    train_dataset=dataset,
    data_collator=Seq2SeqDataCollator(model.encoder.embeddings.word_embeddings.padding_idx),
    eval_dataset=dataset,
    compute_metrics=compute_metrics,
)
# a trick to reduce the amount of logging
trainer.is_local_master = lambda: False

# random.seed(42)
# torch.manual_seed(42)
# np.random.seed(42)

train_out = trainer.train()
eval_out = trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=1.0, style=ProgressStyle(description_wid…


{"eval_loss": 0.12566250562667847, "eval_accuracy": 1.0, "eval_exact_match": 1.0, "epoch": 30.0, "step": 30}


In [6]:
print(dataset[0].input_ids)
print(dataset[0].decoder_input_ids)
print(dataset[0].labels)

print()
print(tokenizer.decode(dataset[0].input_ids))
print(schema_tokenizer.decode(dataset[0].decoder_input_ids, dataset[0].input_ids))
print(schema_tokenizer.decode(dataset[0].labels, dataset[0].input_ids))

tensor([  101, 17055,  1116,  1106, 16367,   102])
tensor([ 1,  9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10])
tensor([ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2])

[CLS] Directions to Lowell [SEP]
[BOS] [IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ]
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] [EOS]


In [7]:
dl = torch.utils.data.DataLoader(
    dataset, batch_size=2, collate_fn=Seq2SeqDataCollator(model.encoder.embeddings.word_embeddings.padding_idx).collate_batch
)

batch = next(iter(dl))

In [8]:
print(batch.keys())
print(batch['input_ids'])
print(batch['decoder_input_ids'])
print()
print(batch['labels'])

dict_keys(['input_ids', 'decoder_input_ids', 'pointer_mask', 'decoder_pointer_mask', 'labels', 'attention_mask', 'decoder_attention_mask'])
tensor([[  101, 17055,  1116,  1106, 16367,   102,     0],
        [  101,  3949,  7768,  1106,  3757, 10344,   102]])
tensor([[ 1,  9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  0],
        [ 1,  9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10]])

tensor([[ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2,  0],
        [ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10,  2]])


In [9]:
out = model(**batch)
logits = out[1]

logits.max(-1).indices

tensor([[ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2,  2],
        [ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10,  2]])

In [10]:
out = model(input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'])
logits = out[0]

print(logits.max(-1).indices)
print(schema_tokenizer.decode(logits.max(-1).indices[0], batch['input_ids'][0]))
print(schema_tokenizer.decode(logits.max(-1).indices[1], batch['input_ids'][1]))

tensor([[ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10, 10,  2],
        [ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10,  2]])
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] ] [EOS]
[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View ] ] [EOS]


In [11]:
out = model(input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'], pointer_mask=batch['pointer_mask'])
logits = out[0]

print(logits.max(-1).indices)
print(schema_tokenizer.decode(logits.max(-1).indices[0], batch['input_ids'][0]))
print(schema_tokenizer.decode(logits.max(-1).indices[1], batch['input_ids'][1]))

tensor([[ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10, 10,  2],
        [ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10,  2]])
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] ] [EOS]
[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View ] ] [EOS]


In [12]:
_ = schema_tokenizer.decode(logits.max(-1).indices[0], batch['input_ids'][0].numpy())
print(_)

[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] ] [EOS]


In [13]:
_ = schema_tokenizer.decode(logits.max(-1).indices[1], batch['input_ids'][1].numpy())
print(_)

[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View ] ] [EOS]


In [24]:
example = dataset[0]
input_ids = example.input_ids
labels = example.labels
pointer_mask = example.pointer_mask
max_len = len(example.decoder_input_ids)

print('Input: ')
print(input_ids)
print(tokenizer.decode(input_ids))
print()
print('Expected: ')
print()

generated = model.generate(
    input_ids=input_ids.unsqueeze(0),
#     attention_mask=...,  # for batched decoding
    max_length=max_len+2,
    num_beams=4,
#     pad_token_id=tokenizer.pad_token_id,
    bos_token_id=schema_tokenizer.bos_token_id,
#     eos_token_id=schema_tokenizer.eos_token_id,
#     model_specific_kwargs,  # just in case
    pointer_mask=pointer_mask.unsqueeze(0),
).squeeze()

decoded = schema_tokenizer.decode(generated, input_ids)

print
print(generated)
print(decoded)

Input: 
tensor([  101, 17055,  1116,  1106, 16367,   102])
[CLS] Directions to Lowell [SEP]

Expected: 

tensor([ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2,  2])
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] [EOS] [EOS]


In [27]:
# batched generation

example = next(iter(dl))
input_ids = example['input_ids']
labels = example['labels']
pointer_mask = example['pointer_mask']
max_len = max(map(len, example['decoder_input_ids']))

print('Input: ')
print(input_ids)
print(tokenizer.decode(input_ids[0]))
print()
print('Expected: ')
print(labels[0])
print(schema_tokenizer.decode(labels[0], input_ids[0]))
print()

generated = model.generate(
    input_ids=input_ids,
#     attention_mask=...,  # for batched decoding
    max_length=max_len+2,
#     num_beams=4,
#     pad_token_id=tokenizer.pad_token_id,
    bos_token_id=schema_tokenizer.bos_token_id,
#     eos_token_id=schema_tokenizer.eos_token_id,
#     model_specific_kwargs,  # just in case
    pointer_mask=pointer_mask,
).squeeze()

decoded = schema_tokenizer.decode(generated[0], input_ids[0])

print()
print(generated)
print(decoded)

Input: 
tensor([[  101, 17055,  1116,  1106, 16367,   102,     0],
        [  101,  3949,  7768,  1106,  3757, 10344,   102]])
[CLS] Directions to Lowell [SEP] [PAD]

Expected: 
tensor([ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2,  0])
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] [EOS] [PAD]


tensor([[ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 10, 10,  2,  2, 10],
        [ 9,  7,  5, 12, 13, 14,  9,  8,  4, 15, 16, 10, 10,  2, 10]])
[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell ] ] [EOS] [EOS] ]


In [26]:
pointer_mask.repeat_interleave(repeats=4, dim=0)

tensor([[0., 1., 1., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 0.]])