In [5]:
import re
import time
import pickle
import numpy as np

from edit_distance import SequenceMatcher
import torch
from dataset import SpeechDataset
import matplotlib.pyplot as plt

from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.model import GRUDecoder
import pickle
import argparse
import matplotlib.pyplot as plt
from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.lm_utils import build_llama_1B
from neural_decoder.model import GRUDecoder
from neural_decoder.bit import BiT_Phoneme
import pickle
import argparse
from lm_utils import _cer_and_wer
import json
import os
import copy

In [6]:
def convert_sentence(s):
    s = s.lower()
    charMarks = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
                 "'", ' ']
    ans = []
    for i in s:
        if(i in charMarks):
            ans.append(i)
    
    return ''.join(ans)

In [7]:
base_dir = "/home3/skaasyap/willett"

load_lm = True
# LM decoding hyperparameters
acoustic_scale = 0.8
blank_penalty = np.log(2)

run_for_llm = False

if run_for_llm:
    return_n_best = True
    rescore = False
    nbest = 100
    print("RUNNING IN LLM MODE")
else:
    return_n_best = False
    rescore = False
    nbest = 1
    print("RUNNING IN N-GRAM MODE")

RUNNING IN N-GRAM MODE


In [8]:
if load_lm: 
        
    lmDir = base_dir +'/lm/languageModel'
    ngramDecoder = lmDecoderUtils.build_lm_decoder(
        lmDir,
        acoustic_scale=acoustic_scale, #1.2
        nbest=nbest,
        beam=18
    )
    print("loaded LM")
    

# Get model predictions
## Always check to make sure val perf matches wandb

In [37]:
output_file = 'obi'
    
device = "cuda:2"

if output_file == 'obi':
    model_storage_path = '/data/willett_data/outputs/'
elif output_file == 'leia':
    model_storage_path = '/data/willett_data/leia_outputs/'
    
models_to_run = ['tf_no_time_mask_blue']

shared_output_file = 'transformer_held_out_days_memo'

if len(shared_output_file) > 0:
    print("Writing to shared output file")
    write_mode = "a"
else:
    write_mode = "w"
    
seeds_list = [0,1,2,3]
partition = "competition" # "test"
run_lm = True

comp_on_reduced = True
fill_max_day = False
memo = True
memo_epochs = 1
memo_augs = 16
memo_lr = [6e-5, 6e-5, 6e-5]
skip_days = [[], [], []]

day_edit_distance = 0
day_seq_length = 0
prev_day = None


if partition == 'test':
    saveFolder_transcripts = "/data/willett_data/model_transcriptions/"
else:
    saveFolder_transcripts = "/data/willett_data/model_transcriptions_comp/"
    
