In [1]:
import re
import time
import pickle
import numpy as np

from edit_distance import SequenceMatcher
import torch
from dataset import SpeechDataset
import matplotlib.pyplot as plt

from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.model import GRUDecoder
import pickle
import argparse
import matplotlib.pyplot as plt
from neural_decoder.dataset import getDatasetLoaders
import neural_decoder.lm_utils as lmDecoderUtils
from neural_decoder.lm_utils import build_llama_1B
from neural_decoder.model import GRUDecoder
from neural_decoder.bit import BiT_Phoneme
import pickle
import argparse
from lm_utils import _cer_and_wer
import json
import os
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
output_file = 'leia'
    
device = "cuda:2"

if output_file == 'obi':
    model_storage_path = '/data/willett_data/outputs/'
elif output_file == 'leia':
    model_storage_path = '/data/willett_data/leia_outputs/'
    
models_to_run = ['neurips_transformer_time_masked_held_out_days_2']

shared_output_file = ''

if len(shared_output_file) > 0:
    print("Writing to shared output file")
    write_mode = "a"
else:
    write_mode = "w"
    
seeds_list = [0]
partition = "test" # "test"
run_lm = True

comp_on_reduced = True
fill_max_day = False
memo = True
memo_epochs = 1
memo_augs = 16
memo_lr = [3e-5, 6e-5, 6e-5]
skip_days = [[], [], []]

day_edit_distance = 0
day_seq_length = 0
prev_day = None

memo_seeds = np.arange(1)

if partition == 'test':
    saveFolder_transcripts = "/data/willett_data/model_transcriptions/"
else:
    saveFolder_transcripts = "/data/willett_data/model_transcriptions_comp/"
    

day_cer_dict = {}
total_wer_dict = {}

