In [1]:
import os
import numpy as np
import tensorflow as tf
from dataProcessing import load_file, insert_target, encodeDataInfer
from transformers import BertTokenizer
from transformers import TFBertForMaskedLM
from datetime import datetime
import json
import sys

In [2]:
### instantiate the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [3]:
### path to weights
checkpoint_path = "Models/20200425_142515/cp-008.ckpt"

In [4]:
### punctuation decoder
punDec = {
    '0': "SPACE",
    '1': "COMMA",
    '2': "PERIOD",
    '3': "QUESTION"
}

In [5]:
n = 20
vocab_size = 30522
segment_size = 32
batch_size = 2

In [6]:
# name of dataset with sentences
data_name = "IWSLT12"
infSet_01 = 'Data' + data_name + '/' + 'extractInfer_01.txt'
data = load_file(infSet_01)


X_ = encodeDataInfer(data, tokenizer)
X = insert_target(X_, segment_size)


# # get only a fraction of dataset
# X = X[0:n]


# instantiate tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices((X,)).batch(batch_size)

In [7]:
print(X.shape)
print(type(X[0]))
print(X[0])

(8575, 32)
<class 'numpy.ndarray'>
[ 2017  3305  2054  1005  1055  2725  2009  2009  2987  1005  1056  2191
  2017  5191  2079     0  2023  5852  2080  2057  2024  2469  2009  2573
  2079  2009  2017  1005  2222  7523 15398  1998]


### Build the model

In [8]:
bert_input = tf.keras.Input(shape=(segment_size), dtype='int32', name='bert_input')
x = TFBertForMaskedLM.from_pretrained('bert-base-uncased')(bert_input)[0]
x = tf.keras.layers.Reshape((segment_size*vocab_size,))(x)
dense_out = tf.keras.layers.Dense(4)(x)


model = tf.keras.Model(bert_input, dense_out, name='BertModel')
# print(model.summary())

In [9]:
# load the weights
model.load_weights(checkpoint_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd61ccae820>

### Calculate predictions

In [10]:
# feats = next(iter(dataset))

In [11]:
preds = np.argmax(model.predict(dataset), axis=1)

In [12]:
print(preds.shape)
print(preds)
print(type(preds[0]))

(8575,)
[0 0 0 ... 0 0 2]
<class 'numpy.int64'>


In [16]:
np.where(preds==3)

(array([  10,  657, 1256, 1274, 1716, 2131, 2687, 2972, 2985, 3168, 3392,
        4088, 4179, 4185, 4191, 4193, 4195, 4196, 4199, 4201, 4207, 4267,
        4971, 5071, 5100, 5118, 5122, 5151, 5234, 5249, 5508, 5523, 5563,
        5842, 5999, 6035, 6144, 6306, 6382, 6470, 6477, 6740, 7225, 7248,
        7317, 7685]),)

In [14]:
# print(len(X_))
# print(X_)

### Return the text with restored (inferred) punctuation

In [None]:
def restorePunctuation(X, preds, punDec, tokenizer, fileName):
    file = open(fileName, 'w')
    for i in range(len(preds)):
        word = tokenizer.convert_ids_to_tokens(X_[i])
        pun = punDec[str(preds[i])]
        file.write(word + " | " + pun + " \n")
    file.close()

In [None]:
restorePunctuation(X_, preds, punDec, tokenizer, 'textRestored.txt')