for seed in seeds_list:
    
    for mn, model_name_str in enumerate(models_to_run):
        
        modelPath = f"{model_storage_path}{model_name_str}_seed_{seed}"
        
        if len(shared_output_file) > 0:
            output_file = f"{shared_output_file}_seed_{seed}"
            print(output_file)
        else:
            output_file = f"{model_name_str}_seed_{seed}"
            
        print(f"Running model: {model_name_str}_seed_{seed}")
            
        with open(modelPath + "/args", "rb") as handle:
            args = pickle.load(handle)
            
        if args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both':
            data_file = '/data/willett_data/ptDecoder_ctc_both'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data':
            data_file = '/data/willett_data/ptDecoder_ctc'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days_1':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days_1'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days_2':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days_2'
            
        else:
            data_file = args['datasetPath']
            
        trainLoaders, testLoaders, loadedData = getDatasetLoaders(
            data_file, 8
        )
        
        # if true, model is a GRU
        if 'nInputFeatures' in args.keys():
            
            if 'max_mask_pct' not in args:
                args['max_mask_pct'] = 0
            if 'num_masks' not in args:
                args['num_masks'] = 0
            if 'input_dropout' not in args:
                args['input_dropout'] = 0
                
            print("Loading GRU")
            model = GRUDecoder(
                neural_dim=args["nInputFeatures"],
                n_classes=args["nClasses"],
                hidden_dim=args["nUnits"],
                layer_dim=args["nLayers"],
                nDays=args['nDays'],
                dropout=args["dropout"],
                device=device,
                strideLen=args["strideLen"],
                kernelLen=args["kernelLen"],
                gaussianSmoothWidth=args["gaussianSmoothWidth"],
                bidirectional=args["bidirectional"],
                input_dropout=args['input_dropout'], 
                max_mask_pct=args['max_mask_pct'],
                num_masks=args['num_masks']
            ).to(device)

        else:
            
            if 'mask_token_zero' not in args:
                args['mask_token_zero'] = False
            
            print("Loading TRANSFORMER")
            
            # Instantiate model
            # set training relevant parameters for MEMO, doesn't matter for other runs because they are 
            # only run in eval mode.
            model = BiT_Phoneme(
                patch_size=args['patch_size'],
                dim=args['dim'],
                dim_head=args['dim_head'],
                nClasses=args['nClasses'],
                depth=args['depth'],
                heads=args['heads'],
                mlp_dim_ratio=args['mlp_dim_ratio'],
                dropout=0,
                input_dropout=0,
                look_ahead=args['look_ahead'],
                gaussianSmoothWidth=args['gaussianSmoothWidth'],
                T5_style_pos=args['T5_style_pos'],
                max_mask_pct=0.05,
                num_masks=20, 
                mask_token_zeros=args['mask_token_zero'], 
                num_masks_channels=0, 
                max_mask_channels=0, 
                dist_dict_path=0, 
            ).to(device)
            
            
        ckpt_path = modelPath + '/modelWeights'
        model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
        model = model.to(device)
        
        if memo: 
            optimizer = torch.optim.AdamW(model.parameters(), lr=memo_lr[mn], weight_decay=0, 
                                    betas=(args['beta1'], args['beta2']))
            original_state_dict = copy.deepcopy(model.state_dict())
            
            for p in model.parameters():
                p.requires_grad = False

            # Unfreeze patch‑embedding linear projection (assumed third module)
            for p in model.to_patch_embedding[2].parameters():
                p.requires_grad = True
        
        model.eval()

        model_outputs = {
            "logits": [],
            "logitLengths": [],
            "trueSeqs": [],
            "transcriptions": [],
        }
        
        total_edit_distance = 0
        total_seq_length = 0

        if partition == "competition":
            
            if comp_on_reduced:
                testDayIdxs = np.arange(5)
            else:
                testDayIdxs = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]
                
        elif partition == "test":
            
            testDayIdxs = range(len(loadedData[partition])) 
            
        ground_truth_sentences = []
        
        print("RESTRICTED DAYS: ", args['restricted_days'])
        
        for i, testDayIdx in enumerate(testDayIdxs):
            
            if len(skip_days[mn]) > 0:
                if testDayIdx in skip_days[mn]:
                    continue
            
            if len(args['restricted_days']) > 0:
                if testDayIdx not in args['restricted_days']:
                    continue
                
            test_ds = SpeechDataset([loadedData[partition][i]])
            test_loader = torch.utils.data.DataLoader(
                test_ds, batch_size=1, shuffle=False, num_workers=0
            )
            
            for j, (X, y, X_len, y_len, _) in enumerate(test_loader):
                        
                X, y, X_len, y_len, dayIdx = (
                    X.to(device),
                    y.to(device),
                    X_len.to(device),
                    y_len.to(device),
                    torch.tensor([testDayIdx], dtype=torch.int64).to(device),
                )
                
                if fill_max_day:
                    dayIdx.fill_(args['maxDay'])
                     
                if memo: 
                                        
                    model.train()
                    
                    for _ in range(memo_epochs):
                        
                        logits_aug = model.forward(X, X_len, testDayIdx, memo_augs)  # [memo_augs, T, D]
                        probs_aug = torch.nn.functional.softmax(logits_aug, dim=-1)          # [memo_augs, T, D]
                        marginal_probs = probs_aug.mean(dim=0)                               # [T, D]

                        adjustedLens = model.compute_length(X_len)
                        marginal_probs = marginal_probs[:adjustedLens]

                        loss = - (marginal_probs * marginal_probs.log()).sum(dim=-1).mean()

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        
                    model.eval()
                    
                with torch.no_grad():
                    
                    pred = model.forward(X, X_len, dayIdx)
                
                if hasattr(model, 'compute_length'):
                    adjustedLens = model.compute_length(X_len)
                else:
                    adjustedLens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32)
                    
                for iterIdx in range(pred.shape[0]):
                    
                    trueSeq = np.array(y[iterIdx][0 : y_len[iterIdx]].cpu().detach())
                    model_outputs["logits"].append(pred[iterIdx].cpu().detach().numpy())
                    
                    model_outputs["logitLengths"].append(
                        adjustedLens[iterIdx].cpu().detach().item()
                    )
                    
                    model_outputs["trueSeqs"].append(trueSeq)
                    
                    decodedSeq = torch.argmax(
                        torch.tensor(pred[iterIdx, 0 : adjustedLens[iterIdx], :]),
                        dim=-1,
                    ) 
                    
                    decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
                    decodedSeq = decodedSeq.cpu().detach().numpy()
                    decodedSeq = np.array([i for i in decodedSeq if i != 0])
                    
                    matcher = SequenceMatcher(
                        a=trueSeq.tolist(), b=decodedSeq.tolist()
                    )
                    
                    total_edit_distance += matcher.distance()
                    total_seq_length += len(trueSeq)
                    
                    day_edit_distance += matcher.distance()
                    day_seq_length += len(trueSeq)
                    
                transcript = loadedData[partition][i]["transcriptions"][j].strip()
                transcript = re.sub(r"[^a-zA-Z\- \']", "", transcript)
                transcript = transcript.replace("--", "").lower()
                model_outputs["transcriptions"].append(transcript)
                
            cer_day = day_edit_distance / day_seq_length
            print("CER DAY: ", cer_day)
            day_edit_distance = 0 
            day_seq_length = 0

        cer = total_edit_distance / total_seq_length
        
        print("Model performance: ", cer)

        if run_lm:
            
            print("Running n-gram LM")
            
            llm_outputs = []
            start_t = time.time()
            nbest_outputs = []
            
            for j in range(len(model_outputs["logits"])):
                
                logits = model_outputs["logits"][j]
                
                logits = np.concatenate(
                    [logits[:, 1:], logits[:, 0:1]], axis=-1
                )  # Blank is last token
                
                logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True)
                
                nbest = lmDecoderUtils.lm_decode(
                    ngramDecoder,
                    logits[0],
                    blankPenalty=blank_penalty,
                    returnNBest=return_n_best,
                    rescore=rescore,
                )
                
                nbest_outputs.append(nbest)
                
            time_per_sample = (time.time() - start_t) / len(model_outputs["logits"])
            print(f"N-gram decoding took {time_per_sample} seconds per sample")
            
            if run_for_llm:
                print("SAVING OUTPUTS FOR LLM")
                with open(f"{saveFolder_transcripts}{model_name_str}_seed_{seed}_model_outputs.pkl", "wb") as f:
                    pickle.dump(model_outputs, f)
                    
                with open(f"{saveFolder_transcripts}{model_name_str}_seed_{seed}_nbest.pkl", "wb") as f:
                    pickle.dump(nbest_outputs, f)
                
            else:
                # just get perf with greedy decoding
                for i in range(len(model_outputs["transcriptions"])):
                    model_outputs["transcriptions"][i] = model_outputs["transcriptions"][i].strip()
                    nbest_outputs[i] = nbest_outputs[i].strip()
                
                # lower case + remove puncs
                for i in range(len(model_outputs["transcriptions"])):
                    model_outputs["transcriptions"][i] = convert_sentence(model_outputs["transcriptions"][i])

                cer, wer = _cer_and_wer(nbest_outputs, model_outputs["transcriptions"], 
                                    outputType='speech', returnCI=True)

                print("CER and WER after 3-gram LM: ", cer, wer)       
                
                out_file = os.path.join(saveFolder_transcripts, output_file)   # no extension per your spec
                
                with open(out_file + '.txt', write_mode, encoding="utf-8") as f:
                    f.write("\n".join(nbest_outputs)+ "\n")   # one line per LLM output  

