# Training an Acoustic Model with Subword Tokenization

In this notebook, we train an ASR model for German, using the Citrinet model with cross language transfer learning. The workflow is demonstrated in the figure below.

![png](./imgs/german-transfer-learning.PNG)

We first demonstrate the training process with NeMo on 1 GPU in this notebook. To speed up training, multiple GPUs should be leveraged using the more efficient DDP (distributed data parallel) protocol, which must run in a seperate [training script](./train.py).

This notebook can be run from within the NeMo container, such as:

```
docker run  --ipc=host --gpus=all --net=host --rm -it -v $PWD:/myworkspace nvcr.io/nvidia/nemo:22.08 bash
```

Note:  PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host or --shm-size command line options to nvidia-docker run.


In [1]:
import nemo
import nemo.collections.asr as nemo_asr

print(nemo.__version__)

from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf, open_dict
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

[NeMo W 2022-10-25 00:26:25 optimizers:77] Could not import distributed_fused_adam optimizer from Apex
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


1.12.0


## Cross-Language Transfer Learning

Transfer learning is an important machine learning technique that uses a model’s knowledge of one task to perform better on another. Fine-tuning is one of the techniques to perform transfer learning. It is an essential part of the recipe for many state-of-the-art results where a base model is first pretrained on a task with abundant training data and then fine-tuned on different tasks of interest where the training data is less abundant or even scarce.

Transfer learning with NeMo is simple.


First, let's load the pretrained Nemo Citrinet model, which was trained on ~6000 hours of English data.

In [2]:
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="stt_en_citrinet_1024")

[NeMo I 2022-10-25 00:26:35 cloud:56] Found existing object /root/.cache/torch/NeMo/NeMo_1.12.0/stt_en_citrinet_1024/86acfaf495a53383369fb6c9c547b8dd/stt_en_citrinet_1024.nemo.
[NeMo I 2022-10-25 00:26:35 cloud:62] Re-using file from: /root/.cache/torch/NeMo/NeMo_1.12.0/stt_en_citrinet_1024/86acfaf495a53383369fb6c9c547b8dd/stt_en_citrinet_1024.nemo
[NeMo I 2022-10-25 00:26:35 common:910] Instantiating model from pre-trained checkpoint
[NeMo I 2022-10-25 00:26:39 mixins:170] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2022-10-25 00:26:39 modelPT:142] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 32
    trim_silence: true
    max_duration: 16.7
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    use_start_end_token: false
    
[NeMo W 2022-10-25 00:26:39 modelPT:149] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 32
    shuffle: false
    use_start_end_token: false
    
[NeMo W 2022-10-25 00:26:39 modelPT:155] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a va

[NeMo I 2022-10-25 00:26:40 features:225] PADDING: 16
[NeMo I 2022-10-25 00:26:44 save_restore_connector:243] Model EncDecCTCModelBPE was successfully restored from /root/.cache/torch/NeMo/NeMo_1.12.0/stt_en_citrinet_1024/86acfaf495a53383369fb6c9c547b8dd/stt_en_citrinet_1024.nemo.


### Update vocabulary
Next, check what kind of vocabulary/alphabet the model has right now

In [3]:
print(asr_model.decoder.vocabulary)

['<unk>', 's', '▁the', 't', '▁a', '▁i', "'", '▁and', '▁to', 'ed', 'd', '▁of', 'e', '▁in', 'ing', '.', '▁it', '▁you', 'n', '▁that', 'm', 'y', 'er', '▁he', 're', 'r', '▁was', '▁is', '▁for', '▁know', 'a', 'p', 'c', ',', '▁be', 'o', '▁but', '▁they', 'g', '▁so', 'ly', 'b', '▁s', '▁yeah', '▁we', '▁have', '▁re', '▁like', 'l', '▁on', 'll', 'u', '▁with', '▁do', 'al', '▁not', '▁are', 'or', 'ar', 'le', '▁this', '▁as', 'es', '▁c', '▁de', 'f', 'in', 'i', 've', '▁uh', 'ent', '▁or', '▁what', '▁me', '▁t', '▁at', '▁my', '▁his', '▁there', 'w', '▁all', '▁just', 'h', '▁can', 'ri', 'il', 'k', 'ic', '▁e', '▁', '▁um', '▁don', '▁b', '▁had', 'ch', 'ation', 'en', 'th', '▁no', '▁she', 'it', '▁one', '▁think', '▁st', '▁if', '▁from', 'ter', '▁an', 'an', 'ur', '▁out', 'on', '▁go', 'ck', '▁would', '▁were', '▁w', '▁will', '▁about', '▁right', 'ment', '▁her', 'te', 'ion', '▁well', '▁by', 'ce', '▁g', '▁oh', '▁up', 'ro', 'ra', '▁when', '▁some', '▁also', '▁their', 'ers', 'ow', '▁more', '▁time', 'ate', '▁has', '▁people', '▁

Now let's update the vocabulary in this model, using the German tokenizer that we have trained in the data preparation step.

In [4]:
# Lets change the tokenizer vocabulary by passing the path to the new directory,
asr_model.change_vocabulary(
    new_tokenizer_dir="../data_preparation/data/processed/tokenizer/tokenizer_spe_bpe_v1024/",
    new_tokenizer_type="bpe"
)

[NeMo W 2022-10-25 00:26:44 modelPT:217] You tried to register an artifact under config key=tokenizer.model_path but an artifact for it has already been registered.
[NeMo W 2022-10-25 00:26:44 modelPT:217] You tried to register an artifact under config key=tokenizer.vocab_path but an artifact for it has already been registered.


[NeMo I 2022-10-25 00:26:44 mixins:170] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2022-10-25 00:26:44 ctc_bpe_models:259] 
    Replacing old number of classes (1024) with new number of classes - 1024
