In [29]:
import re
import time
import pickle
import numpy as np

from edit_distance import SequenceMatcher
import torch
from dataset import SpeechDataset
import matplotlib.pyplot as plt

from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.model import GRUDecoder
import pickle
import argparse
import matplotlib.pyplot as plt
from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.lm_utils import build_llama_1B
from neural_decoder.model import GRUDecoder
from neural_decoder.bit import BiT_Phoneme
import pickle
import argparse
from lm_utils import _cer_and_wer
import json
import os
import copy

In [30]:
def convert_sentence(s):
    s = s.lower()
    charMarks = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
                 "'", ' ']
    ans = []
    for i in s:
        if(i in charMarks):
            ans.append(i)
    
    return ''.join(ans)

In [31]:
data_file = '/data/willett_data/ptDecoder_ctc_both'
device = "cuda:2"

modelPath = "/data/willett_data/outputs/neurips_gru_baseline_seed_1/"

with open(modelPath + "/args", "rb") as handle:
            args = pickle.load(handle)

if 'max_mask_pct' not in args:
    args['max_mask_pct'] = 0
if 'num_masks' not in args:
    args['num_masks'] = 0
if 'input_dropout' not in args:
    args['input_dropout'] = 0
                
model = GRUDecoder(
        neural_dim=args["nInputFeatures"],
        n_classes=args["nClasses"],
        hidden_dim=args["nUnits"],
        layer_dim=args["nLayers"],
        nDays=args['nDays'],
        dropout=args["dropout"],
        device=device,
        strideLen=args["strideLen"],
        kernelLen=args["kernelLen"],
        gaussianSmoothWidth=args["gaussianSmoothWidth"],
        bidirectional=args["bidirectional"],
        input_dropout=args['input_dropout'], 
        max_mask_pct=args['max_mask_pct'],
        num_masks=args['num_masks']
    ).to(device)



model_outputs = {
    "logits": [],
    "logitLengths": [],
    "trueSeqs": [],
    "transcriptions": [],
}

trainLoaders, testLoaders, loadedData = getDatasetLoaders(data_file, 8)
partition = "competition" # "test"
if partition == "competition":
    testDayIdxs = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]
elif partition == "test":
    testDayIdxs = range(len(loadedData[partition]))

for i, testDayIdx in enumerate(testDayIdxs):
    test_ds = SpeechDataset([loadedData[partition][i]])
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=1, shuffle=False, num_workers=0
    )
    for j, (X, y, X_len, y_len, _) in enumerate(test_loader):
        X, y, X_len, y_len, dayIdx = (
            X.to(device),
            y.to(device),
            X_len.to(device),
            y_len.to(device),
            torch.tensor([testDayIdx], dtype=torch.int64).to(device),
        )
        pred = model.forward(X, X_len, dayIdx)
        adjustedLens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32)

        for iterIdx in range(pred.shape[0]):
            trueSeq = np.array(y[iterIdx][0 : y_len[iterIdx]].cpu().detach())

            model_outputs["logits"].append(pred[iterIdx].cpu().detach().numpy())
            model_outputs["logitLengths"].append(
                adjustedLens[iterIdx].cpu().detach().item()
            )
            model_outputs["trueSeqs"].append(trueSeq)

        transcript = loadedData[partition][i]["transcriptions"][j].strip()
        transcript = re.sub(r"[^a-zA-Z\- \']", "", transcript)
        transcript = transcript.replace("--", "").lower()
        model_outputs["transcriptions"].append(transcript)


In [32]:
base_dir = "/home3/skaasyap/willett"
acoustic_scale = 0.8
nbest = 10
lmDir = base_dir +'/lm/languageModel'
ngramDecoder = lmDecoderUtils.build_lm_decoder(
    lmDir,
    acoustic_scale=acoustic_scale, #1.2
    nbest=nbest,
    beam=18
)
print("loaded LM")


loaded LM


I0628 11:38:01.393167 3990267 brain_speech_decoder.h:52] Reading fst /home3/skaasyap/willett/lm/languageModel/TLG.fst
I0628 11:40:34.564733 3990267 brain_speech_decoder.h:81] Reading symbol table /home3/skaasyap/willett/lm/languageModel/words.txt


In [33]:
blank_penalty = np.log(2)
llm_outputs = []
# Generate nbest outputs from 5gram LM
start_t = time.time()
nbest_outputs = []
for j in range(len(model_outputs["logits"])):
    logits = model_outputs["logits"][j]
    logits = np.concatenate(
        [logits[:, 1:], logits[:, 0:1]], axis=-1
    )  # Blank is last token
    logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True)
    nbest = lmDecoderUtils.lm_decode(
        ngramDecoder,
        logits[0],
        blankPenalty=blank_penalty,
        returnNBest=True,
        rescore=True,
    )
    nbest_outputs.append(nbest)
time_per_sample = (time.time() - start_t) / len(model_outputs["logits"])
print(f"5gram decoding took {time_per_sample} seconds per sample")


: 

: 