## Multi-task auxiliary objectives update
This notebook has been updated to work with the extended multi-task ASR head. You can enable or disable CTC, seq2seq, frame-level phoneme classifiers, speaker embeddings, and pronunciation error classifiers individually through `Configs/config.yml` under the `multi_task` section. Adjust the corresponding `loss_weights` entries when experimenting with these objectives.

> **Entropy regularisation**: Use the new `regularization.entropy` section in `Configs/config.yml` to enable or disable per-head entropy penalties. You can independently attach the regulariser to the CTC and seq2seq objectives without breaking joint training.

> **Lazy phoneme dictionaries**: Toggle `phoneme_dictionary.lazy_loading` and `phoneme_dictionary.shared_cache` in `Configs/config.yml` to control when the phoneme map is parsed and whether dataloader workers reuse the cached mapping.

In [None]:
# Inspect and optionally warm the metadata cache for notebook experiments
import os
import sys

if not os.path.isdir('Configs'):
    os.chdir('..')

_nb_root_dir = os.getcwd()
if _nb_root_dir not in sys.path:
    sys.path.insert(0, _nb_root_dir)

import yaml
from train import prepare_data_list
from utils import get_data_path_list

with open('Configs/config.yml', 'r', encoding='utf-8') as _cfg_file:
    _nb_root_config = yaml.safe_load(_cfg_file)

