In [7]:
# Run this once per kernel
%load_ext autoreload
%autoreload 2
import re
import time
import pickle
import numpy as np

from edit_distance import SequenceMatcher
import torch
from dataset import SpeechDataset
import matplotlib.pyplot as plt

from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.model import GRUDecoder
import pickle
import argparse
import matplotlib.pyplot as plt
from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.lm_utils import build_llama_1B
from neural_decoder.model import GRUDecoder
from neural_decoder.bit import BiT_Phoneme
import pickle
import argparse
from lm_utils import _cer_and_wer
import json
import os
import copy
from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from loss import memo_loss_from_logits, forward_ctc
from g2p_en import G2p

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
saveFolder_data = "/data/willett_data/paper_results_obi/"
saveFolder_transcripts = "/data/willett_data/model_transcriptions_comp/"

output_file = 'leia'
device = "cuda:2"

if output_file == 'obi':
    model_storage_path = '/data/willett_data/outputs/'
elif output_file == 'leia':
    model_storage_path = '/data/willett_data/leia_outputs/'

In [9]:
base_dir = "/home3/skaasyap/willett"

load_lm = False
# LM decoding hyperparameters
acoustic_scale = 0.8
blank_penalty = np.log(2)

run_for_llm = False

if run_for_llm:
    return_n_best = True
    rescore = False
    nbest = 100
    print("RUNNING IN LLM MODE")
else:
    return_n_best = False
    rescore = False
    nbest = 1
    print("RUNNING IN N-GRAM MODE")
    
if load_lm: 
        
    lmDir = base_dir +'/lm/languageModel'
    ngramDecoder = lmDecoderUtils.build_lm_decoder(
        lmDir,
        acoustic_scale=acoustic_scale, #1.2
        nbest=nbest,
        beam=18
    )
    print("loaded LM")
    
    load_lm = False

RUNNING IN N-GRAM MODE


In [19]:
models_to_run = ['neurips_transformer_time_masked_held_out_days_2', 
                 'neurips_transformer_time_masked_held_out_days_1', 
                 'neurips_transformer_time_masked_held_out_days']


shared_output_file = 'entropy_min'
val_save_file = 'entropy_min'
seeds_list = [0,1,2,3]

if len(shared_output_file) > 0:
    print("Writing to shared output file")
    write_mode = "a"
else:
    write_mode = "w"
    
evaluate_comp = True
run_lm = True

tta = True
run_memo = True
run_lang_informed = False

memo_epochs = 1
memo_augs = 0
if memo_augs:
    max_mask_pct = 0.05
    num_masks = 20
else:
    max_mask_pct = 0
    num_masks = 0


nptl_augs = 0
nptl_aug_params = [0.2, 0.05] # white noise, constant offset

memo_lr = [3e-5, 6e-5, 6e-5]

partition = "competition" 
blank_id = 0


Writing to shared output file


In [20]:
def convert_sentence(s):
    s = s.lower()
    charMarks = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
                 "'", ' ']
    ans = []
    for i in s:
        if(i in charMarks):
            ans.append(i)
    
    return ''.join(ans)


def get_phonemes(thisTranscription):
    
    phonemes = []
    g2p = G2p()
    
    
    for p in g2p(thisTranscription):
        
        if p == ' ':
            phonemes.append('SIL')
        p = re.sub(r'[0-9]', '', p)  # Remove stress
        if re.match(r'^[A-Z]+$', p):  # Only keep phonemes (uppercase only)
            phonemes.append(p)
    
    phonemes.append('SIL')  # Add trailing SIL
    
    PHONE_DEF = [
        'AA', 'AE', 'AH', 'AO', 'AW',
        'AY', 'B',  'CH', 'D', 'DH',
        'EH', 'ER', 'EY', 'F', 'G',
        'HH', 'IH', 'IY', 'JH', 'K',
        'L', 'M', 'N', 'NG', 'OW',
        'OY', 'P', 'R', 'S', 'SH',
        'T', 'TH', 'UH', 'UW', 'V',
        'W', 'Y', 'Z', 'ZH'
    ]
    PHONE_DEF_SIL = PHONE_DEF + ['SIL']

    phoneme_ids = [PHONE_DEF_SIL.index(p) for p in phonemes]

    return torch.tensor(phoneme_ids, dtype=torch.long), torch.tensor([len(phoneme_ids)], dtype=torch.long)

