## Multi-task auxiliary objectives update
This notebook has been updated to work with the extended multi-task ASR head. You can enable or disable CTC, seq2seq, frame-level phoneme classifiers, speaker embeddings, and pronunciation error classifiers individually through `Configs/config.yml` under the `multi_task` section. Adjust the corresponding `loss_weights` entries when experimenting with these objectives.

> **Entropy regularisation**: Use the new `regularization.entropy` section in `Configs/config.yml` to enable or disable per-head entropy penalties. You can independently attach the regulariser to the CTC and seq2seq objectives without breaking joint training.

> **Lazy phoneme dictionaries**: Toggle `phoneme_dictionary.lazy_loading` and `phoneme_dictionary.shared_cache` in `Configs/config.yml` to control when the phoneme map is parsed and whether dataloader workers reuse the cached mapping.

In [None]:
# Inspect and optionally warm the metadata cache for notebook experiments
import os
import sys

if not os.path.isdir('Configs'):
    os.chdir('..')

_nb_root_dir = os.getcwd()
if _nb_root_dir not in sys.path:
    sys.path.insert(0, _nb_root_dir)

import yaml
from train import prepare_data_list
from models import ASRCNN
from utils import get_data_path_list, load_asr_model_from_config, build_dev_dataloader_from_config

with open('Configs/config.yml', 'r', encoding='utf-8') as _cfg_file:
    _nb_root_config = yaml.safe_load(_cfg_file)

