# Comparing TensorFlow (original) and PyTorch model on the SQuAD task

You can use this small notebook to check the loss computation from the TensorFlow model to the PyTorch model. In the following, we compare the total loss computed by the models starting from identical initializations (position prediction linear layers with weights at 1 and bias at 0).

To run this notebook, follow these instructions:
- make sure that your Python environment has both TensorFlow and PyTorch installed,
- download the original TensorFlow implementation,
- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,
- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.

If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code.

In [1]:
import os
os.chdir('../')

## 1/ TensorFlow code

In [2]:
original_tf_inplem_dir = "./tensorflow_code/"
model_dir = "../google_models/uncased_L-12_H-768_A-12/"

vocab_file = model_dir + "vocab.txt"
bert_config_file = model_dir + "bert_config.json"
init_checkpoint = model_dir + "bert_model.ckpt"

input_file = "../data/squad_data/train-v1.1.json"
max_seq_length = 384
outside_pos = max_seq_length + 10
doc_stride = 128
max_query_length = 64
max_answer_length = 30
output_dir = "/tmp/squad_base/"
learning_rate = 3e-5

In [3]:
import importlib.util
import sys

spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/modeling.py')
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules['modeling_tensorflow'] = module

spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/run_bert_squad.py')
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules['run_squad_tensorflow'] = module
import modeling_tensorflow
from run_squad_tensorflow import *

In [4]:
bert_config = modeling_tensorflow.BertConfig.from_json_file(bert_config_file)
tokenizer = tokenization.BertTokenizer(
    vocab_file=vocab_file, do_lower_case=True)

eval_examples = read_squad_examples(
    input_file=input_file, is_training=True, max_num=16)

eval_features = convert_examples_to_features(
    examples=eval_examples,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    doc_stride=doc_stride,
    max_query_length=max_query_length,
    is_training=True)

# You can use that to test the behavior of the models when target are outside of the model input sequence
# for feature in eval_features:
#     feature.start_position = outside_pos
#     feature.end_position = outside_pos

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1000000000
INFO:tensorflow:example_index: 0
INFO:tensorflow:doc_span_index: 0
INFO:tensorflow:tokens: [CLS] to whom did the virgin mary allegedly appear in 1858 in lou ##rdes france ? [SEP] architectural ##ly , the school has a catholic character . atop the main building ' s gold dome is a golden statue of the virgin mary . immediately in front of the main building and facing it , is a copper statue of christ with arms up ##rai ##sed with the legend " ve ##ni ##te ad me om ##nes " . next to the main building is the basilica of the sacred heart . immediately behind the basilica is the gr ##otto , a marian place of prayer and reflection . it is a replica of the gr ##otto at lou ##rdes , france where the virgin mary reputed ##ly appeared to saint bern ##ade ##tte so ##ub ##iro ##us in 1858 . at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ) , is a simple , modern stone statue o

INFO:tensorflow:token_is_max_context: 13:True 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:T

INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

INFO:tensorflow:unique_id: 1000000004
INFO:tensorflow:example_index: 4
INFO:tensorflow:doc_span_index: 0
INFO:tensorflow:tokens: [CLS] what sits on top of the main building at notre dame ? [SEP] architectural ##ly , the school has a catholic character . atop the main building ' s gold dome is a golden statue of the virgin mary . immediately in front of the main building and facing it , is a copper statue of christ with arms up ##rai ##sed with the legend " ve ##ni ##te ad me om ##nes " . next to the main building is the basilica of the sacred heart . immediately behind the basilica is the gr ##otto , a marian place of prayer and reflection . it is a replica of the gr ##otto at lou ##rdes , france where the virgin mary reputed ##ly appeared to saint bern ##ade ##tte so ##ub ##iro ##us in 1858 . at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ) , is a simple , modern stone statue of mary . [SEP]
INFO:tensorflow:token_to_orig_map: 14:0

INFO:tensorflow:token_is_max_context: 13:True 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:T

INFO:tensorflow:token_is_max_context: 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:

INFO:tensorflow:token_is_max_context: 13:True 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:T

INFO:tensorflow:token_is_max_context: 13:True 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:T

INFO:tensorflow:token_is_max_context: 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 1

