In [29]:
%env TF_KERAS=1
import codecs
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
from pyknp import Juman

env: TF_KERAS=1


In [32]:
SEQ_LEN = 64
BATCH_SIZE = 64
OUTPUT_LAYER_NUM = 4
LEARNING_RATE = 1e-5

PRETRAINED_PATH = 'bert-master/Japanese_L-12_H-768_A-12_E-30_BPE/'
CONFIG_PATH = PRETRAINED_PATH + 'bert_config.json'
CHECKPOINT_PATH = PRETRAINED_PATH + 'bert_model.ckpt'
VOCAB_PATH = PRETRAINED_PATH + 'vocab.txt'
jumanpp = Juman()

In [3]:
model = load_trained_model_from_checkpoint(
  CONFIG_PATH,
  CHECKPOINT_PATH,
  training=False,
  trainable=False,
  output_layer_num=OUTPUT_LAYER_NUM,
  seq_len=SEQ_LEN
)

In [4]:
print(model.inputs, model.outputs)

[<tf.Tensor 'Input-Token:0' shape=(None, 64) dtype=float32>, <tf.Tensor 'Input-Segment:0' shape=(None, 64) dtype=float32>] [<tf.Tensor 'Encoder-Output/Identity:0' shape=(None, 64, 3072) dtype=float32>]


In [40]:
# prepare token->idx dictionary
def make_token_dict(vocab_path):
  token_dict = {}
  with codecs.open(vocab_path, 'r', 'utf8') as reader:
    for line in reader:
      if line != ' \n':
        token = line.strip()
      else:
        token = line.strip('\n')
      token_dict[token] = len(token_dict)
  return token_dict
token_dict = make_token_dict(VOCAB_PATH)

In [41]:
print(token_dict['[CLS]'])
print(token_dict['日本'])
print(token_dict['##ｏｓ'])
print(token_dict['好調な'])

2
49
2451
32005


In [48]:
def preprocess_jpn_sentence(w):
    result = jumanpp.analysis(w)
    w = '[cls]'
    for mrph in result.mrph_list():
      w = w + '　' + mrph.midasi
    w += '　[sep]'
    return w

In [44]:
def dict_lookup(ids, token_dict):
  assert isinstance(ids, list)
  for word_embedding in ids:
    if word_embedding != 0:
      print('{}->{}'.format(word_embedding, [k for k, v in token_dict.items() if v == word_embedding]))

In [49]:
text = '今回の事件が貿易紛争に発展しかねない。'
texts = ['日本語テスト。', '明日の天気はどうですか？']

In [50]:
preprocess_jpn_sentence(text)

'[cls]\u3000今回\u3000の\u3000事件\u3000が\u3000貿易\u3000紛争\u3000に\u3000発展\u3000し\u3000かね\u3000ない\u3000。\u3000[sep]'

In [51]:
tokenizer = Tokenizer(token_dict, cased=True)

In [52]:
ids = []
segments = []
# for text in texts:
id, segment = tokenizer.encode(text, max_len=SEQ_LEN)
ids.append(id)
segments.append(segment)
print(ids, '\n', segments)

[[2, 774, 111, 5, 270, 2531, 11, 1, 15447, 16983, 24121, 8, 1318, 5134, 790, 3861, 5342, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] 
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [53]:
print(dict_lookup(ids[0], token_dict))

2->['[CLS]']
774->['今']
111->['回']
5->['の']
270->['事']
2531->['件']
11->['が']
1->['[UNK]']
15447->['易']
16983->['紛']
24121->['争']
8->['に']
1318->['発']
5134->['展']
790->['しか']
3861->['##ね']
5342->['##ない']
7->['。']
3->['[SEP]']
None


In [8]:
ids = []
segments = []
for text in texts:
  id, segment = tokenizer.encode(text, max_len=SEQ_LEN)
  ids.append(id)
  segments.append(segment)
print(ids, '\n', segments)

In [28]:
print(dict_lookup(ids[0]))
print(dict_lookup(ids[1]))

2->['[CLS]']
29->['日']
97->['本']
156->['語']
3003->['テスト']
7->['。']
3->['[SEP]']
None
2->['[CLS]']
1503->['明']
29->['日']
5->['の']
866->['天']
1482->['気']
9->['は']
5272->['##どう']
12323->['##です']
856->['##か']
1566->['？']
3->['[SEP]']
None


In [9]:
result = model.predict([ids, segments], verbose=True)



In [10]:
result.shape

(2, 64, 3072)

In [12]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        [(None, 64)]         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      [(None, 64)]         0                                            
__________________________________________________________________________________________________
Embedding-Token (TokenEmbedding [(None, 64, 768), (3 24580608    Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, 64, 768)      1536        Input-Segment[0][0]              
____________________________________________________________________________________________