_metadata_cache_cfg = (_nb_root_config.get('metadata_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"metadata cache enabled --> {bool(_metadata_cache_cfg.get('enabled', False))}")
print(f"metadata cache directory --> {_metadata_cache_cfg.get('directory', 'Data/cache')}")
_dataset_toggles = (_metadata_cache_cfg.get('datasets') or {}) if isinstance(_metadata_cache_cfg, dict) else {}
_normalised_dataset_toggles = {str(k).lower(): bool(v) for k, v in _dataset_toggles.items()} if isinstance(_dataset_toggles, dict) else {}

def _cache_enabled_for(name: str) -> bool:
    if not bool(_metadata_cache_cfg.get('enabled', False)):
        return False
    if not _normalised_dataset_toggles:
        return True
    return bool(_normalised_dataset_toggles.get(name.lower(), True))

(
    _train_raw,
    _val_raw,
    _train_meta_path,
    _val_meta_path,
) = get_data_path_list(
    _nb_root_config.get('train_data'),
    _nb_root_config.get('val_data'),
    return_paths=True,
)

if _cache_enabled_for('train'):
    _train_entries, _train_durations = prepare_data_list(
        _train_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_train_meta_path,
        dataset_name='train',
    )
    print(f"[cache] training entries available: {len(_train_entries)}")

if _cache_enabled_for('val'):
    _val_entries, _val_durations = prepare_data_list(
        _val_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_val_meta_path,
        dataset_name='val',
    )
    print(f"[cache] validation entries available: {len(_val_entries)}")

_ood_metadata_path = _nb_root_config.get('ood_data')
if _ood_metadata_path and _cache_enabled_for('ood'):
    try:
        with open(_ood_metadata_path, 'r', encoding='utf-8') as _ood_file:
            _ood_raw = _ood_file.readlines()
        _ood_entries, _ood_durations = prepare_data_list(
            _ood_raw,
            root_path='',
            cache_config=_metadata_cache_cfg,
            metadata_path=_ood_metadata_path,
            dataset_name='ood',
        )
        print(f"[cache] OOD entries available: {len(_ood_entries)}")
    except FileNotFoundError:
        print(f"[cache] OOD metadata not found at {_ood_metadata_path!r}")


In [None]:
# Inspect mel cache configuration for notebook experiments
_mel_cache_cfg = (_nb_root_config.get('mel_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"mel cache enabled --> {bool(_mel_cache_cfg.get('enabled', False))}")
print(f"mel cache directory --> {_mel_cache_cfg.get('directory', 'Data/mel_cache')}")
print(f"mel cache dtype --> {_mel_cache_cfg.get('dtype', 'float32')}")
_mel_cache_datasets = (_mel_cache_cfg.get('datasets') or {}) if isinstance(_mel_cache_cfg, dict) else {}
if isinstance(_mel_cache_datasets, dict) and _mel_cache_datasets:
    for _name, _flag in sorted(_mel_cache_datasets.items()):
        print(f"[mel cache] {_name}: {bool(_flag)}")


## Memory optimization settings
Lazy creation of decoder masks can now be toggled through the new `memory_optimizations.lazy_masks` section in `Configs/config.yml`. When enabled (the default), the trainer skips preallocating the `future_mask` and `text_mask` tensors unless an experiment explicitly needs them. Set the corresponding flags to `false` to restore the previous eager allocation behaviour.

Gradient checkpointing for the deeper encoder blocks is configurable through `memory_optimizations.gradient_checkpointing`. Setting `enabled: true` activates `torch.utils.checkpoint.checkpoint_sequential` on the selected layer range so activations are recomputed during backpropagation instead of stored. Tune `start_layer`/`end_layer` to target specific encoder depths, `chunk_size` to limit how many stages are bundled into a checkpoint, and `segments` when you need finer control over how each chunk is split. The helper flags `min_sequence_length` and `use_checkpoint_sequential` let you skip short utterances or fall back to the basic checkpoint API if required.

### CTC/seq2seq head sharing
Enable `multi_task.head_sharing.ctc_seq2seq.enabled` in `Configs/config.yml` to reuse the intermediate projection computed for the CTC logits when running the seq2seq decoder. With the feature turned on the encoder exposes the shared tensor via `model_outputs["ctc_seq2seq_shared_states"]` and feeds an adapter into the decoder so both heads reuse the same computation.

Set the flag back to `false` to restore the previous behaviour. The optional `detach_for_seq2seq` switch stops decoder gradients from flowing through the shared branch if you want to isolate the two objectives.


# Auxiliary ASR Diagonal Attention Evaluation

## Augmentation configuration
This notebook now honours the extended SpecAugment policies, waveform perturbations, mixup, and phoneme-level dropout toggles defined in `Configs/config.yml`. Adjust those settings in the config file before running these evaluation cells.


In [None]:

# check available CUDA devices
import torch
devices = []
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        device_name = torch.cuda.get_device_name(i)
        devices.append({
            'type': 'CUDA',
            'available': True,
            'name': device_name,
            'index': i
        })
else:
    devices.append({'type': 'CUDA', 'available': False, 'name': 'N/A'})
devices


In [None]:

# change folder into the root of the ASR project
import os

if not os.path.isdir('Configs'):
    %cd ../

!pwd


In [None]:
def resolve_dataset_params(config, base_overrides=None):
    dataset_params = {
        'dict_path': cfg_get_nested(config, 'phoneme_maps_path', 'Data/word_index_dict.txt'),
        'sr': cfg_get_nested(config, 'preprocess_params.sr', cfg_get_nested(config, 'preprocess_parasm.sr', 24000)),
        'spect_params': cfg_get_nested(
            config,
            'preprocess_params.spect_params',
            cfg_get_nested(config, 'preprocess_parasm.spect_params', {'n_fft': 1024, 'win_length': 1024, 'hop_length': 300}),
        ),
        'mel_params': cfg_get_nested(
            config,
            'preprocess_params.mel_params',
            cfg_get_nested(config, 'preprocess_parasm.mel_params', {'n_mels': 80}),
        ),
    }
    dataset_params['mel_cache'] = cfg_get_nested(config, 'mel_cache', {}) or {}
    dataset_params['phoneme_dictionary_config'] = cfg_get_nested(config, 'phoneme_dictionary', {}) or {}
    dataset_overrides = cfg_get_nested(config, 'dataset_params', {})
    if isinstance(dataset_overrides, dict):
        for key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config'):
            if key in dataset_overrides:
                dataset_params[key] = dataset_overrides[key]
        if 'spec_augment' in dataset_overrides:
            dataset_params['spec_augment_params'] = dataset_overrides['spec_augment']
        for key, value in dataset_overrides.items():
            if key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config', 'spec_augment'):
                continue
            dataset_params[key] = value
    if base_overrides:
        dataset_params.update(base_overrides)
    return dataset_params



In [None]:
# load model, config, and validation dataloader
checkpoint_dir = 'Checkpoint'
config_path = 'Checkpoint/config.yml'

if not os.path.isdir(checkpoint_dir):
    raise FileNotFoundError(f"Checkpoint directory '{checkpoint_dir}' not found.")

ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pth')]
if not ckpt_files:
    raise FileNotFoundError(f"No checkpoint files found in '{checkpoint_dir}'.")

ckpt_files = sorted(ckpt_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
model_path = os.path.join(checkpoint_dir, ckpt_files[-1])

with open(config_path, 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

print(f'model --> {model_path}')
print(f'config --> {config_path}')
phoneme_source = config.get('phoneme_maps_path', 'built from dataset')
print(f'dictionary --> {phoneme_source}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, token_map = load_asr_model_from_config(config, model_path, device)

dev_loader, val_entries = build_dev_dataloader_from_config(config, device)
print(f'Validation dataset size: {len(val_entries)} samples')
print(f'Batch size --> {dev_loader.batch_size}')

if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


In [None]:
# attention alignment diagnostics
import numpy as np
from typing import List, Optional

@torch.no_grad()
def diagonal_attention_score(model: ASRCNN, dev_loader, device: Optional[torch.device] = None, band: float = 0.1, max_batches: Optional[int] = None):
    """Compute the average diagonal attention concentration score.

    Args:
        model: ASR model returning alignment matrices when teacher-forced.
        dev_loader: validation dataloader yielding (texts, text_lens, mels, mel_lens).
        device: device for computation; defaults to model parameters' device.
        band: allowable deviation from the diagonal (0-1 range).
        max_batches: limit the number of batches to evaluate (useful for quick checks).

    Returns:
        mean_score: average ratio of attention mass inside the diagonal band.
        scores: list of per-utterance scores.
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    diag_scores: List[float] = []
    downsample = 2 ** getattr(model, 'n_down', 1)

    for batch_idx, batch in enumerate(dev_loader):
        if max_batches is not None and batch_idx >= max_batches:
            break

        texts, text_lens, mels, mel_lens = batch[:4]
        mels = mels.to(device)
        texts = texts.to(device)
        text_lens = text_lens.to(torch.long)
        mel_lens = mel_lens.to(device=device, dtype=torch.long)

        reduced_mel_lens = torch.clamp(mel_lens // downsample, min=1)
        mel_mask = model.length_to_mask(reduced_mel_lens)

        outputs = model(mels, src_key_padding_mask=mel_mask, text_input=texts)
        attn = None
        if isinstance(outputs, dict):
            for key in ('s2s_attn', 'attn', 'attention', 'alignments'):
                tensor = outputs.get(key)
                if tensor is not None:
                    attn = tensor
                    break
            if attn is None:
                available = ', '.join(outputs.keys())
                raise KeyError(f'Model output dictionary does not contain attention matrices. Available keys: {available}')
        elif isinstance(outputs, (tuple, list)):
            if len(outputs) < 3:
                raise ValueError('Model forward output does not include attention tensors.')
            attn = outputs[2]
        else:
            raise TypeError('Unsupported model output type for attention extraction.')

        attn = attn.detach()
        time_axis = attn.size(-1)
        output_axis = attn.size(1)

        text_lens_list = text_lens.tolist()
        mel_lens_list = reduced_mel_lens.tolist()

        for b in range(attn.size(0)):
            To = min(int(text_lens_list[b]), output_axis)
            Te = min(int(mel_lens_list[b]), time_axis)
            if To <= 1 or Te <= 1:
                continue

            a = attn[b, :To, :Te]
            total_mass = a.sum()
            if torch.isclose(total_mass, torch.tensor(0.0, device=a.device)):
                continue

            t = torch.arange(To, device=a.device, dtype=torch.float32).unsqueeze(1)
            e = torch.arange(Te, device=a.device, dtype=torch.float32).unsqueeze(0)
            t = t / (To - 1) if To > 1 else torch.zeros_like(t)
            e = e / (Te - 1) if Te > 1 else torch.zeros_like(e)
            diag = t - e
            mask = (diag.abs() <= band).to(a.dtype)

            score = (a * mask).sum() / total_mass.clamp_min(1e-8)
            diag_scores.append(float(score))

    mean_score = float(np.mean(diag_scores)) if diag_scores else float('nan')
    return mean_score, diag_scores


### Diagonal Attention Score
- scores > 0.6-0.7 are usually good; trending higher across training is a healthy sign.

In [None]:
# evaluate the diagonal attention score
mean_score, scores = diagonal_attention_score(model, dev_loader, device=device, band=0.1, max_batches=None)
print(f'Diagonal attention score: {mean_score:.4f}')
print(f'Evaluated {len(scores)} alignments')

if scores:
    scores_array = np.array(scores)
    print('Score statistics:')
    print(f'  min:  {scores_array.min():.4f}')
    print(f'  25%:  {np.percentile(scores_array, 25):.4f}')
    print(f'  median: {np.median(scores_array):.4f}')
    print(f'  75%:  {np.percentile(scores_array, 75):.4f}')
    print(f'  max:  {scores_array.max():.4f}')


## Self-Conditioned CTC Support

This notebook has been updated to work with the optional self-conditioned CTC feature. To experiment with it, enable the feature in `Configs/config.yml` by setting:

```yaml
stabilization:
  self_conditioned_ctc:
    enabled: true
    layers:
      - index: 2
      - index: 4
    conditioning_strategy: add  # or concat
    detach_conditioning: true
loss_weights:
  self_conditioned_ctc: 0.2
```

With the feature enabled the trainer will expose `self_conditioned_ctc_logits` and `self_conditioned_ctc_log_probs` in the model outputs so they can be visualised or scored inside the notebook just like the existing CTC signals.

```yaml
regularization:
  entropy:
    enabled: false
    mode: minimize  # set to "maximize" to encourage smoother distributions
    eps: 1.0e-6
    targets:
      ctc:
        enabled: true
        weight: 0.0
        length_normalize: true
      s2s:
        enabled: true
        weight: 0.0
        length_normalize: true
```

In [None]:
import yaml

def _resolve_mixed_precision_from_config(_cfg):
    if 'cfg_get_nested' in globals():
        return cfg_get_nested(_cfg, 'precision.mixed_precision', {})
    precision_cfg = _cfg.get('precision', {}) if isinstance(_cfg, dict) else {}
    return precision_cfg.get('mixed_precision', {}) if isinstance(precision_cfg, dict) else {}

if 'config' in globals():
    _mp_root_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_root_config = yaml.safe_load(_f)

mixed_precision_cfg = _resolve_mixed_precision_from_config(_mp_root_config) or {}
grad_scaler_cfg = mixed_precision_cfg.get('grad_scaler', {}) if isinstance(mixed_precision_cfg, dict) else {}
print(f"mixed precision enabled --> {bool(mixed_precision_cfg.get('enabled', False))}")
print(f"mixed precision dtype --> {mixed_precision_cfg.get('dtype', 'float16')}")
print(f"grad scaler enabled --> {bool(grad_scaler_cfg.get('enabled', True))}")