[NeMo I 2022-10-25 00:26:45 ctc_bpe_models:301] Changed tokenizer to ['<unk>', 'en', 'er', '▁d', 'ch', 'ei', 'un', 'ie', '▁w', '▁a', '▁s', '▁i', 'st', '▁die', '▁un', '▁m', 'ge', 'ich', '▁da', 'ein', 'ss', '▁b', '▁h', 'sch', '▁v', 'on', 'an', '▁k', '▁z', '▁n', '▁und', 'gen', '▁f', '▁e', 'ir', '▁au', 'ti', '▁ein', '▁der', 'll', 'in', '▁wir', 'te', '▁in', 'or', 'ur', 'ten', '▁ge', 'ung', 'ra', 'it', 're', 'ar', '▁zu', 'den', '▁g', 'der', '▁p', 'al', 'ür', 'lich', 'hr', 'icht', 'es', '▁ha', 'men', '▁das', 'ben', '▁ver', 'eit', 'em', '▁ist', 'ier', '▁den', 'tz', '▁l', 'ber', '▁be', '▁dass', '▁an', '▁auch', 'om', '▁nicht', 'de', '▁es', 'isch', '▁mit', 'ter', 'se', '▁ich', 'au', 'op', '▁er', '▁t', 'oll', 'ach', '▁j', '▁eur', 'ig', 'um', '▁für', '▁auf', '▁europ', '▁sie'

After this, our decoder has completely changed, but our encoder (where most of the weights are) remained intact.

### Update Config

Each NeMo model has a config embedded in it, which can be accessed via model.cfg. In general, this is the config that was used to construct the model.

For pre-trained models, this config generally represents the config used to construct the model when it was trained. A nice benefit to this embedded config is that we can repurpose it to set up new data loaders, optimizers, schedulers, and even data augmentation!

In [5]:
!ln -s ../data_preparation/data .

ln: failed to create symbolic link './data': File exists


In [6]:
DATA_ROOT = "./data"
USE_TARRED_DATASET = True

if USE_TARRED_DATASET:
    # Setup train, validation, test configs
    with open_dict(asr_model.cfg):    
      # Train dataset  (Concatenate train manifest cleaned and dev manifest cleaned)
      asr_model.cfg.train_ds.manifest_filepath = f'{DATA_ROOT}/processed/tar/train/tarred_audio_manifest.json'
      asr_model.cfg.train_ds.is_tarred = True
      asr_model.cfg.train_ds.tarred_audio_filepaths= DATA_ROOT+'/processed/tar/train/audio_{0..127}.tar'

      asr_model.cfg.train_ds.batch_size = 32
      asr_model.cfg.train_ds.num_workers = 32
      asr_model.cfg.train_ds.pin_memory = True
      asr_model.cfg.train_ds.trim_silence = True

      # Validation dataset  (Use test dataset as validation, since we train using train + dev)
      asr_model.cfg.validation_ds.manifest_filepath = [f'{DATA_ROOT}/processed/test_manifest_merged.json', f'{DATA_ROOT}/processed/dev_manifest_merged.json']
      asr_model.cfg.validation_ds.batch_size = 32
      asr_model.cfg.validation_ds.num_workers = 32
      asr_model.cfg.validation_ds.pin_memory = True
      asr_model.cfg.validation_ds.trim_silence = True
else:
    # Setup train, validation, test configs
    with open_dict(asr_model.cfg):    
      # Train dataset  (Concatenate train manifest cleaned and dev manifest cleaned)
      asr_model.cfg.train_ds.manifest_filepath = f'{DATA_ROOT}/processed/train_manifest_merged.json'
      asr_model.cfg.train_ds.batch_size = 32
      asr_model.cfg.train_ds.num_workers = 32
      asr_model.cfg.train_ds.pin_memory = True
      asr_model.cfg.train_ds.trim_silence = True

      # Validation dataset  (Use test dataset as validation, since we train using train + dev)
      asr_model.cfg.validation_ds.manifest_filepath = [f'{DATA_ROOT}/processed/test_manifest_merged.json', f'{DATA_ROOT}/processed/dev_manifest_merged.json']
      asr_model.cfg.validation_ds.batch_size = 32
      asr_model.cfg.validation_ds.num_workers = 32
      asr_model.cfg.validation_ds.pin_memory = True
      asr_model.cfg.validation_ds.trim_silence = True

# Point to the new train and validation data for fine-tuning
asr_model.setup_training_data(train_data_config=asr_model.cfg.train_ds)
asr_model.setup_validation_data(val_data_config=asr_model.cfg.validation_ds)



[NeMo W 2022-10-25 00:26:45 audio_to_text_dataset:179] dataset does not have explicitly defined labels


[NeMo I 2022-10-25 00:26:46 collections:194] Dataset loaded with 9029 files totalling 18.10 hours
[NeMo I 2022-10-25 00:26:46 collections:195] 427 files were filtered totalling 2.17 hours


[NeMo W 2022-10-25 00:26:46 ctc_models:434] Model Trainer was not set before constructing the dataset, incorrect number of training batches will be used. Please set the trainer and rebuild the dataset.


[NeMo I 2022-10-25 00:26:51 collections:194] Dataset loaded with 4077 files totalling 9.90 hours
[NeMo I 2022-10-25 00:26:51 collections:195] 0 files were filtered totalling 0.00 hours


### Setting up optimizer and scheduler

When fine-tuning character models, it is generally advised to use a lower learning rate and reduced warmup. A reduced learning rate helps preserve the pre-trained weights of the encoder. Since the fine-tuning dataset is generally smaller than the original training dataset, the warmup steps would be far too much for the smaller fine-tuning dataset.


In [7]:
# Original optimizer + scheduler
print(OmegaConf.to_yaml(asr_model.cfg.optim))

name: novograd
lr: 0.05
betas:
- 0.8
- 0.25
weight_decay: 0.001
sched:
  name: CosineAnnealing
  warmup_steps: 1000
  warmup_ratio: null
  min_lr: 1.0e-05
  last_epoch: -1



In [8]:
# Use the smaller learning rate we set before
with open_dict(asr_model.cfg.optim):
  asr_model.cfg.optim.name="adamw"
  asr_model.cfg.optim.lr = 0.01
  asr_model.cfg.optim.betas = [0.8, 0.25]  # from paper
  asr_model.cfg.optim.weight_decay = 0.001  # Original weight decay
  asr_model.cfg.optim.sched.warmup_steps = None  # Remove default number of steps of warmup
  asr_model.cfg.optim.sched.warmup_ratio = 0.05  # 5 % warmup
  asr_model.cfg.optim.sched.min_lr = 1e-5
  asr_model.cfg.optim.sched.max_steps = 50000

## Training the model

And now we can create a PyTorch Lightning trainer and call `fit`. To increase training speed, we can leverage the mixed precision training mode. In this notebook, we demonstrate training with 1 GPUs. To train with 8 GPUs, execute the [train.py](train.py) script in a shell terminal.

Notes:
- Even with cross-language transfer learning, the model will still take a few hundreds epochs to train to convergence. 
- To stabilize training and avoid NAN loss issues, increase the global batch size to the range of [256, 2048]. On devices with small memory, this can be achieved by setting an appropriate number of the `accumulate_grad_batches`.
- `asr_model.cfg.train_ds.batch_size` denotes the per-device batchsize. The global batch size will be `batch_size* #nodes * GPUs per node * accumulate_grad_batches`.

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="val_wer",
    mode="min",
    dirpath="./checkpoint-dir",
    filename="citrinet-DE-{epoch:02d}",
    save_on_train_epoch_end=True,
)

trainer = pl.Trainer(precision=16, 
                     devices=1, 
                     accelerator='gpu',                        
                     max_epochs=500,                      
                     default_root_dir="./checkpoint/",
                     accumulate_grad_batches=32, # For a global batch size of 32*1*32 = 1024
                     callbacks=[checkpoint_callback])
    
trainer.fit(asr_model)
asr_model.save_to('de-asr-model.nemo')