In [24]:
%load_ext autoreload
%autoreload 2
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
from inference_functions import load_bit_phoneme_model
from dataset import getDatasetLoaders
from edit_distance import SequenceMatcher

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
language_model_path = "/workspace/transformers_with_dietcorp/lm/char_6gram_lm.nemo"

In [26]:
lm = NGramGPULanguageModel.from_nemo(
    lm_path=language_model_path,
    vocab_size=26
)

[NeMo I 2025-08-26 01:07:58 save_restore_connector:282] Model NGramGPULanguageModel was successfully restored from /workspace/transformers_with_dietcorp/lm/char_6gram_lm.nemo.


In [27]:
decoder = BatchedBeamCTCComputer(blank_index=26, beam_size=16, 
                                 return_best_hypothesis=False, 
                                 fusion_models=[lm], 
                                 fusion_models_alpha=[0.2])


In [28]:
device = 'cuda'
bit_phoneme_filepath = "/data/models/time_masked_transfomer_characters_80ms_seed_0/"
model, args = load_bit_phoneme_model(bit_phoneme_filepath)
model = model.to(device)

In [29]:
data_file = '/data/neural_data/ptDecoder_ctc_both_char'
trainLoaders, testLoaders, loadedData = getDatasetLoaders(
        data_file, 8, None, 
        False
    )

In [30]:
import re
import numpy as np
import torch
from typing import Dict, Any, List, Tuple
from dataset import SpeechDataset  # adjust if your path differs

def evaluate_model(
    model: torch.nn.Module,
    loadedData: Dict[str, List[Dict[str, Any]]],
    args: Dict[str, Any],
    partition: str,               # "test" or "competition"
    device: torch.device,
    fill_max_day: bool = False,   # optional, keep behavior you had
    verbose: bool = True
) -> Tuple[Dict[str, List[Any]], float, List[float]]:
    """
    Minimal evaluation: runs `model` over `partition`, collects outputs, and computes CER.
    Returns (model_outputs, overall_CER, per_day_CER_list).
    """

    # Decide day indices
    if partition == "competition":
        day_indices = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]
    elif partition == "test":
        day_indices = list(range(len(loadedData[partition])))
    else:
        raise ValueError(f"Unknown partition '{partition}'")

    # Pull common flags from args with safe defaults
    restricted_days = set(args.get("restricted_days", []))
    ventral_6v_only = bool(args.get("ventral_6v_only", False))

    # Accumulators
    outputs = {"logits": [], "logitLengths": [], "trueSeqs": [], "transcriptions": []}
    per_day_cer: List[float] = []
    total_edit, total_len = 0, 0

    model.eval()

    for idx_in_enum, day_idx in enumerate(day_indices):
        if restricted_days and (day_idx not in restricted_days):
            continue

        # one-day dataset/loader (mirror your original)
        one_day = loadedData[partition][idx_in_enum]
        loader = torch.utils.data.DataLoader(SpeechDataset([one_day]), batch_size=1, shuffle=False, num_workers=0)

        day_edit, day_len = 0, 0

        for j, (X, y, X_len, y_len, _) in enumerate(loader):
            X, y, X_len, y_len = X.to(device), y.to(device), X_len.to(device), y_len.to(device)
            day_tensor = torch.tensor([day_idx], dtype=torch.int64, device=device)

            if ventral_6v_only:
                X = X[:, :, :128]

            with torch.no_grad(): 
                pred = model.forward(X, X_len, day_tensor)[:,:, :30]

            # Output lengths
            if hasattr(model, "compute_length"):
                out_lens = model.compute_length(X_len)
            else:
                # fallback: conv-style
                out_lens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32)

            # Batch loop (batch_size=1, but keep general)
            for b in range(pred.shape[0]):
                tlen = int(y_len[b].item())
                true_seq = np.array(y[b][:tlen].cpu().numpy())

                logits_b = pred[b].detach().cpu().numpy()
                Lb = int(out_lens[b].item())

                outputs["logits"].append(logits_b)
                outputs["logitLengths"].append(Lb)
                outputs["trueSeqs"].append(true_seq)

                # Greedy CTC decode (blank=0), collapse repeats
                decoded = torch.argmax(pred[b, :Lb, :], dim=-1)
                decoded = torch.unique_consecutive(decoded).cpu().numpy()
                decoded = decoded[decoded != 0]
                 
                matcher = SequenceMatcher(
                    a=true_seq.tolist(), b=decoded.tolist()
                )
            
                ed = matcher.distance()
                total_edit += ed
                total_len += len(true_seq)
                day_edit += ed
                day_len += len(true_seq)

            # normalized transcript
            t = one_day["transcriptions"][j].strip()
            t = re.sub(r"[^a-zA-Z\- \']", "", t).replace("--", "").lower()
            outputs["transcriptions"].append(t)

        if day_len > 0:
            day_cer = day_edit / day_len
            per_day_cer.append(day_cer)
            if verbose:
                print(f"CER DAY {day_idx}: {day_cer:.6f}")

    cer = (total_edit / total_len) if total_len > 0 else float("nan")
    if verbose:
        print("Model performance (CER):", cer)

    return outputs, cer, per_day_cer


In [31]:
outputs, cer, per_day_cer = evaluate_model(model, loadedData, args, partition='test', device='cuda')

CER DAY 0: 0.373233
CER DAY 1: 0.293769
CER DAY 2: 0.276964
CER DAY 3: 0.276510
CER DAY 4: 0.149733
CER DAY 5: 0.118876
CER DAY 6: 0.148660
CER DAY 7: 0.176110
CER DAY 8: 0.167794
CER DAY 9: 0.212660
CER DAY 10: 0.168666
CER DAY 11: 0.208599
CER DAY 12: 0.185535
CER DAY 13: 0.150725
CER DAY 14: 0.122137
CER DAY 15: 0.129672
CER DAY 16: 0.133152
CER DAY 17: 0.186555
CER DAY 18: 0.153346
CER DAY 19: 0.129005
CER DAY 20: 0.111397
CER DAY 21: 0.143902
CER DAY 22: 0.125197
CER DAY 23: 0.166225
Model performance (CER): 0.17459088509590004


In [32]:
import torch.nn.functional as F
num_classes = 30
logits = np.zeros((len(outputs['logits']), max(outputs['logitLengths']), num_classes))
for idx, l in enumerate(outputs['logits']):
    l_length = outputs['logitLengths'][idx]
    logits[idx, :l_length, :] = l
    
logits_torch = torch.from_numpy(logits)
log_probs = F.log_softmax(logits_torch, dim=-1)
log_probs_length = torch.from_numpy(np.array(outputs['logitLengths']))

In [33]:
decoder = BatchedBeamCTCComputer(blank_index=0, beam_size=16, return_best_hypothesis=False, fusion_models=[lm], 
                                 fusion_models_alpha=[0.25])


In [34]:
transcripts = decoder.batched_beam_search_torch(log_probs, log_probs_length)

RuntimeError: The size of tensor a (29) must match the size of tensor b (26) at non-singleton dimension 2