Writing to shared output file
transformer_held_out_days_memo_seed_0
Running model: neurips_transformer_time_masked_held_out_days_2_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []


  torch.tensor(pred[iterIdx, 0 : adjustedLens[iterIdx], :]),


CER DAY:  2.5265625
CER DAY:  2.7
CER DAY:  2.4390625
CER DAY:  2.3625
CER DAY:  2.1125
Model performance:  2.428125
Running n-gram LM
N-gram decoding took 0.03492524087429047 seconds per sample
CER and WER after 3-gram LM:  (2.3834375, 2.2868671875000004, 2.479375) (2.64, 2.5425, 2.73625)
transformer_held_out_days_memo_seed_0
Running model: neurips_transformer_time_masked_held_out_days_1_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []
CER DAY:  3.24375
CER DAY:  2.828125
CER DAY:  2.5765625
CER DAY:  2.696875
CER DAY:  2.9375
Model performance:  2.8565625
Running n-gram LM
N-gram decoding took 0.019733198881149293 seconds per sample
CER and WER after 3-gram LM:  (3.0825, 2.9684375, 3.1971875) (3.12, 3.0162187499999997, 3.22625)
transformer_held_out_days_memo_seed_0
Running model: neurips_transformer_time_masked_held_out_days_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []
CER DAY:  2.903125
CER DAY:  2.8546875
CER DAY:  3.021875
CER DAY:  3.2265625
CER DAY:  2.8546875
Model performanc