def get_data_file(path):
    suffix_map = {
        "data_log_both": "/data/willett_data/ptDecoder_ctc_both",
        "data": "/data/willett_data/ptDecoder_ctc",
        "data_log_both_held_out_days": "/data/willett_data/ptDecoder_ctc_both_held_out_days",
        "data_log_both_held_out_days_1": "/data/willett_data/ptDecoder_ctc_both_held_out_days_1",
        "data_log_both_held_out_days_2": "/data/willett_data/ptDecoder_ctc_both_held_out_days_2",
    }
    suffix = path.rsplit('/', 1)[-1]
    return suffix_map.get(suffix, path)

def reverse_dataset(dataset):
    return Subset(dataset, list(reversed(range(len(dataset)))))

def get_dataloader(dataset, batch_size=1):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                       shuffle=False, num_workers=0)

def decode_sequence(pred, adjusted_len):
    pred = torch.argmax(pred[:adjusted_len], dim=-1)
    pred = torch.unique_consecutive(pred)
    return np.array([i for i in pred.cpu().numpy() if i != 0])

day_edit_distance = 0
day_seq_length = 0

for mn, model_name_str in enumerate(models_to_run):
    day_cer_dict, total_wer_dict = {}, {}

    for seed in seeds_list:
        
        print(f"Running model: {model_name_str}_seed_{seed}")
        
        day_cer_dict[seed], total_wer_dict[seed] = [], []

        modelPath = f"{model_storage_path}{model_name_str}_seed_{seed}"
        output_file = f"{shared_output_file}_seed_{seed}" if shared_output_file else f"{model_name_str}_seed_{seed}"

        with open(f"{modelPath}/args", "rb") as handle:
            args = pickle.load(handle)
            
        model = BiT_Phoneme(
        patch_size=args['patch_size'], dim=args['dim'], dim_head=args['dim_head'],
        nClasses=args['nClasses'], depth=args['depth'], heads=args['heads'],
        mlp_dim_ratio=args['mlp_dim_ratio'], dropout=0, input_dropout=0,
        look_ahead=args['look_ahead'], gaussianSmoothWidth=args['gaussianSmoothWidth'],
        T5_style_pos=args['T5_style_pos'], max_mask_pct=max_mask_pct,
        num_masks=num_masks, mask_token_zeros=args['mask_token_zero'], max_mask_channels=0,
        num_masks_channels=0, dist_dict_path=None
        ).to(device)

        data_file = get_data_file(args['datasetPath'])

        trainLoaders, testLoaders, loadedData = getDatasetLoaders(data_file, 8)
        args.setdefault('mask_token_zero', False)

        model.load_state_dict(torch.load(f"{modelPath}/modelWeights", map_location=device), strict=True)
        model.eval()

        optimizer = torch.optim.AdamW(model.parameters(), lr=memo_lr[mn], weight_decay=0,
                                      betas=(args['beta1'], args['beta2']))

        for name, p in model.named_parameters():
            p.requires_grad = name in {
                "to_patch_embedding.1.weight", "to_patch_embedding.1.bias",
                "to_patch_embedding.2.weight", "to_patch_embedding.2.bias",
                "to_patch_embedding.3.weight", "to_patch_embedding.3.bias"
            }

        testDayIdxs = np.arange(5)
        valDayIdxs = [0, 1, 3, 4, 5] if mn == 2 else [0, 1, 2, 3, 4]

        model_outputs = {"logits": [], "logitLengths": [], "trueSeqs": [], "transcriptions": []}
        
        total_edit_distance = total_seq_length = 0
        nbest_outputs = []
        nbest_outputs_val = []
        
        for i, testDayIdx in enumerate(testDayIdxs):
            
            ve = valDayIdxs[i]
            val_ds = reverse_dataset(SpeechDataset([loadedData['test'][ve]]))
            test_ds = reverse_dataset(SpeechDataset([loadedData['competition'][i]]))
            combined_ds = ConcatDataset([val_ds, test_ds])
            data_loader = get_dataloader(combined_ds)

            if tta:
            
                for trial_idx, (X, y, X_len, y_len, _) in enumerate(data_loader):
                    
                
                    X, y, X_len, y_len = map(lambda x: x.to(device), [X, y, X_len, y_len])
                    
                    dayIdx = torch.tensor([ve], dtype=torch.int64).to(device)
                    
                    model.train()
                    
                    memo_loss = li_loss = torch.tensor(0.0, device=device)
                    for _ in range(memo_epochs):
                        
                        logits_aug = model(X, X_len, ve, memo_augs, nptl_augs, nptl_aug_params)
                        logits_np = logits_aug[0].detach().cpu().numpy()
                        logits = logits_aug[0:1]
                        adjusted_len = model.compute_length(X_len)
                        
                        if run_memo:
                            
                            memo_loss = memo_loss_from_logits(logits_aug, adjusted_len, blank_id)
                            
                        if run_lang_informed:
                            
                            logits_np = np.concatenate([logits_np[:, 1:], logits_np[:, 0:1]], axis=-1)
                            
                            logit_np = lmDecoderUtils.rearrange_speech_logits(logits_np[None, :, :], 
                                                                            has_sil=True)
                            
                            
                            decoded = lmDecoderUtils.lm_decode(ngramDecoder, logits_np, 
                                                        blankPenalty=blank_penalty,
                                                    returnNBest=return_n_best, rescore=rescore)
                            
                            y_pseudo, y_len_pseudo = get_phonemes(decoded)
                            y_pseudo = y_pseudo.to(device)
                            y_len_pseudo = y_len_pseudo.to(device)
                                                        
                            
                            li_loss = forward_ctc(logits, adjusted_len, 
                                                  y_pseudo, y_len_pseudo)
                      
                        loss = memo_loss

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        
                    model.eval()
                    
                    # get validation performance 
                    if trial_idx < len(val_ds):
            
                        for idx in range(logits.shape[0]):
                            trueSeq = y[idx][:y_len[idx]].cpu().numpy()
                            decoded = decode_sequence(logits[idx], adjusted_len[idx])
                            dist = SequenceMatcher(a=trueSeq.tolist(), b=decoded.tolist()).distance()

                            total_edit_distance += dist
                            total_seq_length += len(trueSeq)
                            day_edit_distance += dist
                            day_seq_length += len(trueSeq)
                            
                    # get test set predictions 
                    #else:   
                    #    nbest_outputs.append(decoded)
                        
            print("DAY CER: ", day_edit_distance / day_seq_length)
            day_cer_dict[seed].append(day_edit_distance / day_seq_length)
            day_edit_distance = 0 
            day_seq_length = 0
            
            if val_save_file:
                print(f"SAVING VAL RESULTS FOR {model_name_str}")
                with open(f"{saveFolder_data}{model_name_str}_{val_save_file}.pkl", "wb") as f:
                    pickle.dump(day_cer_dict, f)
            
            
        #out_file = os.path.join(saveFolder_transcripts, output_file)
        #with open(out_file + '.txt', write_mode, encoding="utf-8") as f:
        #    f.write("\n".join(nbest_outputs) + "\n")
            
        #model_outputs["transcriptions"] = [convert_sentence(t.strip()) for t in model_outputs["transcriptions"]]
        #nbest_outputs = [t.strip() for t in nbest_outputs]
        #cer, wer = _cer_and_wer(nbest_outputs, model_outputs["transcriptions"],
        #                        outputType='speech', returnCI=True)
        #total_wer_dict[seed] = wer

