In [1]:
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

In [2]:
import torch

import transformers
import tokenizers

from new_semantic_parsing import EncoderDecoderWPointerModel
from new_semantic_parsing import TopSchemaTokenizer

In [3]:
ENCODER_NAME = 'distilbert-base-uncased'
HIDDEN = 768

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained(ENCODER_NAME, use_fast=True)
encoder = transformers.AutoModel.from_pretrained(ENCODER_NAME)

vocab = {'[', ']', 'IN:', 'SL:', 'GET_DIRECTIONS', 'DESTINATION',
         'DATE_TIME_DEPARTURE', 'GET_ESTIMATED_ARRIVAL'}
schema_tokenizer = TopSchemaTokenizer(vocab, tokenizer)

print(len(vocab) + 1)  # plus padding
print(schema_tokenizer.vocab_size)

# BERTConfig is a generic transformer and is only decoder Transformers support by now
decoder_config = transformers.BertConfig(
    vocab_size=schema_tokenizer.vocab_size + encoder.config.vocab_size,
    hidden_size=HIDDEN,
    is_decoder=True,  # adds cross-attention modules and enables causal masking
)

decoder = transformers.BertModel(decoder_config)

9
9


In [5]:
model = EncoderDecoderWPointerModel(encoder, decoder)

In [6]:
source_text = 'Directions to Lowell'
schema_text = '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]'

source_ids = tokenizer.encode(source_text)
schema_ids = schema_tokenizer.encode(schema_text, source_ids)

print(source_ids)
print(schema_ids)

[101, 7826, 2000, 15521, 102]
[2, 4, 5, 10, 11, 2, 6, 7, 12, 1, 1]


In [7]:
source_text = 'Directions to Lowell'
schema_text = '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]'

source_ids = tokenizer.encode(source_text)
schema_ids = schema_tokenizer.encode(schema_text, source_ids)

print(source_ids)
print(schema_ids)

[101, 7826, 2000, 15521, 102]
[2, 4, 5, 10, 11, 2, 6, 7, 12, 1, 1]


In [8]:
x = torch.tensor([source_ids])
y = torch.tensor([schema_ids])

mask = torch.ones_like(x)
mask[:, 0] = 0.
mask[source_ids == tokenizer.sep_token_id] = 0.

print(x, x.dtype)
print(y, y.dtype)

tensor([[  101,  7826,  2000, 15521,   102]]) torch.int64
tensor([[ 2,  4,  5, 10, 11,  2,  6,  7, 12,  1,  1]]) torch.int64


In [18]:
mask, mask.shape

(tensor([[0, 1, 1, 1, 1]]), torch.Size([1, 5]))

In [19]:
source_ids == tokenizer.sep_token_id

False

In [20]:
combined_logits = model(input_ids=x, decoder_input_ids=y, pointer_attention_mask=mask)[0]

In [21]:
combined_logits.shape == (1, 11, 14)

True

In [22]:
combined_logits.shape[2] == schema_tokenizer.vocab_size + len(source_ids)

True