## Multi-task auxiliary objectives update
This notebook has been updated to work with the extended multi-task ASR head. You can enable or disable CTC, seq2seq, frame-level phoneme classifiers, speaker embeddings, and pronunciation error classifiers individually through `Configs/config.yml` under the `multi_task` section. Adjust the corresponding `loss_weights` entries when experimenting with these objectives.

> **Entropy regularisation**: Use the new `regularization.entropy` section in `Configs/config.yml` to enable or disable per-head entropy penalties. You can independently attach the regulariser to the CTC and seq2seq objectives without breaking joint training.

> **Lazy phoneme dictionaries**: Toggle `phoneme_dictionary.lazy_loading` and `phoneme_dictionary.shared_cache` in `Configs/config.yml` to control when the phoneme map is parsed and whether dataloader workers reuse the cached mapping.

In [None]:
# Inspect and optionally warm the metadata cache for notebook experiments
import os
import sys

if not os.path.isdir('Configs'):
    os.chdir('..')

_nb_root_dir = os.getcwd()
if _nb_root_dir not in sys.path:
    sys.path.insert(0, _nb_root_dir)

import yaml
from train import prepare_data_list
from utils import get_data_path_list

with open('Configs/config.yml', 'r', encoding='utf-8') as _cfg_file:
    _nb_root_config = yaml.safe_load(_cfg_file)