'''
            model.eval()
            day_edit_distance = day_seq_length = 0

            with torch.no_grad():
                
                for i, (X, y, X_len, y_len, _) in enumerate(data_loader):
                    
                    X, y, X_len, y_len = map(lambda x: x.to(device), [X, y, X_len, y_len])
                    
                    dayIdx = torch.tensor([ve], dtype=torch.int64).to(device)
                    pred = model(X, X_len, dayIdx)
                    adjustedLens = model.compute_length(X_len)
                
                    if i < len(val_ds):
            
                        for idx in range(pred.shape[0]):
                            trueSeq = y[idx][:y_len[idx]].cpu().numpy()
                            decoded = decode_sequence(pred[idx], adjustedLens[idx])
                            dist = SequenceMatcher(a=trueSeq.tolist(), b=decoded.tolist()).distance()

                            total_edit_distance += dist
                            total_seq_length += len(trueSeq)
                            day_edit_distance += dist
                            day_seq_length += len(trueSeq)
                            
                    else: 
                    
                        for idx in range(pred.shape[0]):
                            
                            decoded = decode_sequence(pred[idx], adjustedLens[idx])
                            
                            transcript = loadedData[partition][i]["transcriptions"][j].strip()
                            transcript = re.sub(r"[^a-zA-Z\- \']", "", transcript).replace("--", "").lower()

                            model_outputs["logits"].append(pred[idx].cpu().numpy())
                            model_outputs["logitLengths"].append(adjustedLens[idx].item())
                            model_outputs["trueSeqs"].append(y[idx][:y_len[idx]].cpu().numpy())
                            model_outputs["transcriptions"].append(transcript)
                        



        if run_lm:
            print("Running LM decoding...")
            nbest_outputs = []
            for logits in model_outputs["logits"]:
                logits = np.concatenate([logits[:, 1:], logits[:, 0:1]], axis=-1)
                logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True)
                decoded = lmDecoderUtils.lm_decode(ngramDecoder, logits[0], blankPenalty=blank_penalty,
                                                   returnNBest=return_n_best, rescore=rescore)
                nbest_outputs.append(decoded)

            model_outputs["transcriptions"] = [convert_sentence(t.strip()) for t in model_outputs["transcriptions"]]
            nbest_outputs = [t.strip() for t in nbest_outputs]
            cer, wer = _cer_and_wer(nbest_outputs, model_outputs["transcriptions"],
                                    outputType='speech', returnCI=True)
            total_wer_dict[seed] = wer

           
    if val_save_file:
        print(f"SAVING VAL RESULTS FOR {model_name_str}")
        with open(f"{saveFolder_data}{model_name_str}_{val_save_file}.pkl", "wb") as f:
            pickle.dump(day_cer_dict, f)
'''


