In [22]:
%load_ext autoreload
%autoreload 2
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer
from inference_functions import load_bit_phoneme_model
from dataset import getDatasetLoaders
from edit_distance import SequenceMatcher

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
language_model_path = "/workspace/transformers_with_dietcorp/lm/char_6gram_lm.nemo"

In [3]:
lm = NGramGPULanguageModel.from_nemo(
    lm_path=language_model_path,
    vocab_size=26
)

[NeMo I 2025-08-23 22:02:54 save_restore_connector:282] Model NGramGPULanguageModel was successfully restored from /workspace/transformers_with_dietcorp/lm/char_6gram_lm.nemo.


In [4]:
decoder = BatchedBeamCTCComputer(blank_index=26, beam_size=16, 
                                 return_best_hypothesis=False, 
                                 fusion_models=[lm], 
                                 fusion_models_alpha=[0.2])


In [10]:
device = 'cuda'
bit_phoneme_filepath = "/data/models/time_masked_transfomer_characters_80ms_seed_0/"
model, args = load_bit_phoneme_model(bit_phoneme_filepath)
model = model.to(device)

In [11]:
data_file = '/data/neural_data/ptDecoder_ctc_both_char'
trainLoaders, testLoaders, loadedData = getDatasetLoaders(
        data_file, 8, None, 
        False
    )

In [32]:
import re
import numpy as np
import torch
from typing import Dict, Any, List, Tuple
from dataset import SpeechDataset  # adjust if your path differs

def evaluate_model(
    model: torch.nn.Module,
    loadedData: Dict[str, List[Dict[str, Any]]],
    args: Dict[str, Any],
    partition: str,               # "test" or "competition"
    device: torch.device,
    fill_max_day: bool = False,   # optional, keep behavior you had
    verbose: bool = True
) -> Tuple[Dict[str, List[Any]], float, List[float]]:
    """
    Minimal evaluation: runs `model` over `partition`, collects outputs, and computes CER.
    Returns (model_outputs, overall_CER, per_day_CER_list).
    """

    # Decide day indices
    if partition == "competition":
        day_indices = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]
    elif partition == "test":
        day_indices = list(range(len(loadedData[partition])))
    else:
        raise ValueError(f"Unknown partition '{partition}'")

    # Pull common flags from args with safe defaults
    restricted_days = set(args.get("restricted_days", []))
    ventral_6v_only = bool(args.get("ventral_6v_only", False))

    # Accumulators
    outputs = {"logits": [], "logitLengths": [], "trueSeqs": [], "transcriptions": []}
    per_day_cer: List[float] = []
    total_edit, total_len = 0, 0

    model.eval()

    for idx_in_enum, day_idx in enumerate(day_indices):
        if restricted_days and (day_idx not in restricted_days):
            continue

        # one-day dataset/loader (mirror your original)
        one_day = loadedData[partition][idx_in_enum]
        loader = torch.utils.data.DataLoader(SpeechDataset([one_day]), batch_size=1, shuffle=False, num_workers=0)

        day_edit, day_len = 0, 0

        for j, (X, y, X_len, y_len, _) in enumerate(loader):
            X, y, X_len, y_len = X.to(device), y.to(device), X_len.to(device), y_len.to(device)
            day_tensor = torch.tensor([day_idx], dtype=torch.int64, device=device)

            if ventral_6v_only:
                X = X[:, :, :128]


            with torch.no_grad():
                pred = model.forward(X, X_len, day_tensor)

            # Output lengths
            if hasattr(model, "compute_length"):
                out_lens = model.compute_length(X_len)
            else:
                # fallback: conv-style
                out_lens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32)

            # Batch loop (batch_size=1, but keep general)
            for b in range(pred.shape[0]):
                tlen = int(y_len[b].item())
                true_seq = np.array(y[b][:tlen].cpu().numpy())

                logits_b = pred[b].detach().cpu().numpy()
                Lb = int(out_lens[b].item())

                outputs["logits"].append(logits_b)
                outputs["logitLengths"].append(Lb)
                outputs["trueSeqs"].append(true_seq)

                # Greedy CTC decode (blank=0), collapse repeats
                decoded = torch.argmax(pred[b, :Lb, :], dim=-1)
                decoded = torch.unique_consecutive(decoded).cpu().numpy()
                decoded = decoded[decoded != 0]
                 
                matcher = SequenceMatcher(
                    a=true_seq.tolist(), b=decoded.tolist()
                )
            
                ed = matcher.distance()
                total_edit += ed
                total_len += len(true_seq)
                day_edit += ed
                day_len += len(true_seq)

            # normalized transcript
            t = one_day["transcriptions"][j].strip()
            t = re.sub(r"[^a-zA-Z\- \']", "", t).replace("--", "").lower()
            outputs["transcriptions"].append(t)

        if day_len > 0:
            day_cer = day_edit / day_len
            per_day_cer.append(day_cer)
            if verbose:
                print(f"CER DAY {day_idx}: {day_cer:.6f}")

    cer = (total_edit / total_len) if total_len > 0 else float("nan")
    if verbose:
        print("Model performance (CER):", cer)

    return outputs, cer, per_day_cer


In [33]:
outputs, cer, per_day_cer = evaluate_model(model, loadedData, args, partition='test', device='cuda')

torch.Size([1, 64, 41])


TypeError: cannot unpack non-iterable int object

In [31]:
outputs['logits'][10].shape

(123, 41)