In [1]:
#This notebook uses the publicly available GPT-2 language model from OpenAI (https://github.com/openai/gpt-2) to
#rescore the candidate sentences generated by the kaldi bigram language model (Step 6). This improves performance because
#GPT-2 is a much more powerful language model than the bigram model. 

#To run this you will need to download the GPT-2 model (1558M version)

In [2]:
import tensorflow as tf
import os

#suppress all tensorflow warnings (largely related to compatability with v2)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

#point this towards the top level dataset directory
rootDir = os.path.expanduser('~') + '/handwritingBCIData/'

#point this towards the code directory
repoDir = os.getcwd() + '/'

#defines which train/test partition to use
cvPart = 'HeldOutTrials'

#point this to wherever you downloaded the 1558M GPT-2 model
gptModelDir = os.path.expanduser(os.path.expandvars('~/gpt-2/models'))

#defines which datasets to process
dataDirs = ['t5.2019.05.08','t5.2019.11.25','t5.2019.12.09','t5.2019.12.11','t5.2019.12.18',
            't5.2019.12.20','t5.2020.01.06','t5.2020.01.08','t5.2020.01.13','t5.2020.01.15']

#this prevents tensorflow from taking over more than one gpu on a multi-gpu machine
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]='0'

In [3]:
#load GPT-2
import json
import gpt2.model, gpt2.encoder

batch_size = 1
model_name = '1558M'

enc = gpt2.encoder.get_encoder(model_name, gptModelDir)
hparams = gpt2.model.default_hparams()
with open(os.path.join(gptModelDir, model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))

sess = tf.Session()

X = tf.placeholder(tf.int32, [batch_size, None])
logits = gpt2.model.model(hparams, X, past=None, scope='model', reuse=False)

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(gptModelDir, model_name))
saver.restore(sess, ckpt)

In [4]:
#Rescore the bigram language model's candidate sentences and choose the new best sentence.
#This takes a little while, as we are processing ALL sentences from all datasets (both the train AND test sentences).
#We also run each candidate sentence one at a time through GPT-2 
#(but could probably speed this up by running multiple through at the same time).

from kaldiReadWrite import readKaldiLatticeFile
import numpy as np
import scipy.io
from rnnEval import wer
import warnings

#this stops scipy.io.savemat from throwing a warning about empty entries
warnings.simplefilter(action='ignore', category=FutureWarning)

for dataDir in dataDirs:
    #process ALL sentences from this dataset (both train & test)
    print(' --' + dataDir + '-- ')

    sentenceDat = scipy.io.loadmat(rootDir+'Datasets/'+dataDir+'/sentences.mat')
    cvPartFile = scipy.io.loadmat(rootDir+'RNNTrainingSteps/trainTestPartitions_'+cvPart+'.mat')
    valIdx = cvPartFile[dataDir+'_test']
    
    kaldiDir = rootDir+'RNNTrainingSteps/Step6_ApplyBigramLM/'+cvPart+'/KaldiOutput/'+dataDir 
    nFiles = int(len(os.listdir(kaldiDir))/9)   
    
    decSentences = []
    allErrCounts = []
    
    for fileIdx in range(nFiles):
        #load kaldi bigram output
        nbestFile = kaldiDir+'/'+str(fileIdx)+'_transcript.txt'
        acFile = kaldiDir+'/'+str(fileIdx)+'_best_acscore.ark'
        lmFile = kaldiDir+'/'+str(fileIdx)+'_best_lmscore.ark'

        nums, content = readKaldiLatticeFile(nbestFile, 'string')
        _, acScore = readKaldiLatticeFile(acFile, 'numeric')
        _, lmScore = readKaldiLatticeFile(lmFile, 'numeric')
        
        #rescoring
        lmRescore = np.zeros(lmScore.shape)
        for y in range(len(content)):
            if content[y]=='':
                content[y] = ' '
                
            encText = enc.encode(content[y][0].upper() + content[y][1:])
            encText.insert(0,50256) #special 'endoftext' symbol
            encText.append(50256)

            out = sess.run(logits, feed_dict={X: [encText]})

            for x in range(out['logits'].shape[0]):
                for t in range(out['logits'].shape[1]):
                    out['logits'][x,t,:] = np.exp(out['logits'][x,t,:])/np.sum(np.exp(out['logits'][x,t,:]))
            out['logits'] = np.log(out['logits'])

            logSum = 0
            for x in range(1,out['logits'].shape[1]):
                logSum += out['logits'][0,x-1,encText[x]]

            lmRescore[y] = -logSum

        newBest = np.argmin(acScore + 2.0*lmRescore)
        decSent = content[newBest]
        decSentences.append(decSent)

        #compute char/word error rates
        trueText = sentenceDat['sentencePrompt'][fileIdx,0][0]
        trueText = trueText.replace('>',' ')
        trueText = trueText.replace('~','.')
        trueText = trueText.replace('#','')
        
        charErrs = wer(trueText, decSent)
        wordErrs = wer(trueText.split(), decSent.split())
        allErrCounts.append(np.array([charErrs, len(trueText), wordErrs, len(trueText.split())]))

        #print results from the held-out sentences
        if fileIdx in valIdx:
            print('#' + str(fileIdx))
            print('True:    ' + trueText)
            print('Decoded: ' + decSent)
            print('')

    #save error rates & decoded sentences
    concatCounts = np.stack(allErrCounts, axis=0)
    
    saveDict = {}
    saveDict['decSentences'] = decSentences
    saveDict['trueSentences'] = sentenceDat['sentencePrompt']
    saveDict['charCounts'] = concatCounts[:,1]
    saveDict['charErrors'] = concatCounts[:,0]
    saveDict['wordCounts'] = concatCounts[:,3]
    saveDict['wordErrors'] = concatCounts[:,2]

    scipy.io.savemat(rootDir + 'RNNTrainingSteps/Step7_GPT2Rescore/' + cvPart + '/' + dataDir + '_errCounts.mat', saveDict)

 --t5.2019.05.08-- 
