# ASR Punctuation Restoration Experiments
By Bart Pleiter S4752740 for the course ASR 2021-2022

In [None]:
# imports
import pandas as pd
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow.keras import mixed_precision
import tensorflow_addons as tfa

from transformers import BertTokenizer
from transformers import TFBertForMaskedLM

from tqdm import tqdm

import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

In [None]:
# use mixed precision to speed up training on my RTX3050
mixed_precision.set_global_policy('mixed_float16')

memoryLimit = 7000 # 7.0GB

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=memoryLimit)])
    except RuntimeError as e:
        print(e)

In [None]:
# hyperparameters
batchSize = 64
numLabels = 4
segmentSize = 32 # MUST be an even number
dropout = 0.3
learningRate = 1e-5
epochs = 3
vocabSize = 30522 # of the BERT model

# labels
LABEL_NOTHING = 0
LABEL_COMMA = 1
LABEL_PERIOD = 2
LABEL_QUESTION = 3

labelNames = ["O", "COMMA", "PERIOD", "QUESTION"]

# encode the punctuation label as a number
punctEncode = {
    'O': LABEL_NOTHING,
    'COMMA': LABEL_COMMA,
    'PERIOD': LABEL_PERIOD,
    'QUESTION': LABEL_QUESTION
}

# decode the label for printing purposes
punctDecode = {v: k for k, v in punctEncode.items()}

