# Restore Punctuation In An Unpunctuated Text

In [1]:
import os
import numpy as np
import tensorflow as tf

from utils import loadFile
from dataProcessing import encodeDataInfer, insertTarget, processingScriber, processingOPUS

from transformers import AutoTokenizer
from transformers import TFCamembertForMaskedLM
from datetime import datetime
import json
import sys

In [2]:
### instantiate the tokenizer
tokenizer = AutoTokenizer.from_pretrained("jplu/tf-camembert-base", do_lower_case=True)

In [3]:
### path to weights
# checkpointPath = "Models/20200530_161559/cp-001.ckpt"  # baseline model
checkpointPath = "Models/20200601_090641/cp-006.ckpt"

In [4]:
### punctuation decoder
punDec = {
    "0": "SPACE",
    "1": "PERIOD",
}

## Hyperparameters

In [5]:
vocab_size = 32005
segment_size = 32
batch_size = 1

## Get The Dataset

In [6]:
# name of dataset with sentences
dataSetName = "./DataScriber/raw.processed.Valid_01_extractNoPun.txt"

data = loadFile(dataSetName)

X_ = encodeDataInfer(data, tokenizer)
X = insertTarget(X_, segment_size)

# ### Get Only A Fraction Of Dataset
# n = 320
# X = X[0:n]

# instantiate tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices((X,)).batch(batch_size)

In [7]:
print("Length Of X_ = ", len(X_))
print("Shape Of X   = ", X.shape)

Length Of X_ =  9155
Shape Of X   =  (9155, 32)


<b>### Build The Model

In [8]:
bert_input = tf.keras.Input(shape=(segment_size), dtype='int32', name='bert_input')
x = TFCamembertForMaskedLM.from_pretrained("jplu/tf-camembert-base")(bert_input)[0]
x = tf.keras.layers.Reshape((segment_size*vocab_size,))(x)
dense_out = tf.keras.layers.Dense(len(punDec))(x)
model = tf.keras.Model(bert_input, dense_out, name='CamemBERT')

All model checkpoint weights were used when initializing TFCamembertForMaskedLM.

All the weights of TFCamembertForMaskedLM were initialized from the model checkpoint at jplu/tf-camembert-base.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use TFCamembertForMaskedLM for predictions without further training.


In [9]:
# load the weights
model.load_weights(checkpointPath)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd6b191ac40>

### Calculate Predictions

In [10]:
# feats = next(iter(dataset))
preds = np.argmax(model.predict(dataset), axis=1)

In [11]:
print(len(preds))

9155


### Return The Text With Restored (Inferred) Punctuation

In [12]:
def restorePunctuation(X, preds, punDec, tokenizer, fileName):
    file = open(fileName, 'w')
    for i in range(len(preds)):
        word = tokenizer.convert_ids_to_tokens(X_[i])
        pun = punDec[str(preds[i])]
        file.write(word + " | " + pun + " \n")
    file.close()

In [13]:
restorePunctuation(X_, preds, punDec, tokenizer, 'textRestored_01.txt')