_metadata_cache_cfg = (_nb_root_config.get('metadata_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"metadata cache enabled --> {bool(_metadata_cache_cfg.get('enabled', False))}")
print(f"metadata cache directory --> {_metadata_cache_cfg.get('directory', 'Data/cache')}")
_dataset_toggles = (_metadata_cache_cfg.get('datasets') or {}) if isinstance(_metadata_cache_cfg, dict) else {}
_normalised_dataset_toggles = {str(k).lower(): bool(v) for k, v in _dataset_toggles.items()} if isinstance(_dataset_toggles, dict) else {}

def _cache_enabled_for(name: str) -> bool:
    if not bool(_metadata_cache_cfg.get('enabled', False)):
        return False
    if not _normalised_dataset_toggles:
        return True
    return bool(_normalised_dataset_toggles.get(name.lower(), True))

(
    _train_raw,
    _val_raw,
    _train_meta_path,
    _val_meta_path,
) = get_data_path_list(
    _nb_root_config.get('train_data'),
    _nb_root_config.get('val_data'),
    return_paths=True,
)

if _cache_enabled_for('train'):
    _train_entries, _train_durations = prepare_data_list(
        _train_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_train_meta_path,
        dataset_name='train',
    )
    print(f"[cache] training entries available: {len(_train_entries)}")

if _cache_enabled_for('val'):
    _val_entries, _val_durations = prepare_data_list(
        _val_raw,
        root_path='',
        cache_config=_metadata_cache_cfg,
        metadata_path=_val_meta_path,
        dataset_name='val',
    )
    print(f"[cache] validation entries available: {len(_val_entries)}")

_ood_metadata_path = _nb_root_config.get('ood_data')
if _ood_metadata_path and _cache_enabled_for('ood'):
    try:
        with open(_ood_metadata_path, 'r', encoding='utf-8') as _ood_file:
            _ood_raw = _ood_file.readlines()
        _ood_entries, _ood_durations = prepare_data_list(
            _ood_raw,
            root_path='',
            cache_config=_metadata_cache_cfg,
            metadata_path=_ood_metadata_path,
            dataset_name='ood',
        )
        print(f"[cache] OOD entries available: {len(_ood_entries)}")
    except FileNotFoundError:
        print(f"[cache] OOD metadata not found at {_ood_metadata_path!r}")


In [None]:
# Inspect mel cache configuration for notebook experiments
_mel_cache_cfg = (_nb_root_config.get('mel_cache') or {}) if isinstance(_nb_root_config, dict) else {}
print(f"mel cache enabled --> {bool(_mel_cache_cfg.get('enabled', False))}")
print(f"mel cache directory --> {_mel_cache_cfg.get('directory', 'Data/mel_cache')}")
print(f"mel cache dtype --> {_mel_cache_cfg.get('dtype', 'float32')}")
_mel_cache_datasets = (_mel_cache_cfg.get('datasets') or {}) if isinstance(_mel_cache_cfg, dict) else {}
if isinstance(_mel_cache_datasets, dict) and _mel_cache_datasets:
    for _name, _flag in sorted(_mel_cache_datasets.items()):
        print(f"[mel cache] {_name}: {bool(_flag)}")


## Memory optimization settings
Lazy creation of decoder masks can now be toggled through the new `memory_optimizations.lazy_masks` section in `Configs/config.yml`. When enabled (the default), the trainer skips preallocating the `future_mask` and `text_mask` tensors unless an experiment explicitly needs them. Set the corresponding flags to `false` to restore the previous eager allocation behaviour.

Gradient checkpointing for the deeper encoder blocks is configurable through `memory_optimizations.gradient_checkpointing`. Setting `enabled: true` activates `torch.utils.checkpoint.checkpoint_sequential` on the selected layer range so activations are recomputed during backpropagation instead of stored. Tune `start_layer`/`end_layer` to target specific encoder depths, `chunk_size` to limit how many stages are bundled into a checkpoint, and `segments` when you need finer control over how each chunk is split. The helper flags `min_sequence_length` and `use_checkpoint_sequential` let you skip short utterances or fall back to the basic checkpoint API if required.

### CTC/seq2seq head sharing
Enable `multi_task.head_sharing.ctc_seq2seq.enabled` in `Configs/config.yml` to reuse the intermediate projection computed for the CTC logits when running the seq2seq decoder. With the feature turned on the encoder exposes the shared tensor via `model_outputs["ctc_seq2seq_shared_states"]` and feeds an adapter into the decoder so both heads reuse the same computation.

Set the flag back to `false` to restore the previous behaviour. The optional `detach_for_seq2seq` switch stops decoder gradients from flowing through the shared branch if you want to isolate the two objectives.


In [None]:
# check available CUDA devices
import torch
devices = []
if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            device_name = torch.cuda.get_device_name(i)
            devices.append({
                "type": "CUDA",
                "available": True,
                "name": device_name,
                "index": i
            })
else:
    devices.append({"type": "CUDA", "available": False, "name": "N/A"})
devices

## Augmentation configuration
This notebook now honours the extended SpecAugment policies, waveform perturbations, mixup, and phoneme-level dropout toggles defined in `Configs/config.yml`. Adjust those settings in the config file before running these evaluation cells.


In [None]:
# change folder into the root of the ASR project
import os

if not os.path.isdir("Configs"):
    %cd ../

!pwd

In [None]:
# import packages, define common functions
import torch
import yaml
import os
from models import ASRCNN
from utils import select_logits_from_output
from meldataset import build_dataloader
import torch.nn.functional as F
import pandas as pd
import os
import os.path as osp
from text_utils import TextCleaner
import itertools
from jiwer import wer
import re
from token_map import build_token_map_from_data

def cfg_get_nested(cfg: dict, path, default=None, sep="."):
    """
    Get a nested value from a dict using a list of keys or a dot-separated string.

    Examples:
        cfg_get_nested(config, ["model_params", "input_dim"], 80)
        cfg_get_nested(config, "model_params.input_dim", 80)
        cfg_get_nested(config, "top_key", 80)
    """
    if isinstance(path, str):
        keys = path.split(sep)
    else:
        keys = path

    cur = cfg
    for k in keys:
        if isinstance(cur, dict) and k in cur:
            cur = cur[k]
        else:
            return default
    return cur

def length_to_mask(lengths):
        mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
        mask = torch.gt(mask+1, lengths.unsqueeze(1))
        return mask

def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
    def _load_model(model_config, model_path):
        model = ASRCNN(**model_config)
        params = torch.load(model_path, map_location='cpu', weights_only=False)['model']
        try:
            model.load_state_dict(params)
        except Exception as e:
            new_state_dict = {k.replace("module.", ""): v for k, v in params.items()}
            model.load_state_dict(new_state_dict)
        return model

    with open(ASR_MODEL_CONFIG) as f:
        config = yaml.safe_load(f)

    token_src = config.get('phoneme_maps_path')
    if not token_src:
        token_map = build_token_map_from_data(config.get('train_data'), config.get('val_data'), config.get('ood_data'), apply_asr_tokenizer=True)
    elif isinstance(token_src, dict):
        token_map = token_src
    else:
        csv = pd.read_csv(token_src, header=None).values
        token_map = {word: index for word, index in csv}

    default_model_params = {
        'input_dim': 80,
        'hidden_dim': 256,
        'n_token': len(token_map),
        'token_embedding_dim': 512,
        'n_layers': 5,
        'location_kernel_size': 31
    }
    model_params = cfg_get_nested(config, 'model_params', default_model_params)
    if isinstance(model_params, dict):
        model_params = dict(model_params)
    else:
        model_params = dict(default_model_params)

    model_params.setdefault('n_token', len(token_map))
    model_params.setdefault('stabilization_config', cfg_get_nested(config, 'stabilization', {}))

    multi_task_config = cfg_get_nested(config, 'multi_task', {}) or {}
    model_params['multi_task_config'] = multi_task_config

    asr_model = _load_model(model_params, ASR_MODEL_PATH)

    return asr_model


if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


In [None]:
def resolve_dataset_params(config, base_overrides=None):
    dataset_params = {
        'dict_path': cfg_get_nested(config, 'phoneme_maps_path', 'Data/word_index_dict.txt'),
        'sr': cfg_get_nested(config, 'preprocess_params.sr', cfg_get_nested(config, 'preprocess_parasm.sr', 24000)),
        'spect_params': cfg_get_nested(
            config,
            'preprocess_params.spect_params',
            cfg_get_nested(config, 'preprocess_parasm.spect_params', {'n_fft': 1024, 'win_length': 1024, 'hop_length': 300}),
        ),
        'mel_params': cfg_get_nested(
            config,
            'preprocess_params.mel_params',
            cfg_get_nested(config, 'preprocess_parasm.mel_params', {'n_mels': 80}),
        ),
    }
    dataset_params['mel_cache'] = cfg_get_nested(config, 'mel_cache', {}) or {}
    dataset_params['phoneme_dictionary_config'] = cfg_get_nested(config, 'phoneme_dictionary', {}) or {}
    dataset_overrides = cfg_get_nested(config, 'dataset_params', {})
    if isinstance(dataset_overrides, dict):
        for key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config'):
            if key in dataset_overrides:
                dataset_params[key] = dataset_overrides[key]
        if 'spec_augment' in dataset_overrides:
            dataset_params['spec_augment_params'] = dataset_overrides['spec_augment']
        for key, value in dataset_overrides.items():
            if key in ('dict_path', 'sr', 'spect_params', 'mel_params', 'phoneme_dictionary_config', 'spec_augment'):
                continue
            dataset_params[key] = value
    if base_overrides:
        dataset_params.update(base_overrides)
    return dataset_params



In [None]:
checkpoint_dir = "Checkpoint"

files = [f for f in os.listdir( checkpoint_dir + "/") if f.startswith('epoch_') and f.endswith('.pth')]
sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

model_path = checkpoint_dir + "/" + sorted_files[-1]
#model_path = "Checkpoint/epoch_00080.pth"

config_path = 'Checkpoint/config.yml'

config = yaml.safe_load(open(config_path))
phoneme_map = config.get('phoneme_maps_path')
if not phoneme_map:
    phoneme_map = build_token_map_from_data(config.get('train_data'), config.get('val_data'), config.get('ood_data'), apply_asr_tokenizer=True)

test_csv_path = config['val_data']

def _dict_desc(obj):
    return obj if isinstance(obj, str) else 'built from dataset'

print( "model --> " + model_path )
print( "config --> " + config_path)
print( "dictionary --> " + _dict_desc(phoneme_map))
print( "test: --> " + test_csv_path)

model = load_ASR_models(model_path, config_path)
model.eval()

print( "All OK ✓")

if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


In [None]:
text_cleaner = TextCleaner(phoneme_map)
device = "cpu"

with open(test_csv_path, "r", encoding="utf-8") as f:
    test_lines = f.readlines()

dataset_params = resolve_dataset_params(config)

test_loader = build_dataloader(
    path_list=[l[:-1].split('|') for l in test_lines],
    validation=True,
    batch_size=1,
    num_workers=2,
    device=device,
    collate_config={"return_wave": False},
    dataset_config=dataset_params,
    dataset_name="val",
)

if isinstance(phoneme_map, str):
    csv = pd.read_csv(phoneme_map, header=None).values
    wlist = {word: index for word, index in csv}
else:
    wlist = phoneme_map
index2phoneme = {v: k for k, v in wlist.items()}

predictions = []
references = []
cleartexts = []

model.eval()
log_interval = 10
total = len(test_lines)
cntr = 0
maxtestsize = 1
#maxtestsize = 0

with torch.no_grad():
    for batch in test_loader:
        cleartexts.append(test_lines[cntr])

        texts, input_lengths, mels, output_lengths = batch  # from Collater

        mels = mels.to(device)
        output = model(mels)
        logits = select_logits_from_output(output)
        predicted_ids = torch.argmax(logits, dim=-1)

        print(f"Batch {cntr} - Expected text: {test_lines[cntr].strip().split('|')[1]}")
        predicted_text = ''.join([text_cleaner.inverse_mapping.get(phoneme.item(), '<unk>') for phoneme in predicted_ids[0]])
        print(f"Batch {cntr} - Predicted text: {predicted_text}")

        for i in range(predicted_ids.size(0)):
            pred_seq = predicted_ids[i][:output_lengths[i]//3]
            ref_seq = texts[i][:input_lengths[i]]

            pred_phonemes = [index2phoneme.get(p.item(), '') for p in pred_seq]
            ref_phonemes = [index2phoneme.get(r.item(), '') for r in ref_seq]

            predictions.append(pred_phonemes)
            references.append(ref_phonemes)

        cntr += 1
        if (cntr)%log_interval == 0:
            print(f"{cntr} of {total} sentences tested")

        if maxtestsize > 0 and cntr >= maxtestsize:
            print(f"early stop reached at {maxtestsize} sentences")
            break

print("Done - all sententes tested.")


In [None]:
# Clean extra quotes and join tokens into space-separated strings
references_cleaned = [' '.join(token.strip('"') for token in seq) for seq in references]
predictions_cleaned = [' '.join(token.strip('"') for token in seq) for seq in predictions]

# Now use jiwer
per = wer(references_cleaned, predictions_cleaned)
print(f'Phoneme Error Rate: {per:.4f}')

In [None]:
###########################################
# Find the best AUX model (with best PER) #
###########################################

#checkpoint_dir = "Checkpoint-en-test"

files = [f for f in os.listdir( checkpoint_dir + "/") if f.startswith('epoch_') and f.endswith('.pth')]
sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

model_path = checkpoint_dir + "/" + sorted_files[-1]
#model_path = "Checkpoint/epoch_00040.pth"
#config_path = "Configs/config.yml"

config = yaml.safe_load(open(config_path))
phoneme_map = config.get('phoneme_maps_path')
if not phoneme_map:
    phoneme_map = build_token_map_from_data(config.get('train_data'), config.get('val_data'), config.get('ood_data'), apply_asr_tokenizer=True)

test_csv_path = config['val_data']

def _dict_desc(obj):
    return obj if isinstance(obj, str) else 'built from dataset'

print( "config --> " + config_path )
print( "dictionary --> " + _dict_desc(phoneme_map) )
print( "test --> " + test_csv_path )

text_cleaner = TextCleaner(phoneme_map)
#text_cleaner = TextCleaner()
device = "cpu"

with open(test_csv_path, "r", encoding="utf-8") as f:
    test_lines = f.readlines()

dataset_params = resolve_dataset_params(config)

test_loader = build_dataloader(
    #path_list=test_lines,
    path_list=[l[:-1].split('|') for l in test_lines],
    validation=True,
    batch_size=1,
    num_workers=2,
    device=device,
    collate_config={"return_wave": False},
    dataset_config=dataset_params,
    dataset_name="val",
)

if isinstance(phoneme_map, str):
    csv = pd.read_csv(phoneme_map, header=None).values
    wlist = {word: index for word, index in csv}
else:
    wlist = phoneme_map
index2phoneme = {v: k for k, v in wlist.items()}

best_model = ""
best_model_per = 100.0

model_cntr = 1
total_files = len(sorted_files)
results = []
#maxtestsize = 0 # will test the full validation set - may take a while and is usually not necessary
maxtestsize = 25

for aux_model_file in sorted_files:
    model_path = checkpoint_dir + "/" + aux_model_file
    model = load_ASR_models(model_path, config_path)
    model.eval()

    print(f"[{model_cntr}/{total_files}] Now evaluating AUX model: {aux_model_file}")
    model_cntr += 1

    predictions = []
    references = []
    first_ref = ""
    first_pred = ""

    test_iter = iter(test_loader)
    with torch.no_grad():
        for sample_idx in range(len(test_loader)):
            if maxtestsize > 0 and sample_idx >= maxtestsize:
                break

            batch = next(test_iter)
            texts, input_lengths, mels, output_lengths = batch
            mels = mels.to(device)
            output = model(mels)
            logits = select_logits_from_output(output)
            predicted_ids = torch.argmax(logits, dim=-1)

            for i in range(predicted_ids.size(0)):
                pred_seq = predicted_ids[i][:output_lengths[i] // 3]
                ref_seq = texts[i][:input_lengths[i]]

                pred_phonemes = [index2phoneme.get(p.item(), '') for p in pred_seq]
                ref_phonemes = [index2phoneme.get(r.item(), '') for r in ref_seq]

                predictions.append(pred_phonemes)
                references.append(ref_phonemes)

                if first_ref == "":
                    first_ref = ' '.join(ref_phonemes)
                    first_pred = ' '.join(pred_phonemes)

    references_cleaned = [' '.join(seq) for seq in references]
    predictions_cleaned = [' '.join(seq) for seq in predictions]
    per = wer(references_cleaned, predictions_cleaned)

    print(f'Phoneme Error Rate: {per:.4f} {"✓" if per < best_model_per else "✗"}')
    if per < best_model_per:
        best_model_per = per
        best_model = aux_model_file

    results.append({
        'model': aux_model_file,
        'per': per,
        'first_ref': first_ref,
        'first_pred': first_pred
    })

results_sorted = sorted(results, key=lambda x: x['per'])

print("===================")
print("PERFORMANCE SUMMARY")
print("===================")
for res in results_sorted:
    print(f"Model: {res['model']}")
    print(f"PER: {res['per']:.4f}")
    print(f"Reference: {res['first_ref']}")
    print(f"Prediction: {res['first_pred']}")
    print("------------------------")

best = results_sorted[0]
print(f"✅ Best model: {best['model']} with PER = {best['per']:.4f}")

if 'config' in globals() and isinstance(config, dict):
    _mp_base_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_base_config = yaml.safe_load(_f)

memory_opts = _mp_base_config.get("memory_optimizations", {}) if isinstance(_mp_base_config, dict) else {}
lazy_mask_cfg = memory_opts.get("lazy_masks", {}) if isinstance(memory_opts, dict) else {}
lazy_enabled = bool(lazy_mask_cfg.get("enabled", True))
print(f"lazy mask creation enabled --> {lazy_enabled}")
print(f"skip future mask allocation --> {bool(lazy_mask_cfg.get('future_mask', True))}")
print(f"skip text mask allocation --> {bool(lazy_mask_cfg.get('text_mask', True))}")


## Self-Conditioned CTC Support

This notebook has been updated to work with the optional self-conditioned CTC feature. To experiment with it, enable the feature in `Configs/config.yml` by setting:

```yaml
stabilization:
  self_conditioned_ctc:
    enabled: true
    layers:
      - index: 2
      - index: 4
    conditioning_strategy: add  # or concat
    detach_conditioning: true
loss_weights:
  self_conditioned_ctc: 0.2
```

With the feature enabled the trainer will expose `self_conditioned_ctc_logits` and `self_conditioned_ctc_log_probs` in the model outputs so they can be visualised or scored inside the notebook just like the existing CTC signals.

```yaml
regularization:
  entropy:
    enabled: false
    mode: minimize  # set to "maximize" to encourage smoother distributions
    eps: 1.0e-6
    targets:
      ctc:
        enabled: true
        weight: 0.0
        length_normalize: true
      s2s:
        enabled: true
        weight: 0.0
        length_normalize: true
```

In [None]:
import yaml

def _resolve_mixed_precision_from_config(_cfg):
    if 'cfg_get_nested' in globals():
        return cfg_get_nested(_cfg, 'precision.mixed_precision', {})
    precision_cfg = _cfg.get('precision', {}) if isinstance(_cfg, dict) else {}
    return precision_cfg.get('mixed_precision', {}) if isinstance(precision_cfg, dict) else {}

if 'config' in globals():
    _mp_root_config = config
else:
    with open('Configs/config.yml') as _f:
        _mp_root_config = yaml.safe_load(_f)

mixed_precision_cfg = _resolve_mixed_precision_from_config(_mp_root_config) or {}
grad_scaler_cfg = mixed_precision_cfg.get('grad_scaler', {}) if isinstance(mixed_precision_cfg, dict) else {}
print(f"mixed precision enabled --> {bool(mixed_precision_cfg.get('enabled', False))}")
print(f"mixed precision dtype --> {mixed_precision_cfg.get('dtype', 'float16')}")
print(f"grad scaler enabled --> {bool(grad_scaler_cfg.get('enabled', True))}")
