In [1]:
import os
import numpy as np
import tensorflow as tf
from dataProcessing import load_file, encodeDataInfer, insert_target
from transformers import AutoTokenizer
from transformers import TFCamembertForMaskedLM
from datetime import datetime
import json
import sys

In [2]:
### instantiate the tokenizer
tokenizer = AutoTokenizer.from_pretrained("jplu/tf-camembert-base", do_lower_case=True)

In [3]:
### path to weights
checkpointPath = "Models/20200428_180956/cp-004.ckpt"

In [4]:
### punctuation decoder
punDec = {
    "0": "SPACE",
    "1": "PERIOD",
}

## Hyperparameters

In [5]:
n = 512

vocab_size = 32005
segment_size = 32
batch_size = 8
train_layer_ind = 0  # 0 for all model, -2 for only top layer
learat = 1e-4
num_epochs = 10

hyperparameters = {
    'vocab_size': vocab_size,
    'segment_size': segment_size,
    'learning_rate': learat,
    'batch_size': batch_size
}

## Get the dataset

In [6]:
# print('\nPRE-PROCESS AND PROCESS DATA')


# name of dataset with sentences
data_name = "Scriber"
infSet_01 = 'Data' + data_name + '/' + 'toyInfer_01.txt'
data = load_file(infSet_01)


X_ = encodeDataInfer(data, tokenizer)
X = insert_target(X_, segment_size)


# # get only a fraction of dataset
# X = X[0:n]


# instantiate tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices((X,)).batch(batch_size)

## Build the model

In [7]:
bert_input = tf.keras.Input(shape=(segment_size), dtype='int32', name='bert_input')
x = TFCamembertForMaskedLM.from_pretrained("jplu/tf-camembert-base")(bert_input)[0]
x = tf.keras.layers.Reshape((segment_size*vocab_size,))(x)
dense_out = tf.keras.layers.Dense(4)(x)

model = tf.keras.Model(bert_input, dense_out, name='CamemBERT')
# print(net.summary())

In [8]:
# load the weights
model.load_weights(checkpointPath)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fe2c75e04c0>

## Calculate predictions

In [9]:
feats = next(iter(dataset))
preds = np.argmax(model.predict(dataset), axis=1)

## Return the text with restored (inferred) punctuation

In [10]:
def restorePunctuation(X, preds, punDec, tokenizer, fileName):
    file = open(fileName, 'w')
    for i in range(len(preds)):
        word = tokenizer.convert_ids_to_tokens(X_[i])
        pun = punDec[str(preds[i])]
        file.write(word + " | " + pun + " \n")
    file.close()

In [11]:
restorePunctuation(X_, preds, punDec, tokenizer, 'textRestored.txt')