_metadata_cache_cfg = (_nb_root_config.get('metadata_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"metadata cache enabled --> {bool(_metadata_cache_cfg.get('enabled', False))}")
print(f"metadata cache directory --> {_metadata_cache_cfg.get('directory', 'Data/cache')}")
_dataset_toggles = (_metadata_cache_cfg.get('datasets') or {}) if isinstance(_metadata_cache_cfg, dict) else {}
_normalised_dataset_toggles = {str(k).lower(): bool(v) for k, v in _dataset_toggles.items()} if isinstance(_dataset_toggles, dict) else {}

def _cache_enabled_for(name: str) -> bool:
    if not bool(_metadata_cache_cfg.get('enabled', False)):
        return False
    if not _normalised_dataset_toggles:
        return True
    return bool(_normalised_dataset_toggles.get(name.lower(), True))

(
    _train_raw,
    _val_raw,
    _train_meta_path,
    _val_meta_path,
) = get_data_path_list(
    _nb_root_config.get('train_data'),
    _nb_root_config.get('val_data'),
    return_paths=True,
)

if _cache_enabled_for('train'):
    _train_entries, _train_durations = prepare_data_list(
        _train_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_train_meta_path,
        dataset_name='train',
    )
    print(f"[cache] training entries available: {len(_train_entries)}")

if _cache_enabled_for('val'):
    _val_entries, _val_durations = prepare_data_list(
        _val_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_val_meta_path,
        dataset_name='val',
    )
    print(f"[cache] validation entries available: {len(_val_entries)}")

_ood_metadata_path = _nb_root_config.get('ood_data')
if _ood_metadata_path and _cache_enabled_for('ood'):
    try:
        with open(_ood_metadata_path, 'r', encoding='utf-8') as _ood_file:
            _ood_raw = _ood_file.readlines()
        _ood_entries, _ood_durations = prepare_data_list(
            _ood_raw,
            root_path='',
            cache_config=_metadata_cache_cfg,
            metadata_path=_ood_metadata_path,
            dataset_name='ood',
        )
        print(f"[cache] OOD entries available: {len(_ood_entries)}")
    except FileNotFoundError:
        print(f"[cache] OOD metadata not found at {_ood_metadata_path!r}")


In [None]:
# Inspect mel cache configuration for notebook experiments
_mel_cache_cfg = (_nb_root_config.get('mel_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"mel cache enabled --> {bool(_mel_cache_cfg.get('enabled', False))}")
print(f"mel cache directory --> {_mel_cache_cfg.get('directory', 'Data/mel_cache')}")
print(f"mel cache dtype --> {_mel_cache_cfg.get('dtype', 'float32')}")
_mel_cache_datasets = (_mel_cache_cfg.get('datasets') or {}) if isinstance(_mel_cache_cfg, dict) else {}
if isinstance(_mel_cache_datasets, dict) and _mel_cache_datasets:
    for _name, _flag in sorted(_mel_cache_datasets.items()):
        print(f"[mel cache] {_name}: {bool(_flag)}")


## Memory optimization settings
Lazy creation of decoder masks can now be toggled through the new `memory_optimizations.lazy_masks` section in `Configs/config.yml`. When enabled (the default), the trainer skips preallocating the `future_mask` and `text_mask` tensors unless an experiment explicitly needs them. Set the corresponding flags to `false` to restore the previous eager allocation behaviour.

Gradient checkpointing for the deeper encoder blocks is configurable through `memory_optimizations.gradient_checkpointing`. Setting `enabled: true` activates `torch.utils.checkpoint.checkpoint_sequential` on the selected layer range so activations are recomputed during backpropagation instead of stored. Tune `start_layer`/`end_layer` to target specific encoder depths, `chunk_size` to limit how many stages are bundled into a checkpoint, and `segments` when you need finer control over how each chunk is split. The helper flags `min_sequence_length` and `use_checkpoint_sequential` let you skip short utterances or fall back to the basic checkpoint API if required.

### CTC/seq2seq head sharing
Enable `multi_task.head_sharing.ctc_seq2seq.enabled` in `Configs/config.yml` to reuse the intermediate projection computed for the CTC logits when running the seq2seq decoder. With the feature turned on the encoder exposes the shared tensor via `model_outputs["ctc_seq2seq_shared_states"]` and feeds an adapter into the decoder so both heads reuse the same computation.

Set the flag back to `false` to restore the previous behaviour. The optional `detach_for_seq2seq` switch stops decoder gradients from flowing through the shared branch if you want to isolate the two objectives.


# AuxiliaryASR PER Evaluation

This notebook loads a trained AuxiliaryASR model, prepares the validation dataset, and computes the phoneme error rate (PER) using greedy CTC decoding.

## Augmentation configuration
This notebook now honours the extended SpecAugment policies, waveform perturbations, mixup, and phoneme-level dropout toggles defined in `Configs/config.yml`. Adjust those settings in the config file before running these evaluation cells.


In [None]:
# check available CUDA devices
import torch
devices = []
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        device_name = torch.cuda.get_device_name(i)
        devices.append({
            'type': 'CUDA',
            'available': True,
            'name': device_name,
            'index': i
        })
else:
    devices.append({'type': 'CUDA', 'available': False, 'name': 'N/A'})
devices

In [None]:
# change folder into the root of the ASR project
import os

if not os.path.isdir('Configs'):
    %cd ../

!pwd

## Imports and helper utilities

In [None]:
# import packages, define helper utilities
import os
import yaml
import torch
import pandas as pd
from collections import Counter

from models import ASRCNN
from meldataset import build_dataloader
from utils import BatchSizeScheduler
from token_map import build_token_map_from_data
from text_utils import TextCleaner
from utils import build_beam_search_decoder

def cfg_get_nested(cfg: dict, path, default=None, sep='.'):
    """Get a nested value from a dict using a list of keys or a dot-separated string."""
    if isinstance(path, str):
        keys = path.split(sep) if path else []
    else:
        keys = path

    cur = cfg
    for k in keys:
        if isinstance(cur, dict) and k in cur:
            cur = cur[k]
        else:
            return default
    return cur

def load_token_map_from_config(config):
    token_src = config.get('phoneme_maps_path')
    if not token_src:
        return build_token_map_from_data(
            config.get('train_data'),
            config.get('val_data'),
            config.get('ood_data'),
            apply_asr_tokenizer=True,
        )
    if isinstance(token_src, dict):
        return token_src
    csv = pd.read_csv(token_src, header=None).values
    return {word: index for word, index in csv}

def load_asr_model(model_path, config_path, device):
    with open(config_path) as f:
        config = yaml.safe_load(f)

    token_map = load_token_map_from_config(config)

    model_params = cfg_get_nested(config, 'model_params', {
        'input_dim': 80,
        'hidden_dim': 256,
        'n_token': len(token_map),
        'token_embedding_dim': 512,
        'n_layers': 5,
        'location_kernel_size': 31,
    })
    if 'n_token' not in model_params:
        model_params['n_token'] = len(token_map)

    model_params.setdefault('stabilization_config', cfg_get_nested(config, 'stabilization', {}))

    model = ASRCNN(**model_params)
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
    state_dict = checkpoint.get('model', checkpoint)
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        sanitized_state = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(sanitized_state)

    model.to(device)
    model.eval()
    return model, config, token_map

def build_dev_dataloader(config, device, batch_size=None, num_workers=None):
    val_csv_path = config.get('val_data')
    if val_csv_path is None:
        raise ValueError("Validation CSV path ('val_data') not found in the config.")

    with open(val_csv_path, 'r', encoding='utf-8') as f:
        raw_lines = [line.rstrip('\n') for line in f]

    path_list = []
    for raw in raw_lines:
        if not raw.strip():
            continue
        parts = raw.split('|')
        if len(parts) == 1:
            continue
        path = parts[0]
        if len(parts) == 2:
            text = parts[1]
            speaker = ''
        else:
            text = '|'.join(parts[1:-1])
            speaker = parts[-1]
        path_list.append([path, text, speaker])

    if batch_size is None:
        base_batch_size = int(cfg_get_nested(config, 'eval_params.batch_size', cfg_get_nested(config, 'batch_size', 4)))
        curriculum_cfg = cfg_get_nested(config, 'training_curriculum.batch_size_schedule', {}) or {}
        scheduler = BatchSizeScheduler(curriculum_cfg, default_batch_size=base_batch_size, total_epochs=int(cfg_get_nested(config, 'epochs', 1)))
        if scheduler.enabled:
            eval_epoch = cfg_get_nested(curriculum_cfg, 'evaluation_epoch', None)
            if eval_epoch is None:
                eval_epoch = scheduler.evaluation_epoch if scheduler.evaluation_epoch is not None else scheduler.total_epochs
            eval_epoch = max(1, min(int(eval_epoch), scheduler.total_epochs))
            batch_size = scheduler.batch_size_for_epoch(eval_epoch)
        else:
            batch_size = base_batch_size
    if num_workers is None:
        num_workers = int(cfg_get_nested(config, 'dataloader_params.val_num_workers', 2))

    dataset_params = resolve_dataset_params(config)

    collate_config = {'return_wave': False}
    device_flag = device.type if isinstance(device, torch.device) else str(device)
    loader = build_dataloader(
        path_list=path_list,
        validation=True,
        batch_size=batch_size,
        num_workers=num_workers,
        device=device_flag,
        collate_config=collate_config,
        dataset_config=dataset_params,
        dataset_name="val",
    )
    return loader

if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


In [None]:
def resolve_dataset_params(config, base_overrides=None):
    dataset_params = {
        'dict_path': cfg_get_nested(config, 'phoneme_maps_path', 'Data/word_index_dict.txt'),
        'sr': cfg_get_nested(config, 'preprocess_params.sr', cfg_get_nested(config, 'preprocess_parasm.sr', 24000)),
        'spect_params': cfg_get_nested(
            config,
            'preprocess_params.spect_params',
            cfg_get_nested(config, 'preprocess_parasm.spect_params', {'n_fft': 1024, 'win_length': 1024, 'hop_length': 300}),
        ),
        'mel_params': cfg_get_nested(
            config,
            'preprocess_params.mel_params',
            cfg_get_nested(config, 'preprocess_parasm.mel_params', {'n_mels': 80}),
        ),
    }
    dataset_params['mel_cache'] = cfg_get_nested(config, 'mel_cache', {}) or {}
    dataset_params['phoneme_dictionary_config'] = cfg_get_nested(config, 'phoneme_dictionary', {}) or {}
    dataset_overrides = cfg_get_nested(config, 'dataset_params', {})
    if isinstance(dataset_overrides, dict):
        for key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config'):
            if key in dataset_overrides:
                dataset_params[key] = dataset_overrides[key]
        if 'spec_augment' in dataset_overrides:
            dataset_params['spec_augment_params'] = dataset_overrides['spec_augment']
        for key, value in dataset_overrides.items():
            if key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config', 'spec_augment'):
                continue
            dataset_params[key] = value
    if base_overrides:
        dataset_params.update(base_overrides)
    return dataset_params



## Load model, configuration, and validation loader

In [None]:
checkpoint_dir = 'Checkpoint'
config_path = 'Checkpoint/config.yml'

if not os.path.isdir(checkpoint_dir):
    raise FileNotFoundError(f"Checkpoint directory '{checkpoint_dir}' not found.")

ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pth')]
if not ckpt_files:
    raise FileNotFoundError(f"No checkpoint files found in '{checkpoint_dir}'.")

ckpt_files = sorted(ckpt_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
model_path = os.path.join(checkpoint_dir, ckpt_files[-1])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, config, token_map = load_asr_model(model_path, config_path, device)

print(f'model --> {model_path}')
print(f'config --> {config_path}')
phoneme_source = config.get('phoneme_maps_path', 'built from dataset')
print(f'dictionary --> {phoneme_source}')

dev_loader = build_dev_dataloader(config, device)
print(f'Validation dataset size: {len(dev_loader.dataset)} samples')

vocab = token_map
if ' ' not in vocab:
    raise KeyError("The vocabulary does not contain the blank symbol ' '.")
BLANK_ID = vocab[' ']
ID2PH = {idx: symbol for symbol, idx in vocab.items()}
print(f'Blank token id: {BLANK_ID}')


decoder = build_beam_search_decoder(config, vocab_size=len(vocab))
if decoder is None:
    print('Decoding strategy: CTC greedy')
else:
    fusion_flags = []
    if decoder.shallow_fusion_lm is not None:
        fusion_flags.append('shallow fusion LM')
    if decoder.cold_fusion_lm is not None:
        fusion_flags.append('cold fusion LM')
    fusion_desc = ' with ' + ' and '.join(fusion_flags) if fusion_flags else ''
    print(f'Decoding strategy: Beam search (beam={decoder.beam_width}){fusion_desc}')

if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


## Greedy CTC decoding and PER computation

In [None]:
# decoding utilities and PER evaluation
from typing import List, Sequence, Optional


def ctc_decode(logits: torch.Tensor, lens: torch.Tensor, decoder: Optional[object] = None) -> List[List[int]]:
    """Decode logits either with beam search or greedy collapse."""
    if logits.dim() != 3:
        raise ValueError(f"Expected logits of shape (B, T, V), got {tuple(logits.shape)}")

    if decoder is None:
        pred_ids = logits.argmax(-1)
        hyps: List[List[int]] = []
        for b in range(pred_ids.size(0)):
            prev = BLANK_ID
            out: List[int] = []
            T = int(lens[b])
            for t in range(T):
                p = int(pred_ids[b, t])
                if p != BLANK_ID and p != prev:
                    out.append(p)
                prev = p
            hyps.append(out)
        return hyps

    return decoder.decode(logits, lens)


def edit_distance(a: Sequence[int], b: Sequence[int]) -> int:
    dp = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
    for i in range(len(a) + 1):
        dp[i][0] = i
    for j in range(len(b) + 1):
        dp[0][j] = j
    for i in range(1, len(a) + 1):
        for j in range(1, len(b) + 1):
            dp[i][j] = min(
                dp[i - 1][j] + 1,
                dp[i][j - 1] + 1,
                dp[i - 1][j - 1] + (a[i - 1] != b[j - 1]),
            )
    return dp[-1][-1]


@torch.no_grad()
def eval_per(model: torch.nn.Module, dev_loader, device=None, max_examples: int = 5, decoder: Optional[object] = None):
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    tot_err, tot_len = 0, 0
    phoneme_freq: Counter = Counter()
    examples = []
    downsample_factor = 2 ** getattr(model, 'n_down', 1)

    for batch in dev_loader:
        texts, text_lens, mels, mel_lens = batch[:4]
        mels = mels.to(device)
        text_lens = text_lens.to(torch.long)
        mel_lens = mel_lens.to(torch.long)

        outputs = model(mels)
        if isinstance(outputs, dict):
            logits = outputs.get('logits_ctc')
            if logits is None:
                logits = outputs.get('ctc_logits')
            if logits is None:
                logits = outputs.get('primary_logits')
            if logits is None:
                raise KeyError("Model output dict does not contain CTC logits.")
        elif isinstance(outputs, (tuple, list)):
            logits = outputs[0]
        else:
            logits = outputs

        if logits.dim() != 3:
            raise ValueError(f"Unexpected logits shape: {tuple(logits.shape)}")

        logit_lens = torch.clamp(mel_lens // downsample_factor, min=1, max=logits.size(1))
        hyps = ctc_decode(logits.cpu(), logit_lens.cpu(), decoder)

        for hyp, tgt, tgt_len in zip(hyps, texts, text_lens):
            effective_len = int(tgt_len)
            tgt_ids = tgt[:effective_len].tolist()
            tgt_ids = [idx for idx in tgt_ids if idx != BLANK_ID]

            phoneme_freq.update(tgt_ids)
            tot_err += edit_distance(hyp, tgt_ids)
            tot_len += len(tgt_ids)

            if len(examples) < max_examples:
                examples.append({
                    'prediction': hyp,
                    'reference': tgt_ids,
                })

    per = tot_err / max(1, tot_len)
    stats = {
        'total_errors': tot_err,
        'total_phonemes': tot_len,
        'phoneme_frequency': phoneme_freq,
        'examples': examples,
    }
    return per, stats


## Run evaluation

In [None]:
decode_mode = 'beam search' if decoder is not None else 'CTC greedy'
per, per_stats = eval_per(model, dev_loader, device=device, max_examples=5, decoder=decoder)
print(f'Dev PER ({decode_mode}): {per:.3f}')
print(f'Total phonemes evaluated: {per_stats["total_phonemes"]}')
print(f'Total edit distance: {per_stats["total_errors"]}')


## Inspect sample predictions

In [None]:
def ids_to_symbols(ids):
    return ' '.join(ID2PH.get(i, f"<unk:{i}>") for i in ids)

for idx, example in enumerate(per_stats["examples"], 1):
    print(f'Sample {idx}')
    print('Reference :', ids_to_symbols(example['reference']))
    print('Prediction:', ids_to_symbols(example['prediction']))
    print('-' * 60)
    if idx >= 5:
        break


In [None]:
print('Most common phonemes in the validation references:')
for phoneme_id, count in per_stats["phoneme_frequency"].most_common(10):
    symbol = ID2PH.get(phoneme_id, phoneme_id)
    print(f'{symbol}: {count}')


## Self-Conditioned CTC Support

This notebook has been updated to work with the optional self-conditioned CTC feature. To experiment with it, enable the feature in `Configs/config.yml` by setting:

```yaml
stabilization:
  self_conditioned_ctc:
    enabled: true
    layers:
      - index: 2
      - index: 4
    conditioning_strategy: add  # or concat
    detach_conditioning: true
loss_weights:
  self_conditioned_ctc: 0.2
```

With the feature enabled the trainer will expose `self_conditioned_ctc_logits` and `self_conditioned_ctc_log_probs` in the model outputs so they can be visualised or scored inside the notebook just like the existing CTC signals.

```yaml
regularization:
  entropy:
    enabled: false
    mode: minimize  # set to "maximize" to encourage smoother distributions
    eps: 1.0e-6
    targets:
      ctc:
        enabled: true
        weight: 0.0
        length_normalize: true
      s2s:
        enabled: true
        weight: 0.0
        length_normalize: true
```

In [None]:
import yaml

def _resolve_mixed_precision_from_config(_cfg):
    if 'cfg_get_nested' in globals():
        return cfg_get_nested(_cfg, 'precision.mixed_precision', {})
    precision_cfg = _cfg.get('precision', {}) if isinstance(_cfg, dict) else {}
    return precision_cfg.get('mixed_precision', {}) if isinstance(precision_cfg, dict) else {}

if 'config' in globals():
    _mp_root_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_root_config = yaml.safe_load(_f)

mixed_precision_cfg = _resolve_mixed_precision_from_config(_mp_root_config) or {}
grad_scaler_cfg = mixed_precision_cfg.get('grad_scaler', {}) if isinstance(mixed_precision_cfg, dict) else {}
print(f"mixed precision enabled --> {bool(mixed_precision_cfg.get('enabled', False))}")
print(f"mixed precision dtype --> {mixed_precision_cfg.get('dtype', 'float16')}")
print(f"grad scaler enabled --> {bool(grad_scaler_cfg.get('enabled', True))}")