INFO:tensorflow:input_ids: 101 2073 2003 1996 4075 1997 1996 7769 1997 1996 4151 2892 1029 102 1996 2118 2003 1996 2350 2835 1997 1996 7769 1997 4151 2892 1006 12167 2025 2049 2880 4075 1010 2029 2024 1999 4199 1007 1012 2049 2364 8705 1010 2062 4887 8705 1010 2003 2284 2006 1996 3721 2408 2358 1012 3312 2697 2013 1996 2364 2311 1012 2214 2267 1010 1996 4587 2311 2006 3721 1998 2284 2379 1996 5370 1997 2358 1012 2984 2697 1010 3506 8324 18014 7066 1012 3394 8656 1998 3428 13960 1999 27596 2160 1006 1037 2280 7822 2415 1007 1010 4151 2892 2160 1010 2004 2092 2004 8902 25438 2050 2534 2379 1996 24665 23052 1012 1996 2118 2083 1996 2062 4887 8705 2038 7208 2000 17200 5406 20934 15937 3678 1012 2096 2025 3234 1010 20934 15937 3678 2038 5868 4898 2013 10289 8214 1998 2062 4887 8705 2580 1037 20934 15937 3678 3396 2005 17979 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:start_position: 44
INFO:tensorflow:end_position: 46
INFO:tensorflow:answer: more ##au seminary
INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1000000012
INFO:tensorflow:exampl

INFO:tensorflow:token_is_max_context: 12:True 13:True 14:True 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:Tr

INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

INFO:tensorflow:token_to_orig_map: 15:0 16:1 17:2 18:3 19:4 20:5 21:6 22:7 23:7 24:8 25:8 26:9 27:10 28:11 29:12 30:13 31:14 32:15 33:16 34:17 35:18 36:19 37:20 38:21 39:22 40:23 41:24 42:25 43:26 44:26 45:27 46:28 47:29 48:29 49:30 50:31 51:32 52:33 53:33 54:34 55:34 56:34 57:35 58:36 59:36 60:36 61:36 62:36 63:36 64:36 65:37 66:38 67:39 68:39 69:40 70:41 71:42 72:43 73:44 74:45 75:46 76:47 77:48 78:49 79:49 80:50 81:51 82:52 83:52 84:52 85:52 86:53 87:53 88:54 89:55 90:56 91:57 92:58 93:58 94:59 95:60 96:61 97:62 98:62 99:63 100:64 101:65 102:66 103:67 104:68 105:69 106:69 107:69 108:69 109:70 110:71 111:71 112:72 113:72 114:73 115:74 116:75 117:76 118:76 119:76 120:77 121:78 122:79 123:80 124:81 125:82 126:83 127:84 128:85 129:86 130:87 131:88 132:89 133:90 134:91 135:92 136:92 137:92 138:92 139:93 140:94 141:95 142:96 143:97 144:98 145:98 146:98 147:99 148:99 149:100 150:100
INFO:tensorflow:token_is_max_context: 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:Tru

INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

INFO:tensorflow:token_to_orig_map: 20:0 21:1 22:2 23:3 24:4 25:5 26:6 27:7 28:7 29:8 30:8 31:9 32:10 33:11 34:12 35:13 36:14 37:15 38:16 39:17 40:18 41:19 42:20 43:21 44:22 45:23 46:24 47:25 48:26 49:26 50:27 51:28 52:29 53:29 54:30 55:31 56:32 57:33 58:33 59:34 60:34 61:34 62:35 63:36 64:36 65:36 66:36 67:36 68:36 69:36 70:37 71:38 72:39 73:39 74:40 75:41 76:42 77:43 78:44 79:45 80:46 81:47 82:48 83:49 84:49 85:50 86:51 87:52 88:52 89:52 90:52 91:53 92:53 93:54 94:55 95:56 96:57 97:58 98:58 99:59 100:60 101:61 102:62 103:62 104:63 105:64 106:65 107:66 108:67 109:68 110:69 111:69 112:69 113:69 114:70 115:71 116:71 117:72 118:72 119:73 120:74 121:75 122:76 123:76 124:76 125:77 126:78 127:79 128:80 129:81 130:82 131:83 132:84 133:85 134:86 135:87 136:88 137:89 138:90 139:91 140:92 141:92 142:92 143:92 144:93 145:94 146:95 147:96 148:97 149:98 150:98 151:98 152:99 153:99 154:100 155:100
INFO:tensorflow:token_is_max_context: 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 2

