# Punctuation Restoration

#### Kai Luo

# Introduction

## Punctuation Restoration Problem

Imagine that you are building a software for transcribing speech to text. The speech transcription part works perfectly, but cannot transcribe punctuations. The task is to train a predictive model to ingest a sequence of text and add punctuation (period, comma or question mark) in the appropriate locations. This task is important for all downstream data processing jobs. <br /><br />

#### Example input:
`this is a string of text with no punctuation this is a new sentence`

<br />

#### Example output:
`this is a string of text with no punctuation <period> this is a new sentence <period>`


## Requirements

- Input is text with punctuations removed.
- The training data is case insensitive.
- The model only prdicts period, comma or question mark.

## Assumptions

- Punctuation restoration is performed on a batch of transcription.
- Output includes punctuation without annotations.
- ASR output does not contain broken sentence.
- The model is only trained on outputting period, comma or question mark.

## Solution Overview

Although sentence must adhere to specific grammar rules, one might try to build a rule-based system for punctuation restoration. However, there are three weaknesses if one wants to do this:
1. A common problem is that most people don't follow grammar rules when speaking.
2. Writing a rule-based system for punctuation restoration requires prior knowledge on specific grammar rules for a language.
3. The approach of building a rule-based system for punctuation restoration for a specific language can not generalize to other languages.

