Skip to content

Commit

Permalink
lower tolerance for albert large and xlarge
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Aug 11, 2020
1 parent d651730 commit edd6655
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions scripts/conversion_toolkits/convert_tf_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,17 +464,18 @@ def convert_qkv_weights(tf_prefix, mx_prefix, is_mlm):
else:
raise NotImplementedError

tolerance = 1E-2 if cfg.MODEL.num_layers == 24 else 1E-3
def check_backbone(tested_model, tf_token_outputs_np):
# test conversion results for backbone model
tf_contextual_embedding = tf_token_outputs_np['sequence_output']
tf_pooled_output = tf_token_outputs_np['pooled_output']
contextual_embedding, pooled_output = \
tested_model(mx_input_ids, mx_token_types, mx_valid_length)
assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-3, 1E-3)
assert_allclose(pooled_output.asnumpy(), tf_pooled_output, tolerance, tolerance)
for i in range(batch_size):
ele_valid_length = valid_length[i]
assert_allclose(contextual_embedding[i, :ele_valid_length, :].asnumpy(),
tf_contextual_embedding[i, :ele_valid_length, :], 1E-3, 1E-3)
tf_contextual_embedding[i, :ele_valid_length, :], tolerance, tolerance)

if not has_mlm:
if test_conversion:
Expand All @@ -493,12 +494,12 @@ def check_backbone(tested_model, tf_token_outputs_np):
tf_mlm_scores = tf_mlm_outputs_np['mlm_logits'].reshape((batch_size, num_mask, -1))
contextual_embedding, pooled_output, mlm_scores = \
model(mx_input_ids, mx_token_types, mx_valid_length, mx_masked_positions)
assert_allclose(pooled_output.asnumpy(), tf_pooled_output, 1E-3, 1E-3)
assert_allclose(mlm_scores.asnumpy(), tf_mlm_scores, 1E-3, 1E-3)
assert_allclose(pooled_output.asnumpy(), tf_pooled_output, tolerance, tolerance)
assert_allclose(mlm_scores.asnumpy(), tf_mlm_scores, tolerance, tolerance)
for i in range(batch_size):
ele_valid_length = valid_length[i]
assert_allclose(contextual_embedding[i, :ele_valid_length, :].asnumpy(),
tf_contextual_embedding[i, :ele_valid_length, :], 1E-3, 1E-3)
tf_contextual_embedding[i, :ele_valid_length, :], tolerance, tolerance)
model.backbone_model.save_parameters(os.path.join(
save_dir, 'model.params'), deduplicate=True)
logging.info('Convert the backbone model in {} to {}/{}'.format(hub_model_dir,
Expand Down

0 comments on commit edd6655

Please sign in to comment.