Running model: neurips_transformer_time_masked_held_out_days_2_seed_0
DAY CER:  0.3284115035707392
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.380046403712297
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.5224416517055656
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.5491719863077066
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.5892558916311004
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
Running model: neurips_transformer_time_masked_held_out_days_2_seed_1
DAY CER:  0.3249372707971434
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.36751740139211136
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.5191929554586646
SAVING VAL RESULTS FOR neurips_transformer_time_masked_held_out_days_2
DAY CER:  0.5494495327967435
SAVING VAL RESULTS FOR neurips_

'\n            model.eval()\n            day_edit_distance = day_seq_length = 0\n\n            with torch.no_grad():\n                \n                for i, (X, y, X_len, y_len, _) in enumerate(data_loader):\n                    \n                    X, y, X_len, y_len = map(lambda x: x.to(device), [X, y, X_len, y_len])\n                    \n                    dayIdx = torch.tensor([ve], dtype=torch.int64).to(device)\n                    pred = model(X, X_len, dayIdx)\n                    adjustedLens = model.compute_length(X_len)\n                \n                    if i < len(val_ds):\n            \n                        for idx in range(pred.shape[0]):\n                            trueSeq = y[idx][:y_len[idx]].cpu().numpy()\n                            decoded = decode_sequence(pred[idx], adjustedLens[idx])\n                            dist = SequenceMatcher(a=trueSeq.tolist(), b=decoded.tolist()).distance()\n\n                            total_edit_distance += dist\n       