In [None]:
Running model: neurips_transformer_time_masked_held_out_days_1_seed_0
CER DAY:  0.18701323074035844
CER DAY:  0.20014556040756915
CER DAY:  0.2676017601760176
CER DAY:  0.30618311533888226
CER DAY:  0.341395752058951
Model performance:  0.26254497254307896
Running n-gram LM
N-gram decoding took 0.01895871174578764 seconds per sample
CER and WER after 3-gram LM:  (0.24947886266989078, 0.24185914701801067, 0.25716652364444587) (0.3511564287490674, 0.33995621145518856, 0.36229499823919814)

In [None]:
# with memo
# without
Running model: neurips_transformer_time_masked_held_out_days_2_seed_0
Output exceeds the size limit. Open the full output data in a text editorCER DAY:  0.3284115035707392
CER DAY:  0.380046403712297
CER DAY:  0.5224416517055656
CER DAY:  0.5491719863077066
CER DAY:  0.5892558916311004
Model performance:  0.47584572132081365
Running n-gram LM
N-gram decoding took 0.04170605278015137 seconds per sample
CER and WER after 3-gram LM:  (0.4804508971108348, 0.4715959430873275, 0.48947669469917093) (0.6677740863787376, 0.6558902882064046, 0.6799936068139665)
Running model: neurips_transformer_time_masked_held_out_days_1_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []
CER DAY:  0.1893591066904382
CER DAY:  0.2011852776044916
CER DAY:  0.27035203520352036
CER DAY:  0.33125247720967105
CER DAY:  0.38925010836584306
Model performance:  0.27901912516568833
Running n-gram LM
N-gram decoding took 0.02109781369871023 seconds per sample
CER and WER after 3-gram LM:  (0.25845076294505126, 0.25050590982146115, 0.2666196134722149) (0.36906242228301417, 0.35733628898284037, 0.3811317908198835)
Running model: neurips_transformer_time_masked_held_out_days_seed_0
CER DAY:  0.1958174904942966
CER DAY:  0.21266887357967254
CER DAY:  0.2108843537414966
CER DAY:  0.2558579384259746
CER DAY:  0.25327416387054874
Model performance:  0.22340826085319476
Running n-gram LM
N-gram decoding took 0.019061621614530976 seconds per sample
CER and WER after 3-gram LM:  (0.19087837837837837, 0.1840513671786001, 0.19784976888843056) (0.28044250645994834, 0.2702872083074558, 0.2907064326560632)

In [None]:
# without memo
Running model: neurips_transformer_time_masked_held_out_days_2_seed_0
CER DAY:  0.3284115035707392
CER DAY:  0.380046403712297
CER DAY:  0.5224416517055656
CER DAY:  0.5491719863077066
CER DAY:  0.5892558916311004
Model performance:  0.47584572132081365
neurips_transformer_time_masked_held_out_days_1_seed_0
CER DAY:  0.1893591066904382
CER DAY:  0.2011852776044916
CER DAY:  0.27035203520352036
CER DAY:  0.33125247720967105
CER DAY:  0.38925010836584306
Model performance:  0.27901912516568833
Running model: neurips_transformer_time_masked_held_out_days_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []
CER DAY:  0.1958174904942966
CER DAY:  0.21266887357967254
CER DAY:  0.2041424760709242
CER DAY:  0.2108843537414966
CER DAY:  0.2558579384259746
CER DAY:  0.25327416387054874
Model performance:  0.22134440503605587

# with memo (6e-5)
Running model: neurips_transformer_time_masked_held_out_days_2_seed_0
CER DAY:  0.31596216946535416
CER DAY:  0.34960556844547563
CER DAY:  0.47755834829443444
CER DAY:  0.5047645480618004
CER DAY:  0.5534421970681017
Model performance:  0.4418882416714136


# with memo (1e-4)
neurips_transformer_time_masked_held_out_days_1_seed_0
CER DAY:  0.18710706577836164
CER DAY:  0.1988978997712622
CER DAY:  0.2666850018335167
CER DAY:  0.3053904082441538
CER DAY:  0.34442999566536625
Model performance:  0.26265858738875214
Running model: neurips_transformer_time_masked_held_out_days_seed_0
Loading TRANSFORMER
RESTRICTED DAYS:  []
CER DAY:  0.19218804009678533
CER DAY:  0.20622707345441532
CER DAY:  0.19708143731366703
CER DAY:  0.2010677688797038
CER DAY:  0.24188294630660923
CER DAY:  0.2463470072518671
Model performance:  0.21344405036055875

# 1.5e-4 No help

# 9e-5 No help