In [None]:
for n_augs in updated_args['repeats']:
    for mn, model_name_str in enumerate(models_to_run):
        if model_name_str in skip_models:
            continue

        day_wer_dict, total_wer_dict = {}, {}

        for seed in seeds_list:
            if seed in skip_seeds:
                continue

            print(f"Running model: {model_name_str}_seed_{seed}")
            day_wer_dict[seed] = []

            modelPath = f"{model_storage_path}{model_name_str}_seed_{seed}"
            output_file = (
                f"{shared_output_file}_seed_{seed}"
                if shared_output_file
                else f"{model_name_str}_seed_{seed}"
            )

            # Load args
            with open(f"{modelPath}/args", "rb") as handle:
                args = pickle.load(handle)
 
                model = BiT_Phoneme(
                    patch_size=args['patch_size'], dim=args['dim'], dim_head=args['dim_head'],
                    nClasses=args['nClasses'], depth=args['depth'], heads=args['heads'],
                    mlp_dim_ratio=args['mlp_dim_ratio'], dropout=updated_args['dropout'], input_dropout=updated_args['input_dropout'],
                    gaussianSmoothWidth=args['gaussianSmoothWidth'],
                    T5_style_pos=args['T5_style_pos'], max_mask_pct=updated_args['max_mask_pct'],
                    num_masks=updated_args['num_masks'], mask_token_zeros=args['mask_token_zero'], max_mask_channels=0,
                    num_masks_channels=0, dist_dict_path=None
                ).to(device)

            if data_file is None:
                data_file = args['datasetPath']

            trainLoader, testLoaders, loadedData = getDatasetLoaders(data_file, 64)
                    
            args.setdefault('mask_token_zero', False)

            model.load_state_dict(torch.load(f"{modelPath}/modelWeights", map_location=device), strict=True)

            if tta_mode != 'baseline':
                print(updated_args['learning_rate'][mn])
                optimizer = torch.optim.AdamW(model.parameters(), lr=updated_args['learning_rate'][mn], 
                                            weight_decay=updated_args['l2_decay'],
                                                betas=(args['beta1'], args['beta2']))

                if updated_args['freeze_linear']:
                    for name, p in model.named_parameters():
                        p.requires_grad = name in {
                            "dayWeights", "dayBias"
                        }
                        
                if updated_args['freeze_patch']:
                    for name, p in model.named_parameters():
                        p.requires_grad = name in {
                            "to_patch_embedding.1.weight", "to_patch_embedding.1.bias",
                            "to_patch_embedding.2.weight", "to_patch_embedding.2.bias",
                            "to_patch_embedding.3.weight", "to_patch_embedding.3.bias"
                        }

            testDayIdxs = np.arange(len(loadedData['test']))
            print(len(testDayIdxs))
            
                
            model_outputs = {"logits": [], "logitLengths": [], "trueSeqs": [], "transcriptions": []}
          
            decoded_list_all_days = []
            transcripts_all_days = []
            
            for test_day_idx, testDayIdx in enumerate(testDayIdxs):
                
                print("day ", test_day_idx)
            
                val_ds = SpeechDataset([loadedData['test'][test_day_idx]], return_transcript=True)
                data_loader = get_dataloader(val_ds)                        
                transcriptions_list = []
                decoded_list = []
                
                test_day_decoded_sents = []
                
                for trial_idx, (X, y, X_len, y_len, day_idx, transcript) in enumerate(data_loader):
                                              
                    total_start = time.time()
                    
                    X, y, X_len, y_len = map(lambda x: x.to(device), [X, y, X_len, y_len])
                    
                    if updated_args['max_day'] is not None:
                        day_idx = torch.tensor([updated_args['max_day']], dtype=torch.int64).to(device)
                    else:
                        day_idx = torch.tensor([day_idx],  dtype=torch.int64).to(device)
                        
                    adjusted_len = model.compute_length(X_len)
                    
                    # obtain beam search + LM corrected outputs
                    # do this before adaptation on that trial to make 
                    # sure results are compatabile with a streaming system 
                    torch.cuda.synchronize(device)
                    torch.cuda.reset_peak_memory_stats(device)
                    
                    model.eval()
                    logits_eval = model(X, X_len, day_idx)
                    decoded, y_pseudo, y_len_pseudo = get_lm_outputs(logits_eval)
                    
                    if tta_mode != 'baseline':
                    
                        # generate multiple versions of the same input
                        if n_augs > 0:
                            
                            X = X.repeat(n_augs, 1, 1)
                            y = y.repeat(n_augs, 1)
                            y_len = y_len.repeat(n_augs)
                            X_len = X_len.repeat(n_augs)
                            adjusted_len = adjusted_len.repeat(n_augs)
                            y_pseudo = y_pseudo.unsqueeze(0).repeat(n_augs, 1).to(device) 
                            y_len_pseudo = y_len_pseudo.repeat(n_augs).to(device)
                            
                        
                        # add white noise and baseline shift augmentations to each sample
                        if updated_args['WN+BS'] == True:
                            
                            X += torch.randn(X.shape, 
                                        device=device) * updated_args['white_noise']
                        
                            X += (
                                torch.randn([X.shape[0], 1, X.shape[2]], 
                                device=device)
                                * updated_args['baseline_shift']
                            )      
                        
                        model.train()
                        
                        for _ in range(updated_args['adaptation_steps']):
                    
                            logits = model(X, X_len, day_idx)
                            
                            corp_loss = forward_ctc(logits, adjusted_len, y_pseudo, y_len_pseudo)
                            corp_loss_tracker.append(corp_loss.detach().cpu().numpy())
                            optimizer.zero_grad()
                            corp_loss.backward()
                            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                            optimizer.step()


                    model.eval()
                    
                    decoded_list.append(decoded)
                    transcriptions_list.append(clean_transcription(transcript[0]))
                    
                to_gib = 1024**3
                torch.cuda.synchronize(device)
                peak_alloc_gib = torch.cuda.max_memory_allocated(device) / to_gib      # tensors/activations
                peak_res_gib  = torch.cuda.max_memory_reserved(device) / to_gib       # allocator footprint

                print(f"PyTorch peak allocated: {peak_alloc_gib:.3f} GiB")
                print(f"PyTorch peak reserved : {peak_res_gib:.3f} GiB")
                
                sys.exit()
                        
                _, wer = _cer_and_wer(decoded_list, transcriptions_list, outputType="speech", returnCI=False)
                print("DAY WER: ", wer)
                day_wer_dict[seed].append(wer)
                
                decoded_list_all_days.extend(decoded_list)
                transcripts_all_days.extend(transcriptions_list)
                
                torch.cuda.synchronize(device)

                
            _, wer_total = _cer_and_wer(decoded_list_all_days, transcripts_all_days, outputType="speech", returnCI=False)
            total_wer_dict[seed] = wer_total
            print("WER ACROSS DAYS: ", wer_total)
            
        continue
        if val_save_file:
            
            val_save_file_updated = val_save_file.replace("dietcorp", f"diet{n_augs}corp")
            
            print(f"SAVING VAL RESULTS FOR {model_name_str}")
            with open(f"{saveFolder_data}{model_name_str}_{val_save_file_updated}.pkl", "wb") as f:
                pickle.dump(day_wer_dict, f)
            with= open(f"{saveFolder_data}{model_name_str}_{val_save_file_updated}_all_days.pkl", "wb") as f:
                pickle.dump(total_wer_dict, f)
                