#2
True:    you want me to sing?
Decoded: you want me to sing?

#5
True:    have you ever seen a large cat fold itself into a tiny shoe box?
Decoded: have you ever seen a large cat fold itself into a tiny shoe box?

#21
True:    the jeep was thirsty so i stopped for gas on the edge of town.
Decoded: the jeep was thirsty so i stopped for gas on the edge of town.

#33
True:    because less than a quarter mile from where i'm standing right now is where alla's body was dumped.
Decoded: because less than a quarter mile from where i'm standing right now is where alla's body was dumped.

#34
True:    there are moments when gentle background singing brings the song close to gospel.
Decoded: there are moments when gentle background singing brings the song close to gospel.

#39
True:    so you could literally see here's what every single person in the company gets paid.
Decoded: so you could literally see here's what every single person in the company gets paid.

#54
True:   

#121
True:    residents were warned to prepare flood defences.
Decoded: residents were warned to prepare flood defences.

 --t5.2020.01.15-- 
#22
True:    imagine a star with a mass ten times that of the sun.
Decoded: imagine a star with a mass ten times that of the sun.

#25
True:    the elder boys are not deterred, however.
Decoded: the elder boys are not deterred, however.

#121
True:    you never used to swear, you know.
Decoded: you never used to swear, you know.



In [5]:
#Summarize character/word error rate and word error rate across all sessions
allErrCounts = []

for dataDir in dataDirs:
    dat = scipy.io.loadmat(rootDir + 'RNNTrainingSteps/Step7_GPT2Rescore/' + cvPart + '/' + dataDir + '_errCounts.mat')
    cvPartFile = scipy.io.loadmat(rootDir+'RNNTrainingSteps/trainTestPartitions_'+cvPart+'.mat')
    valIdx = cvPartFile[dataDir+'_test']
    
    if len(valIdx)==0:
        continue
        
    valIdx = valIdx[0,:]
    allErrCounts.append(np.stack([dat['charCounts'][0,valIdx],
                         dat['charErrors'][0,valIdx],
                         dat['wordCounts'][0,valIdx],
                         dat['wordErrors'][0,valIdx]],axis=0).T)
    
concatErrCounts = np.squeeze(np.concatenate(allErrCounts, axis=0))
cer = 100*(np.sum(concatErrCounts[:,1]) / np.sum(concatErrCounts[:,0]))
wer = 100*(np.sum(concatErrCounts[:,3]) / np.sum(concatErrCounts[:,2]))

print('Character error rate: %1.2f%%' % float(cer))
print('Word error rate: %1.2f%%' % float(wer))    

Character error rate: 0.34%
Word error rate: 1.97%