In [5]:
eval_unique_id_to_feature = {}
for eval_feature in eval_features:
    eval_unique_id_to_feature[eval_feature.unique_id] = eval_feature

In [6]:
def input_fn_builder(features, seq_length, drop_remainder):
    """Creates an `input_fn` closure to be passed to TPUEstimator."""

    all_unique_ids = []
    all_input_ids = []
    all_input_mask = []
    all_segment_ids = []
    all_start_positions = []
    all_end_positions = []

    for feature in features:
        all_unique_ids.append(feature.unique_id)
        all_input_ids.append(feature.input_ids)
        all_input_mask.append(feature.input_mask)
        all_segment_ids.append(feature.segment_ids)
        all_start_positions.append(feature.start_position)
        all_end_positions.append(feature.end_position)

    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        num_examples = len(features)

        # This is for demo purposes and does NOT scale to large data sets. We do
        # not use Dataset.from_generator() because that uses tf.py_func which is
        # not TPU compatible. The right way to load data is with TFRecordReader.
        feature_map = {
            "unique_ids":
                tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
            "input_ids":
                tf.constant(
                    all_input_ids, shape=[num_examples, seq_length],
                    dtype=tf.int32),
            "input_mask":
                tf.constant(
                    all_input_mask,
                    shape=[num_examples, seq_length],
                    dtype=tf.int32),
            "segment_ids":
                tf.constant(
                    all_segment_ids,
                    shape=[num_examples, seq_length],
                    dtype=tf.int32),
            "start_positions":
                tf.constant(
                    all_start_positions,
                    shape=[num_examples],
                    dtype=tf.int32),
            "end_positions":
                tf.constant(
                    all_end_positions,
                    shape=[num_examples],
                    dtype=tf.int32),
        }

        d = tf.data.Dataset.from_tensor_slices(feature_map)
        d = d.repeat()
        d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
        return d

    return input_fn

