In [1]:
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

In [2]:
import torch

from new_semantic_parsing import EncoderDecoderWPointerModel

In [3]:
output_dir = 'output_dir'

In [85]:
src_vocab_size = 23
tgt_vocab_size = 17

model = EncoderDecoderWPointerModel.from_parameters(
    layers=1,
    hidden=32,
    heads=2,
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    max_src_len=7,
    hidden_dropout_prob=0,
    attention_probs_dropout_prob=0,
)

input_ids = torch.randint(src_vocab_size, size=(3, 7))
tgt_sequence = torch.randint(tgt_vocab_size, size=(3, 11))
decoder_input_ids = tgt_sequence[:, :-1].contiguous()
labels = tgt_sequence[:, 1:].contiguous()

expected_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)

# os.mkdir(output_dir)
model.save_pretrained(output_dir)

random_model = EncoderDecoderWPointerModel.from_parameters(
    layers=1,
    hidden=32,
    heads=2,
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    max_src_len=7,
)
loaded_model, info = EncoderDecoderWPointerModel.from_pretrained(output_dir, output_loading_info=True)

output = loaded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)
random_output = random_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)

In [86]:
info

{'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}

### Check that the configs are equal

In [87]:
c1 = model.config.to_dict()
c2 = loaded_model.config.to_dict()

In [88]:
type(model), type(loaded_model)

(new_semantic_parsing.modeling_encoder_decoder_wpointer.EncoderDecoderWPointerModel,
 new_semantic_parsing.modeling_encoder_decoder_wpointer.EncoderDecoderWPointerModel)

In [89]:
dkeys = set(c1.keys()).symmetric_difference(set(c2.keys()))
print([(k, c1.get(k, None), c2.get(k, None)) for k in dkeys])

[]


In [90]:
for k in c1.keys():
    if c1[k] != c2[k]:
        print(k, c1[k], c2[k])

In [91]:
### 

In [92]:
print(len(output) == len(expected_output))
print(torch.allclose(expected_output[0], output[0]))
print(torch.allclose(expected_output[1], output[1]))

True
True
True


In [93]:
expected_output[1][0][0], output[1][0][0], random_output[1][0][0]

(tensor([-0.3304, -0.4807, -0.3249, -0.3067, -1.0206,  0.7124,  0.0791,  0.3639,
          0.7486,  0.1944, -0.5413, -0.4551,  0.0462,  0.2680,  1.3111,  0.1761,
          0.8181,  2.2675, -0.4928, -2.3798, -1.4215,  2.4018, -0.4288, -2.0853],
        grad_fn=<SelectBackward>),
 tensor([-0.3304, -0.4807, -0.3249, -0.3067, -1.0206,  0.7124,  0.0791,  0.3639,
          0.7486,  0.1944, -0.5413, -0.4551,  0.0462,  0.2680,  1.3111,  0.1761,
          0.8181,  2.2675, -0.4928, -2.3798, -1.4215,  2.4018, -0.4288, -2.0853],
        grad_fn=<SelectBackward>),
 tensor([-0.1568, -0.1964,  0.0501,  1.2691, -0.0596, -0.3503, -0.5219,  0.4210,
         -0.3945, -0.4205, -0.3819,  0.3941, -0.5649, -0.3812, -1.0649,  0.0159,
         -1.1970, -1.1131, -1.3002, -1.9903, -2.3364, -2.4054, -2.1643,  2.3802],
        grad_fn=<SelectBackward>))

In [94]:
print(model.encoder.embeddings.word_embeddings.weight[5]),
print(loaded_model.encoder.embeddings.word_embeddings.weight[5])

tensor([ 0.0057, -0.0163, -0.0330, -0.0056,  0.0153,  0.0070, -0.0099,  0.0141,
         0.0017,  0.0096, -0.0309,  0.0080, -0.0171, -0.0065,  0.0125,  0.0205,
         0.0158,  0.0038,  0.0385, -0.0065, -0.0018,  0.0198, -0.0163, -0.0304,
         0.0154,  0.0156,  0.0013,  0.0186, -0.0045,  0.0077,  0.0108,  0.0017],
       grad_fn=<SelectBackward>)
tensor([ 0.0057, -0.0163, -0.0330, -0.0056,  0.0153,  0.0070, -0.0099,  0.0141,
         0.0017,  0.0096, -0.0309,  0.0080, -0.0171, -0.0065,  0.0125,  0.0205,
         0.0158,  0.0038,  0.0385, -0.0065, -0.0018,  0.0198, -0.0163, -0.0304,
         0.0154,  0.0156,  0.0013,  0.0186, -0.0045,  0.0077,  0.0108,  0.0017],
       grad_fn=<SelectBackward>)


In [95]:
parameter = next(model.parameters())
loaded_parameter = next(loaded_model.parameters())

In [96]:
for i, (p1, p2) in enumerate(zip(model.parameters(), loaded_model.parameters())):
    assert torch.allclose(p1, p2)