Therefore, the following solution solves the above problems by incorporating [bidirectional recurrent neural networks (BRNN)](https://deeplearning.cs.cmu.edu/S20/document/readings/Bidirectional%20Recurrent%20Neural%20Networks.pdf) with an [attention mechanism](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) for punctuation restoration in unsegmented text.

My solution is largely based on [Bidirectional Recurrent Neural Network with Attention Mechanism for Punctuation Restoration](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/1517.PDF).

The architecture is defined as follows:
1. Obtain words embeddings from [GloVe](https://nlp.stanford.edu/projects/glove/).
2. The word embeddings are then processed by densely connected [Bi-LSTM](https://arxiv.org/pdf/1303.5778.pdf) layers.
3. These Bi-LSTM layers are followed by a RNN with an attention mechanism and [conditional random field (CRF)](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers) log likelihood loss.

The reasons are as follows:
1. GloVe obtains vector representations for words. It enforces the word vectors to capture sub-linear relationships in the vector space.
2. BRNN enables the model to make use of unfixed length contexts before and after the current position in text. In the recurrent layers, I use [LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf) which is well suited for capturing long range dependencies on multiple time scales.
3. An attention mechanism further increases the model's capacity of finding relevant parts of the context for punctuation decisions. For example, the model might focus on words that indicate a question, but may be relatively far from the current word. An attention mechanism could nudge the model towards ending the sentence with a question mark instead of a period.

 

# Import Settings, Libraries, Configurations and Datasets

## Settings

In [1]:
import warnings
warnings.filterwarnings('ignore')

## Libraries

In [2]:
import codecs
import os
import tensorflow as tf
import time

from models.punctuation_restoration_model import PunctuationRestorationModel
from utils.data_utils import split_to_batches
from utils.data_utils import inhour
from utils.punctuation_preprocess import process_data

## Configurations

In [3]:
# dataset parameters
tf.flags.DEFINE_string("raw_path", "data/raw/LREC_converted", "path to raw dataset")
tf.flags.DEFINE_string("save_path", "data/dataset/lrec", "path to save dataset")
tf.flags.DEFINE_string("glove_name", "840B", "glove embedding name")

# glove word embedding
tf.flags.DEFINE_string("glove_path", "data/external/embeddings/glove.{}.{}d.txt", "glove embedding path")
tf.flags.DEFINE_integer("max_vocab_size", 50000, "maximal vocabulary size")
tf.flags.DEFINE_integer("max_sequence_len", 200, "maximal sequence length allowed")
tf.flags.DEFINE_integer("min_word_count", 1, "minimal word count in word vocabulary")

# dataset for train, validation and test
tf.flags.DEFINE_string("vocab", "data/dataset/lrec/vocab.json", "path to the word and tag vocabularies")
tf.flags.DEFINE_string("train_word_counter", "data/dataset/lrec/train_word_counter.json", "path to the word counter "
                                                                                          "in training datasets")
tf.flags.DEFINE_string("train_punct_counter", "data/dataset/lrec/train_punct_counter.json", "path to the punctuation "
                                                                                            "counter in traning "
                                                                                            "datasets")
tf.flags.DEFINE_string("dev_word_counter", "data/dataset/lrec/dev_word_counter.json", "path to the word counter in "
                                                                                      "development datasets")
tf.flags.DEFINE_string("dev_punct_counter", "data/dataset/lrec/dev_punct_counter.json", "path to the punctuation "
                                                                                        "counter in development "
                                                                                        "datasets")
tf.flags.DEFINE_string("ref_word_counter", "data/dataset/lrec/ref_word_counter.json", "path to the word counter in "
                                                                                      "ref test datasets")
tf.flags.DEFINE_string("ref_punct_counter", "data/dataset/lrec/ref_punct_counter.json", "path to the punctuation "
                                                                                        "counter in ref test datasets")
tf.flags.DEFINE_string("asr_word_counter", "data/dataset/lrec/asr_word_counter.json", "path to the word counter in "
                                                                                      "asr test datasets")
tf.flags.DEFINE_string("asr_punct_counter", "data/dataset/lrec/asr_punct_counter.json", "path to the punctuation "
                                                                                        "counter in asr test datasets")
tf.flags.DEFINE_string("train_set", "data/dataset/lrec/train.json", "path to the training datasets")
tf.flags.DEFINE_string("dev_set", "data/dataset/lrec/dev.json", "path to the development datasets")
tf.flags.DEFINE_string("dev_text", "data/raw/LREC_converted/dev.txt", "path to the development text")
tf.flags.DEFINE_string("ref_set", "data/dataset/lrec/ref.json", "path to the ref test datasets")
tf.flags.DEFINE_string("ref_text", "data/raw/LREC_converted/ref.txt", "path to the ref text")
tf.flags.DEFINE_string("asr_set", "data/dataset/lrec/asr.json", "path to the asr test datasets")
tf.flags.DEFINE_string("asr_text", "data/raw/LREC_converted/asr.txt", "path to the asr text")
tf.flags.DEFINE_string("pretrained_emb", "data/dataset/lrec/glove_emb.npz", "pretrained embeddings")

# network parameters
tf.flags.DEFINE_string("cell_type", "lstm", "RNN cell for encoder and decoder: [lstm | gru], default: lstm")
tf.flags.DEFINE_integer("num_layers", 4, "number of rnn layers")
tf.flags.DEFINE_multi_integer("num_units_list", [50, 50, 50, 300], "number of units for each rnn layer")
tf.flags.DEFINE_boolean("use_pretrained", True, "use pretrained word embedding")
tf.flags.DEFINE_boolean("tuning_emb", False, "tune pretrained word embedding while training")
tf.flags.DEFINE_integer("emb_dim", 300, "embedding dimension for encoder and decoder input words/tokens")
tf.flags.DEFINE_boolean("use_highway", True, "use highway network")
tf.flags.DEFINE_integer("highway_layers", 2, "number of layers for highway network")
tf.flags.DEFINE_boolean("use_crf", True, "use CRF decoder")

# training parameters
tf.flags.DEFINE_float("lr", 0.001, "learning rate")
tf.flags.DEFINE_string("optimizer", "adam", "optimizer: [adagrad | sgd | rmsprop | adadelta | adam], default: adam")
tf.flags.DEFINE_boolean("use_lr_decay", True, "apply learning rate decay for each epoch")
tf.flags.DEFINE_float("lr_decay", 0.05, "learning rate decay factor")
tf.flags.DEFINE_float("l2_reg", None, "L2 norm regularization")
tf.flags.DEFINE_float("minimal_lr", 1e-5, "minimal learning rate")
tf.flags.DEFINE_float("grad_clip", 2.0, "maximal gradient norm")
tf.flags.DEFINE_float("keep_prob", 0.75, "dropout keep probability for embedding while training")
tf.flags.DEFINE_integer("batch_size", 32, "batch size")
tf.flags.DEFINE_integer("epochs", 5, "train epochs")
tf.flags.DEFINE_integer("max_to_keep", 3, "maximum trained models to be saved")
tf.flags.DEFINE_integer("no_imprv_tolerance", 10, "no improvement tolerance")
tf.flags.DEFINE_string("checkpoint_path", "ckpt/punctuator/", "path to save models checkpoints")
tf.flags.DEFINE_string("summary_path", "ckpt/punctuator/summary/", "path to save summaries")
tf.flags.DEFINE_string("model_name", "punctuation_restoration_model", "models name")

# convert parameters to dict
config = tf.flags.FLAGS.flag_values_dict()

## Datasets

I test the model on the IWSLT datasets, which were originally used to evaluate ASR or SLT output. IWSLT datasets consist of TED talks which are openly available online. I used the same training, development and test set to train and test the model as the models published in [Punctuation Prediction for Unsegmented Transcript Based on Word Vector](https://hpi.de/fileadmin/user_upload/fachgebiete/meinel/papers/2016_Che_LREC.pdf) and [Bidirectional Recurrent Neural Network with Attention Mechanism for Punctuation Restoration](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/1517.PDF). This helps in providing accuracy comparison on the model in a later section. The training and development set come from the IWSLT2012 machine trainslation track training data. IWSLT2011 punctuated reference transcripts and unpunctuated but segmented ASR transcripts are used for testing.  The detail about each dataset can be found in the Exploratory Data Analysis section.

### Preprocess Datasets

In preprocessing, I did the following steps:
1. I follow the same procedure as [Bidirectional Recurrent Neural Network with Attention Mechanism for Punctuation Restoration](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/1517.PDF) to deal with other punctuation symbols. They are either mapped to one of the punctuations in the output vocabulary or removed from corpora. To be more specific, exclamation marks and semicolons are mapped to periods, while colons and dashes are mapped to commas.
2. I build vocabulary for each word that occurrs in the train set.
3. I chunk the text into smaller observations.

In [4]:
# preprocess data from raw data files
if not os.path.exists(config["save_path"]) or not os.listdir(config["save_path"]):
    process_data(config)
if not os.path.exists(config["pretrained_emb"]) and config["use_pretrained"]:
    process_data(config)

### Load Dataset

In [5]:
# used for training
train_set = split_to_batches(config["train_set"], config["batch_size"], shuffle=True)
# used for computing validate loss
valid_data = split_to_batches(config["dev_set"], config["batch_size"], shuffle=True)[0]
valid_text = config["dev_text"]
# used for evaluation metrics
test_texts = [config["ref_text"], config["asr_text"]]

# Exploratory Data Analysis

In [6]:
from utils.data_utils import load_dataset
import pandas as pd

train_words_counter = load_dataset('data/dataset/lrec/train_word_counter.json')
dev_words_counter = load_dataset('data/dataset/lrec/dev_word_counter.json')
ref_words_counter = load_dataset('data/dataset/lrec/ref_word_counter.json')
asr_words_counter = load_dataset('data/dataset/lrec/asr_word_counter.json')

words_counter = [{'Dataset': 'Train', 'Total Number of Words': sum(train_words_counter.values()), 'Number of Unique Words': len(train_words_counter)},
                 {'Dataset': 'Dev', 'Total Number of Words': sum(dev_words_counter.values()), 'Number of Unique Words': len(dev_words_counter)},
                 {'Dataset': 'Ref', 'Total Number of Words': sum(ref_words_counter.values()), 'Number of Unique Words': len(ref_words_counter)},
                 {'Dataset': 'ASR', 'Total Number of Words': sum(asr_words_counter.values()), 'Number of Unique Words': len(asr_words_counter)}]

word_df = pd.DataFrame(words_counter)

In [7]:
train_punct_counter = load_dataset('data/dataset/lrec/train_punct_counter.json')
dev_punct_counter = load_dataset('data/dataset/lrec/dev_punct_counter.json')
ref_punct_counter = load_dataset('data/dataset/lrec/ref_punct_counter.json')
asr_punct_counter = load_dataset('data/dataset/lrec/asr_punct_counter.json')

In [8]:
punct_df = pd.concat([pd.DataFrame([train_punct_counter]),pd.DataFrame([dev_punct_counter]),pd.DataFrame([ref_punct_counter]),pd.DataFrame([asr_punct_counter])],ignore_index=True)
punct_df.columns = ['Number of Comma', 'Number of Period', 'Number of Question Mark']

In [9]:
words_punct_df = pd.concat([word_df, punct_df], axis=1)
words_punct_df

Unnamed: 0,Dataset,Total Number of Words,Number of Unique Words,Number of Comma,Number of Period,Number of Question Mark
0,Train,2089286,44514,158369,132330,9901
1,Dev,293677,16150,22449,18910,1515
2,Ref,12539,2317,830,806,46
3,ASR,12822,2317,798,809,35


**Observation:**
- Train set contains the most data.
- Both test sets (Ref and ASR) contain similar data.
- It is reasonalble that commas are used the most frequently when people talk. Question marks are used the least.

# Model Building

## Hyperparameters

For hyperparameters tuning, I could use [NNI (Neural Network Intelligence)](https://github.com/microsoft/nni) to tune hyperparameters (shown in Configurations section) in an efficient and automatic way.

## Model Training and Evaluation

In [10]:
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

model = PunctuationRestorationModel(config)

# if the model is already trained, restore the model
# model.restore_last_session()

model.train(train_set, valid_data, valid_text, test_texts)

word embedding shape: [None, None, 300]
densely connected bi_rnn output shape: [None, None, 600]


attention output shape: [None, None, 300]
logits shape: [None, None, 4]
params number: 5049620


Start training...
Epoch 1/5:





Evaluate on data/raw/LREC_converted/ref.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          56.21     38.19     45.48    
<period>         70.67     61.37     65.69    
<questionmark>   47.62     21.74     29.85    
----------------------------------------------
Overall          63.94     48.84     55.38    
ERR: 8.14%
SER: 61.2%


Evaluate on data/raw/LREC_converted/asr.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          48.11     33.58     39.56    
<period>         67.63     57.67     62.26    
<questionmark>   52.63     28.57     37.04    
----------------------------------------------
Overall          58.81     45.34     51.20    
ERR: 9.01%
SER: 70.4%

 -- new BEST score on ref dataset: 55.38, on asr dataset: 51.20
Epoch 2/5:





Evaluate on data/raw/LREC_converted/ref.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          61.63     45.66     52.46    
<period>         71.09     73.91     72.47    
<questionmark>   66.67     60.87     63.64    
----------------------------------------------
Overall          67.07     59.61     63.12    
ERR: 7.0%
SER: 52.6%


Evaluate on data/raw/LREC_converted/asr.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          51.45     42.36     46.46    
<period>         66.46     66.71     66.58    
<questionmark>   50.00     40.00     44.44    
----------------------------------------------
Overall          59.56     54.30     56.81    
ERR: 8.41%
SER: 65.8%

 -- new BEST score on ref dataset: 63.12, on asr dataset: 56.81
Epoch 3/5:





Evaluate on data/raw/LREC_converted/ref.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          60.41     56.27     58.27    
<period>         74.00     71.06     72.50    
<questionmark>   77.14     58.70     66.67    
----------------------------------------------
Overall          67.43     63.41     65.36    
ERR: 6.6%
SER: 49.6%


Evaluate on data/raw/LREC_converted/asr.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          48.84     50.25     49.54    
<period>         68.10     64.48     66.24    
<questionmark>   58.33     40.00     47.46    
----------------------------------------------
Overall          58.14     57.04     57.58    
ERR: 8.45%
SER: 66.1%

 -- new BEST score on ref dataset: 65.36, on asr dataset: 57.58
Epoch 4/5:





Evaluate on data/raw/LREC_converted/ref.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          60.89     57.59     59.20    
<period>         73.83     72.55     73.18    
<questionmark>   77.78     60.87     68.29    
----------------------------------------------
Overall          67.62     64.84     66.20    
ERR: 6.46%
SER: 48.6%


Evaluate on data/raw/LREC_converted/asr.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          48.25     51.75     49.94    
<period>         67.69     65.59     66.62    
<questionmark>   62.96     48.57     54.84    
----------------------------------------------
Overall          57.62     58.50     58.06    
ERR: 8.45%
SER: 66.1%

 -- new BEST score on ref dataset: 66.20, on asr dataset: 58.06
Epoch 5/5:





Evaluate on data/raw/LREC_converted/ref.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          61.80     57.11     59.36    
<period>         72.69     72.42     72.56    
<questionmark>   75.68     60.87     67.47    
----------------------------------------------
Overall          67.56     64.54     66.02    
ERR: 6.53%
SER: 49.1%


Evaluate on data/raw/LREC_converted/asr.txt:
----------------------------------------------
PUNCTUATION      PRECISION RECALL    F-SCORE  
<comma>          48.88     52.01     50.39    
<period>         68.72     67.70     68.20    
<questionmark>   56.67     48.57     52.31    
----------------------------------------------
Overall          58.45     59.66     59.05    
ERR: 8.39%
SER: 65.6%



#### Observation:
- From the Exploratory Data Analysis section, it is obvious that most words are not followed by any punctuation marks. The punctuation samples for the majority of these words can be succeessfully classified by the model and make the general accuracy (general accuracy = 1 - ERR) of the classification beyond 90%.
- In pupose of predicting puncutation marks, we care more about the performance on predicting period, comma and question mark. Therefore, the words that are not followed by any punctuation marks are ignored. The model is evaluated in terms of per punctuation and overall precision, recall and F1 score. I also report the overall [slot error rate (SER)](http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=46A91ADB72135D6B2D6589AA5A1D6AD3?doi=10.1.1.27.4637&rep=rep1&type=pdf), as F1 score has been shown to have some undesireable properties in [Performance Measures for Information Extraction](http://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=46A91ADB72135D6B2D6589AA5A1D6AD3?doi=10.1.1.27.4637&rep=rep1&type=pdf).
- There are three types of errors: substitution, deletion and insertion.
- Precision measures the percentage of correctly predicted punctuation marks in all predicted punctuation marks. It deals with substitution and insertion errors. Higher is better.
- Recall measures the percentage of correctly predicted punctuation marks in all expected punctuation marks. It deals with substitution and deletion errors. Higher is better.
- F1 score and SER are used to have a single measure of performance that deals with all three types of errors simultaneously. For F1 score, higher is better. In contrast, lower SER is better.
- The model reaches the best performance on both reference transcripts and ASR output in 5 epochs.
- The result shows that comma restoration is a harder task than period and question mark restoration. I think it is normal because the grammatical ambiguity of a pause is generally higher than a full-stop, espacially in less formal text like the transcript of TED talks.
- The result shows that punctuation restoration is a harder task on ASR output than reference transcripts. I think it is normal because there are likely more noise and errors introduced in ASR output.
- As we can see, the model achieves comparable performance comparing to the models published in [Punctuation Prediction for Unsegmented Transcript Based on Word Vector](https://hpi.de/fileadmin/user_upload/fachgebiete/meinel/papers/2016_Che_LREC.pdf) and [Bidirectional Recurrent Neural Network with Attention Mechanism for Punctuation Restoration](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/1517.PDF).

# Case Study

#### Input Text 1:
`this is a string of text with no punctuation this is a new sentence`

In [11]:
sentence = 'this is a string of text with no punctuation this is a new sentence'
start_time = time.time()
print(model.inference(sentence))
print("Inference Running Time: {0}".format(inhour(time.time() - start_time)))

this is a string of text with no punctuation <period> this is a new sentence <period>
Inference Running Time: 00:00:00


#### Input Text 2:
`hi this is wealthsimple customer service my name is john how can i help you today`

In [12]:
sentence = 'hi this is wealthsimple customer service my name is john how can i help you today'
start_time = time.time()
print(model.inference(sentence))
print("Inference Running Time: {0}".format(inhour(time.time() - start_time)))

hi <comma> this is wealthsimple customer service <period> my name is john <period> how can i help you today <questionmark>
Inference Running Time: 00:00:00


#### Input Text 3:
`our purpose is to make sure everyone has the ability to exercise that right it is great to wave a flag and talk about power to the people but how do you actually do it`

In [13]:
sentence = 'our purpose is to make sure everyone has the ability to exercise that right it is great to wave a flag and talk about power to the people but how do you actually do it'
start_time = time.time()
print(model.inference(sentence))
print("Inference Running Time: {0}".format(inhour(time.time() - start_time)))

our purpose is to make sure everyone has the ability to exercise that right <period> it is great to wave a flag and talk about power to the people <period> but how do you actually do it <questionmark>
Inference Running Time: 00:00:00


#### Input Text 4:
`account minimums are a barrier to entry that is why we do not have them that is why we make our fees as low as possible we use technology and force of will to bring down costs for our clients so they have control of more of their money`

In [14]:
sentence = 'account minimums are a barrier to entry that is why we do not have them that is why we make our fees as low as possible we use technology and force of will to bring down costs for our clients so they have control of more of their money'
start_time = time.time()
print(model.inference(sentence))
print("Inference Running Time: {0}".format(inhour(time.time() - start_time)))

account minimums are a barrier to entry <period> that is why we do not have them <period> that is why we make our fees as low as possible <period> we use technology and force of will to bring down costs for our clients <period> so they have control of more of their money <period>
Inference Running Time: 00:00:00


# Conclusions

## Future Improvements

1. I would try to add and train a classification layer on [pre-trained BERT](https://github.com/huggingface/transformers) for punctuation restoration if I have more time for the following reasons:
    - BERT uses transformers architecture of neural network whereas the current approach uses LSTM.
    - Pre-trained BERT has more capacity and was trained a much larger corpus.
2. I would like to have the model to take in more inputs such as pauses during conversations. In this way, the model can levarage more rich information in training and prediction.
3. In the current approach, the word embeddings are fixed the whole time. I would try to allow the pretrained word embeddings to be tuned during training.
4. According to [Character-Word LSTM Language Models](https://www.aclweb.org/anthology/E17-1040.pdf), character information can reveal structural (dis)similarities between words and can even be used when a word is out-of-vocabulary, thus improving the modeling of infrequent and unknown words. I would try to concatenate word and character embeddings  and
feed the resulting character-word embedding to the LSTM in order to build a richer sentence representation.