In [7]:
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps, use_tpu,
                     use_one_hot_embeddings):
    """Returns `model_fn` closure for TPUEstimator."""

    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        unique_ids = features["unique_ids"]
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (start_logits, end_logits) = create_model(
            bert_config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = modeling_tensorflow.get_assigment_map_from_checkpoint(
                tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            seq_length = modeling_tensorflow.get_shape_list(input_ids)[1]

            def compute_loss(logits, positions):
                one_hot_positions = tf.one_hot(
                    positions, depth=seq_length, dtype=tf.float32)
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                loss = -tf.reduce_mean(
                    tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
                return loss

            start_positions = features["start_positions"]
            end_positions = features["end_positions"]

            start_loss = compute_loss(start_logits, start_positions)
            end_loss = compute_loss(end_logits, end_positions)

            total_loss = (start_loss + end_loss) / 2.0

            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            batch_size = modeling_tensorflow.get_shape_list(start_logits)[0]
            seq_length = modeling_tensorflow.get_shape_list(input_ids)[1]

            def compute_loss(logits, positions):
                one_hot_positions = tf.one_hot(
                    positions, depth=seq_length, dtype=tf.float32)
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                loss = -tf.reduce_mean(
                    tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
                return loss

            start_positions = features["start_positions"]
            end_positions = features["end_positions"]

            start_loss = compute_loss(start_logits, start_positions)
            end_loss = compute_loss(end_logits, end_positions)

            total_loss = (start_loss + end_loss) / 2.0

            predictions = {
                "unique_ids": unique_ids,
                "start_logits": start_logits,
                "end_logits": end_logits,
                "total_loss": tf.reshape(total_loss, [batch_size, 1]),
                "start_loss": tf.reshape(start_loss, [batch_size, 1]),
                "end_loss": tf.reshape(end_loss, [batch_size, 1]),
            }
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError(
                "Only TRAIN and PREDICT modes are supported: %s" % (mode))

        return output_spec

    return model_fn

In [8]:
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
    cluster=None,
    master=None,
    model_dir=output_dir,
    save_checkpoints_steps=1000,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=1000,
        num_shards=8,
        per_host_input_for_training=is_per_host))

model_fn = model_fn_builder(
    bert_config=bert_config,
    init_checkpoint=init_checkpoint,
    learning_rate=learning_rate,
    num_train_steps=None,
    num_warmup_steps=None,
    use_tpu=False,
    use_one_hot_embeddings=False)

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=False,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=12,
    predict_batch_size=1)

predict_input_fn = input_fn_builder(
    features=eval_features,
    seq_length=max_seq_length,
    drop_remainder=True)

INFO:tensorflow:Using config: {'_model_dir': '/tmp/squad_base/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11fd09630>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=1000, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input

In [9]:
tensorflow_all_out = []
tensorflow_all_results = []
for result in estimator.predict(predict_input_fn, yield_single_examples=True):
    unique_id = int(result["unique_ids"])
    eval_feature = eval_unique_id_to_feature[unique_id]
    start_logits = result["start_logits"]
    end_logits = result["end_logits"]
    total_loss = result["total_loss"]
    start_loss = result["start_loss"]
    end_loss = result["end_loss"]

    output_json = collections.OrderedDict()
    output_json["linex_index"] = unique_id
    output_json["tokens"] = [token for (i, token) in enumerate(eval_feature.tokens)]
    output_json["start_logits"] = [round(float(x), 6) for x in start_logits.flat]
    output_json["end_logits"] = [round(float(x), 6) for x in end_logits.flat]
    output_json["total_loss"] = [round(float(x), 6) for x in total_loss.flat]
    output_json["start_loss"] = [round(float(x), 6) for x in start_loss.flat]
    output_json["end_loss"] = [round(float(x), 6) for x in end_loss.flat]
    tensorflow_all_out.append(output_json)
    tensorflow_all_results.append(RawResult(
                                    unique_id=unique_id,
                                    start_logits=start_logits,
                                    end_logits=end_logits))
    break

INFO:tensorflow:Could not find trained model in model_dir: /tmp/squad_base/, running initialization to predict.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Running infer on CPU
INFO:tensorflow:*** Features ***
INFO:tensorflow:  name = end_positions, shape = (1,)
INFO:tensorflow:  name = input_ids, shape = (1, 384)
INFO:tensorflow:  name = input_mask, shape = (1, 384)
INFO:tensorflow:  name = segment_ids, shape = (1, 384)
INFO:tensorflow:  name = start_positions, shape = (1,)
INFO:tensorflow:  name = unique_ids, shape = (1,)
INFO:tensorflow:**** Trainable Variables ****
INFO:tensorflow:  name = bert/embeddings/word_embeddings:0, shape = (30522, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/embeddings/token_type_embeddings:0, shape = (2, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/embeddings/position_embeddings:0, shape = (512, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/embeddings/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = 

INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKP

INFO:tensorflow:  name = bert/encoder/layer_8/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_8/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*
INFO:tensorflow:  name = bert/

In [10]:
def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []

    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score

    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x

    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs


def compute_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case):
    """Compute final predictions."""
    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)

    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]

        prelim_predictions = []
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]

            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index]))

        prelim_predictions = sorted(
            prelim_predictions,
            key=lambda x: (x.start_logit + x.end_logit),
            reverse=True)

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]

            tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
            orig_doc_start = feature.token_to_orig_map[pred.start_index]
            orig_doc_end = feature.token_to_orig_map[pred.end_index]
            orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
            tok_text = " ".join(tok_tokens)

            # De-tokenize WordPieces that have been split off.
            tok_text = tok_text.replace(" ##", "")
            tok_text = tok_text.replace("##", "")

            # Clean whitespace
            tok_text = tok_text.strip()
            tok_text = " ".join(tok_text.split())
            orig_text = " ".join(orig_tokens)

            final_text = get_final_text(tok_text, orig_text, do_lower_case)
            if final_text in seen_predictions:
                continue

            seen_predictions[final_text] = True
            nbest.append(
                _NbestPrediction(
                    text=final_text,
                    start_logit=pred.start_logit,
                    end_logit=pred.end_logit))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(
                _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)

        assert len(nbest_json) >= 1

        all_predictions[example.qas_id] = nbest_json[0]["text"]
        all_nbest_json[example.qas_id] = nbest_json

    return all_predictions, all_nbest_json

In [11]:
all_predictions, all_nbest_json = compute_predictions(eval_examples[:1], eval_features[:1], tensorflow_all_results, 20, max_answer_length, True)

In [12]:
all_nbest_json

OrderedDict([('5733be284776f41900661182',
              [OrderedDict([('text', 'empty'),
                            ('probability', 1.0),
                            ('start_logit', 0.0),
                            ('end_logit', 0.0)])])])