# tokenizer that was used to train the pre-trained transformer model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
def insertTarget(x, segment_size):
    # creates segments of surrounding words for each word in x.
    # inserts a zero token ([PAD]) halfway the segment, right after the inserted token.
    # for the first segmentSize/2 tokens, the end of the data is used, as if the text loops
    X = []
    x_pad = x[-((segment_size-1)//2-1):]+x+x[:segment_size//2]

    for i in range(len(x_pad)-segment_size+2):
        segment = x_pad[i:i+segment_size-1]
        segment.insert((segment_size-1)//2, 0)
        X.append(segment)

    return np.array(X)

def encodeRawText(text):
    # splits the text on spaces and creates labels for each created token
    # the resulting token list will have no punctuation anymore
    splitOnSpace = text.split(' ')
    X = []
    Y = []
    
    for word in splitOnSpace:
        if len(word) > 0: # skip empty tokens
            # look for tokens at the end of the word
            # remove if found when appending to X
            # also make everything lowercase
            if word[-1] == '.':
                X.append(word.lower()[:-1])
                Y.append(LABEL_PERIOD)
            elif word[-1] == ',':
                X.append(word.lower()[:-1])
                Y.append(LABEL_COMMA)
            elif word[-1] == '?':
                X.append(word.lower()[:-1])
                Y.append(LABEL_QUESTION)
            else:
                X.append(word.lower())
                Y.append(LABEL_NOTHING)
                
    return X, Y

def prepareDataForModel(words, labels, tokenizer):
    # returns a list of segments of token IDs for X
    #  and a list of label tokens for y, corresponding to the segment in X
    #  and a list of the unsegmented token IDs, for easier reconstruction
    X = []
    Y = []
    for word, label in zip(words, labels):
        y = [label]
        # retokenize x
        x = tokenizer.wordpiece_tokenizer.tokenize(word)
        # encode x
        x = tokenizer.convert_tokens_to_ids(x)

        # do not add if tokenize failed
        if len(x) > 0:
            # if multiple tokens, create multiple labels of 0
            #  set the last one to the real label
            if len(x) > 1:
                y = (len(x)-1)*[0]+y
            X += x
            Y += y
                
    # create segments for X, and return together with Y, and the unsegmented tokens
    # return as Numpy array
    return np.array(insertTarget(X, segmentSize)), np.array(Y), X

def getTokenFromSegment(segment):
    # is always at the same place of the segment (assumes even numbers!)
    return segment[segmentSize//2-2]

def reconstructText(tokenList, labels, tokenizer):
    # reconstructs text by detokenizing and applying the given labels
    tokens = tokenizer.convert_ids_to_tokens(tokenList)
    
    reconstructedText = ""
    for tok, label in zip(tokens, labels):
        # no space in between if second token starts with '##'
        if tok.startswith("##"):
            reconstructedText += tok[2:]
        else:
            reconstructedText += " " + tok
        
        # add the punctuation from the label
        if label == LABEL_COMMA:
            reconstructedText += ","
        elif label == LABEL_PERIOD:
            reconstructedText += "."
        elif label == LABEL_QUESTION:
            reconstructedText += "?"
                
    reconstructedText = reconstructedText[1:] # skip the first space
    
    return reconstructedText
    
def loadDataFromFile(path):
    # creates a dataset from a text file, which can be used on the model
    # returns a list of segments of token IDs for X
    #  and a list of label tokens for y, corresponding to the segment in X
    #  and a list of the unsegmented token IDs, for easier reconstruction
    text = ""
    with open(path, 'r') as file:
        text = file.read()
        
    X, Y = encodeRawText(text)
    dataX, dataY, dataTokens = prepareDataForModel(X, Y, tokenizer)
    
    return dataX, dataY, dataTokens

def loadDataFromTEDtalkDataset(path, tokenizer):
    # loads the data from the TED talk dataset, which is already pre-processed
    # returns a list of segments of token IDs for X
    #  and a list of label tokens for y, corresponding to the segment in X
    #  and a list of the unsegmented token IDs, for easier reconstruction
    X = []
    Y = []
    with open(path, "rb") as file:
        for line in file:
            # dataset uses \r\n for newlines
            word, punc = line.decode('utf-8', errors='ignore').replace('\r\n', '').split('\t')
            # encode y
            y = [punctEncode[punc]]
            # retokenize x
            x = tokenizer.wordpiece_tokenizer.tokenize(word)
            # encode x
            x = tokenizer.convert_tokens_to_ids(x)
            
            # do not add if tokenize failed
            if len(x) > 0:
                # if multiple tokens, create multiple labels of 0
                #  set the last one to the real label
                if len(x) > 1:
                    y = (len(x)-1)*[0]+y
                X += x
                Y += y
                
    # create segments for X, and return together with Y, and the unsegmented tokens
    # return as Numpy array
    return np.array(insertTarget(X, segmentSize)), np.array(Y), X

def predictText(text, model, tokenizer):
    # predicts the punctuation for a given text and print the resulting text
    
    # pre-process
    X, Y = encodeRawText(text)
    dataX, dataY, dataTokens = prepareDataForModel(X, Y, tokenizer)
    
    # get results
    results = model.predict(dataX)
    
    # select best class for each token
    resultY = results.argmax(axis=1)
    
    # reconstruct text with predicted tokens
    resText = reconstructText(dataTokens, resultY, tokenizer)
    
    print(resText)

In [None]:
# evaluates the predicted results
def evaluateResults(Ytrue, Ypred):    
    print(classification_report(Ytrue, Ypred, target_names=labelNames))
    
    confusionMatrix = confusion_matrix(Ytrue, Ypred)
    disp = ConfusionMatrixDisplay(confusion_matrix=confusionMatrix, display_labels=labelNames)
    print("Confusion matrix:")
    disp.plot()
    plt.show()

In [None]:
# construct the network
bert_input = tf.keras.Input(shape=(segmentSize), dtype='int32', name='bert_input')
x = TFBertForMaskedLM.from_pretrained("bert-base-uncased")(bert_input)[0]
print('dtype BERT layer: %s' % x.dtype.name)
x = tf.keras.layers.Reshape((segmentSize*vocabSize,))(x)
print('dtype reshape layer: %s' % x.dtype.name)
x = tf.keras.layers.Dropout(dropout, name="dropout")(x)
print('dtype dropout layer: %s' % x.dtype.name)
x = tf.keras.layers.Dense(len(punctEncode), name='dense')(x)
print('dtype dense layer: %s' % x.dtype.name)
# dtype float32 because of mixed precision
dense_out = tf.keras.layers.Activation('softmax', dtype='float32', name='softmax')(x)
print('dtype softmax layer: %s' % dense_out.dtype.name)

model = tf.keras.Model(bert_input, dense_out, name='model')

In [None]:
# plot the model architecture
tf.keras.utils.plot_model(model, show_shapes=True, dpi=48)

In [None]:
# loading the data for training
Xtrain, Ytrain, _ = loadDataFromTEDtalkDataset("IWSLT2012data/train2012", tokenizer)

In [None]:
Xval, Yval, valTokens = loadDataFromTEDtalkDataset("IWSLT2012data/dev2012", tokenizer)

In [None]:
# compile the network
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learningRate),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

In [None]:
# training the data
history = model.fit(
    Xtrain,
    Ytrain,
    epochs=epochs,
    batch_size = batchSize,
    validation_data=(Xval, Yval)
)

In [None]:
# show the fitting history
history.history

# plot the loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

# Validation set

In [None]:
# generate output on validation data
valPredict = model.predict(Xval)

# select best class for each token
resultYval = valPredict.argmax(axis=1)

In [None]:
evaluateResults(Yval, resultYval)

# Test set

In [None]:
# loading the data for testing
Xtest, Ytest, testTokens = loadDataFromTEDtalkDataset("IWSLT2012data/test2011", tokenizer)
Xtestasr, Ytestasr, testasrTokens = loadDataFromTEDtalkDataset("IWSLT2012data/test2011asr", tokenizer)

In [None]:
# generate output on test data
testPredict = model.predict(Xtest)
testPredictasr = model.predict(Xtestasr)

# select best class for each token
resultYtest = testPredict.argmax(axis=1)
resultYtestasr = testPredictasr.argmax(axis=1)

In [None]:
print("Test set:")
evaluateResults(Ytest, resultYtest)

print("Test ASR set:")
evaluateResults(Ytestasr, resultYtestasr)

In [None]:
# evaluate on test data
#resultsTest = model.evaluate(Xtest, Ytest, batch_size=batchSize)
#print("test loss, test acc:", resultsTest)

#resultsTestasr = model.evaluate(Xtest, Ytest, batch_size=batchSize)
#print("test (asr) loss, test (asr) acc:", resultsTest)

In [None]:
# show the restoration on the test sets
# get results
#restoreTest = model.predict(Xtest)
#restoreTestasr = model.predict(Xtestasr)

# select best class for each token
#resultYtest = restoreTest.argmax(axis=1)
#resultYtestasr = restoreTestasr.argmax(axis=1)

# reconstruct text with predicted tokens
#resTextTest = reconstructText(testTokens, resultYtest, tokenizer)
#resTextTestasr = reconstructText(testasrTokens, resultYtestasr, tokenizer)

In [None]:
#print(resTextTest)

In [None]:
#print(resTextTestasr)

In [None]:
# example restoration on paragraph from a story
#toPredict = "Then he sat down and began to reflect. In the morning he must find seconds. Whom should he choose? He searched his mind for the most important and celebrated names of his acquaintance. At last he decided on the Marquis de la Tour-Noire and Colonel Bourdin, an aristocrat and a soldier; they would do excellently. Their names would look well in the papers. He realised that he was thirsty, and drank three glasses of water one after the other; then he began to walk up and down again. He felt full of energy. If he played the gallant, showed himself determined, insisted on the most strict and dangerous arrangements, demanded a serious duel, a thoroughly serious duel, a positively terrible duel, his adversary would probably retire an apologist."
#predictText(toPredict, model, tokenizer)