for seed in seeds_list:
    
    day_cer_dict[seed] = []
    total_wer_dict[seed] = []
            
    for mn, model_name_str in enumerate(models_to_run):
        
        modelPath = f"{model_storage_path}{model_name_str}_seed_{seed}"
        
        if len(shared_output_file) > 0:
            output_file = f"{shared_output_file}_seed_{seed}"
            print(output_file)
        else:
            output_file = f"{model_name_str}_seed_{seed}"
            
        print(f"Running model: {model_name_str}_seed_{seed}")
            
        with open(modelPath + "/args", "rb") as handle:
            args = pickle.load(handle)
            
        if args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both':
            data_file = '/data/willett_data/ptDecoder_ctc_both'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data':
            data_file = '/data/willett_data/ptDecoder_ctc'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days_1':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days_1'
            
        elif args['datasetPath'].rsplit('/', 1)[-1] == 'data_log_both_held_out_days_2':
            data_file = '/data/willett_data/ptDecoder_ctc_both_held_out_days_2'
            
        else:
            data_file = args['datasetPath']
            
        trainLoaders, testLoaders, loadedData = getDatasetLoaders(
            data_file, 8
        )
        
        # if true, model is a GRU
        if 'nInputFeatures' in args.keys():
            
            if 'max_mask_pct' not in args:
                args['max_mask_pct'] = 0
            if 'num_masks' not in args:
                args['num_masks'] = 0
            if 'input_dropout' not in args:
                args['input_dropout'] = 0
                
            print("Loading GRU")
            model = GRUDecoder(
                neural_dim=args["nInputFeatures"],
                n_classes=args["nClasses"],
                hidden_dim=args["nUnits"],
                layer_dim=args["nLayers"],
                nDays=args['nDays'],
                dropout=args["dropout"],
                device=device,
                strideLen=args["strideLen"],
                kernelLen=args["kernelLen"],
                gaussianSmoothWidth=args["gaussianSmoothWidth"],
                bidirectional=args["bidirectional"],
                input_dropout=args['input_dropout'], 
                max_mask_pct=args['max_mask_pct'],
                num_masks=args['num_masks']
            ).to(device)

        else:
            
            if 'mask_token_zero' not in args:
                args['mask_token_zero'] = False
                            
            # Instantiate model
            # set training relevant parameters for MEMO, doesn't matter for other runs because they are 
            # only run in eval mode.
            model = BiT_Phoneme(
                patch_size=args['patch_size'],
                dim=args['dim'],
                dim_head=args['dim_head'],
                nClasses=args['nClasses'],
                depth=args['depth'],
                heads=args['heads'],
                mlp_dim_ratio=args['mlp_dim_ratio'],
                dropout=0,
                input_dropout=0,
                look_ahead=args['look_ahead'],
                gaussianSmoothWidth=args['gaussianSmoothWidth'],
                T5_style_pos=args['T5_style_pos'],
                max_mask_pct=0.05,
                num_masks=20, 
                mask_token_zeros=args['mask_token_zero'], 
                num_masks_channels=0, 
                max_mask_channels=0, 
                dist_dict_path=0, 
            ).to(device)
            
            
        ckpt_path = modelPath + '/modelWeights'
        model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
        model = model.to(device)
        
        if memo: 
            
            optimizer = torch.optim.AdamW(model.parameters(), lr=memo_lr[mn], weight_decay=0, 
                                    betas=(args['beta1'], args['beta2']))
            
            for name, p in model.named_parameters():
                
                if name in {
                    "to_patch_embedding.1.weight",
                    "to_patch_embedding.1.bias",
                    "to_patch_embedding.2.weight",
                    "to_patch_embedding.2.bias",
                    "to_patch_embedding.3.weight",
                    "to_patch_embedding.3.bias"
                }:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
        
        model.eval()

        model_outputs = {
            "logits": [],
            "logitLengths": [],
            "trueSeqs": [],
            "transcriptions": [],
        }
        
        total_edit_distance = 0
        total_seq_length = 0

        if partition == "competition":
            
            if comp_on_reduced:
                testDayIdxs = np.arange(5)
            else:
                testDayIdxs = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]
                
        elif partition == "test":
            
            testDayIdxs = range(len(loadedData[partition])) 
            
        ground_truth_sentences = []
        
        #print("RESTRICTED DAYS: ", args['restricted_days'])
        
        for i, testDayIdx in enumerate(testDayIdxs):
            
            if len(skip_days[mn]) > 0:
                if testDayIdx in skip_days[mn]:
                    print("SKIPPING DAY: ", testDayIdx)
                    continue
            
            if len(args['restricted_days']) > 0:
                if testDayIdx not in args['restricted_days']:
                    continue
                
            test_ds = SpeechDataset([loadedData[partition][i]])
            test_loader = torch.utils.data.DataLoader(
                test_ds, batch_size=1, shuffle=False, num_workers=0
            )
            
            for j, (X, y, X_len, y_len, _) in enumerate(test_loader):
                        
                X, y, X_len, y_len, dayIdx = (
                    X.to(device),
                    y.to(device),
                    X_len.to(device),
                    y_len.to(device),
                    torch.tensor([testDayIdx], dtype=torch.int64).to(device),
                )
                
                if fill_max_day:
                    dayIdx.fill_(args['maxDay'])
                    
                if memo: 
                                        
                    model.train()
                    
                    for _ in range(memo_epochs):
                        
                        logits_aug = model.forward(X, X_len, testDayIdx, memo_augs)  # [memo_augs, T, D]
                        probs_aug = torch.nn.functional.softmax(logits_aug, dim=-1)          # [memo_augs, T, D]
                        marginal_probs = probs_aug.mean(dim=0)                               # [T, D]

                        adjustedLens = model.compute_length(X_len)
                        marginal_probs = marginal_probs[:adjustedLens]

                        loss = - (marginal_probs * marginal_probs.log()).sum(dim=-1).mean()

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        
                    model.eval()
                    
                with torch.no_grad():
                    
                    pred = model.forward(X, X_len, dayIdx)
                
                if hasattr(model, 'compute_length'):
                    adjustedLens = model.compute_length(X_len)
                else:
                    adjustedLens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32)
                    
                for iterIdx in range(pred.shape[0]):
                    
                    trueSeq = np.array(y[iterIdx][0 : y_len[iterIdx]].cpu().detach())
                    model_outputs["logits"].append(pred[iterIdx].cpu().detach().numpy())
                    
                    model_outputs["logitLengths"].append(
                        adjustedLens[iterIdx].cpu().detach().item()
                    )
                    
                    model_outputs["trueSeqs"].append(trueSeq)
                    
                    decodedSeq = torch.argmax(
                        torch.tensor(pred[iterIdx, 0 : adjustedLens[iterIdx], :]),
                        dim=-1,
                    ) 
                    
                    decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
                    decodedSeq = decodedSeq.cpu().detach().numpy()
                    decodedSeq = np.array([i for i in decodedSeq if i != 0])
                    
                    matcher = SequenceMatcher(
                        a=trueSeq.tolist(), b=decodedSeq.tolist()
                    )
                    
                    total_edit_distance += matcher.distance()
                    total_seq_length += len(trueSeq)
                    
                    day_edit_distance += matcher.distance()
                    day_seq_length += len(trueSeq)
                    
                transcript = loadedData[partition][i]["transcriptions"][j].strip()
                transcript = re.sub(r"[^a-zA-Z\- \']", "", transcript)
                transcript = transcript.replace("--", "").lower()
                model_outputs["transcriptions"].append(transcript)
                
            cer_day = day_edit_distance / day_seq_length
            day_cer_dict[seed].append(cer_day)
            print("CER DAY: ", cer_day)
            day_edit_distance = 0 
            day_seq_length = 0

        cer = total_edit_distance / total_seq_length
        
        print("Model performance: ", cer)

Writing to shared output file
corrected_memo_transformer_seed_0
Running model: neurips_transformer_time_masked_held_out_days_2_seed_0


  torch.tensor(pred[iterIdx, 0 : adjustedLens[iterIdx], :]),


Model performance:  0.4801822825747414