In [13]:
print(len(tensorflow_all_out))
print(len(tensorflow_all_out[0]))
print(tensorflow_all_out[0].keys())
print("number of tokens", len(tensorflow_all_out[0]['tokens']))
print("number of start_logits", len(tensorflow_all_out[0]['start_logits']))
print("shape of end_logits", len(tensorflow_all_out[0]['end_logits']))

1
7
odict_keys(['linex_index', 'tokens', 'start_logits', 'end_logits', 'total_loss', 'start_loss', 'end_loss'])
number of tokens 176
number of start_logits 384
shape of end_logits 384


In [14]:
tensorflow_outputs = [tensorflow_all_out[0]['start_logits'], tensorflow_all_out[0]['end_logits'],
                     tensorflow_all_out[0]['total_loss'], tensorflow_all_out[0]['start_loss'],
                     tensorflow_all_out[0]['end_loss']]

## 2/ PyTorch code

In [15]:
import modeling
from run_squad import *

In [16]:
init_checkpoint_pt = "../google_models/uncased_L-12_H-768_A-12/pytorch_model.bin"

In [17]:
device = torch.device("cpu")
model = modeling.BertForQuestionAnswering(bert_config)
model.bert.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))
model.to(device)
model.qa_outputs.weight.data.fill_(1.0)
model.qa_outputs.bias.data.zero_()

tensor([0., 0.])

In [18]:
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
all_start_positions = torch.tensor([[f.start_position] for f in eval_features], dtype=torch.long)
all_end_positions = torch.tensor([[f.end_position] for f in eval_features], dtype=torch.long)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions, all_example_index)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)

model.eval()
None

In [19]:
batch = iter(eval_dataloader).next()
input_ids, input_mask, segment_ids, start_positions, end_positions, example_index = batch
print([t.shape for t in batch])
start_positions.size()

[torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1, 1]), torch.Size([1, 1]), torch.Size([1])]


torch.Size([1, 1])

In [20]:
pytorch_all_out = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    input_ids, input_mask, segment_ids, start_positions, end_positions, example_index = batch
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    start_positions = start_positions.to(device)
    end_positions = end_positions.to(device)

    total_loss, (start_logits, end_logits) = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
    
    eval_feature = eval_features[example_index.item()]

    output_json = collections.OrderedDict()
    output_json["linex_index"] = unique_id
    output_json["tokens"] = [token for (i, token) in enumerate(eval_feature.tokens)]
    output_json["total_loss"] = total_loss.detach().cpu().numpy()
    output_json["start_logits"] = start_logits.detach().cpu().numpy()
    output_json["end_logits"] = end_logits.detach().cpu().numpy()
    pytorch_all_out.append(output_json)
    break

Evaluating:   0%|          | 0/270 [00:00<?, ?it/s]


In [21]:
print(len(pytorch_all_out))
print(len(pytorch_all_out[0]))
print(pytorch_all_out[0].keys())
print("number of tokens", len(pytorch_all_out[0]['tokens']))
print("number of start_logits", len(pytorch_all_out[0]['start_logits']))
print("number of end_logits", len(pytorch_all_out[0]['end_logits']))

1
5
odict_keys(['linex_index', 'tokens', 'total_loss', 'start_logits', 'end_logits'])
number of tokens 176
number of start_logits 1
number of end_logits 1


In [22]:
pytorch_outputs = [pytorch_all_out[0]['start_logits'], pytorch_all_out[0]['end_logits'], pytorch_all_out[0]['total_loss']]

## 3/ Comparing the standard deviation of start_logits, end_logits and loss of both models

In [23]:
import numpy as np

In [24]:
print('shape tensorflow layer, shape pytorch layer, standard deviation')
print('\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,
                          np.array(pytorch_outputs[i]).shape, 
                          np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(3))))

shape tensorflow layer, shape pytorch layer, standard deviation
((384,), (1, 384), 5.244962470555037e-06)
((384,), (1, 384), 5.244962470555037e-06)
((1,), (), 4.560241698925438e-06)


In [27]:
print("Total loss of the TF model {} - Total loss of the PT model {}".format(tensorflow_outputs[2][0], pytorch_outputs[2]))

Total loss of the TF model 9.06024 - Total loss of the PT model 9.0602445602417
