In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect
"""

# Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg libsox-fmt-mp3
!pip install unidecode
!pip install matplotlib>=3.3.2

## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

"""
Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!
Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case
that you want to use the "Run All Cells" (or similar) option.
"""
# exit()

In [None]:
!apt update
!apt-get  install -y sox libsndfile1 ffmpeg libsox-fmt-mp3

In [None]:
import os
import glob
import subprocess
import tarfile
import wget
import copy
from omegaconf import OmegaConf, open_dict

In [None]:
import nemo
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.utils import logging, exp_manager

In [None]:
VERSION = "aihub-2022-07"
LANGUAGE = "ko"

In [None]:
#tokenizer_dir = os.path.join('tokenizers', LANGUAGE)
nemo_dir = os.path.join('/mnt/sdb/jhchang/nemo/')

In [None]:
!apt-get install  sox

Now that the dataset has been downloaded, let's prepare some paths to easily access the manifest files for the train, dev, and test partitions.

In [None]:
train_manifest = f"{nemo_dir}/Training/Manifests/korean_train.json"
#dev_manifest = f"{manifest_dir}/commonvoice_dev_manifest.json"
test_manifest = f"{nemo_dir}/Validation/Manifests/korean_val.json"

## Manifest utilities

First, we construct some utilities to read and write manifest files

In [None]:
# Manifest Utils
from tqdm.auto import tqdm
import json

def read_manifest(path):
    manifest = []
    i=1
    with open(path, 'r') as f:
        for line in tqdm(f, desc="Reading manifest data"):
            line = line.replace("\n", "")
            data = json.loads(line)
            manifest.append(data)

    return manifest


In [None]:
train_manifest_data = read_manifest(train_manifest)
#dev_manifest_data = read_manifest(dev_manifest)
test_manifest_data = read_manifest(test_manifest)

Next, we extract just the text corpus from the manifest.

In [None]:
train_text = [data['text'] for data in train_manifest_data]
#dev_text = [data['text'] for data in dev_manifest_data]
test_text = [data['text'] for data in test_manifest_data]

## Character set

Let us calculate the character set - which is the set of unique tokens that exist within the text manifests.

In [None]:
from collections import defaultdict

def get_charset(manifest_data):
    charset = defaultdict(int)
    for row in tqdm(manifest_data, desc="Computing character set"):
        text = row['text']
        for character in text:
            charset[character] += 1
    return charset

In [None]:
train_charset = get_charset(train_manifest_data)
#dev_charset = get_charset(dev_manifest_data)
test_charset = get_charset(test_manifest_data)

Count the number of unique tokens that exist within this dataset

In [None]:
train_dev_set = set.union(set(train_charset.keys()))
#, set(dev_charset.keys()))
test_set = set(test_charset.keys())

In [None]:
test_set

In [None]:
print(f"Number of tokens in train+dev set : {len(train_dev_set)}")
print(f"Number of tokens in test set : {len(test_set)}")

## Count number of Out-Of-Vocabulary tokens in the test set

Given such a vast number of tokens exist in the train and dev set, lets make sure that there are no outlier tokens in the test set (remember: the number of kanji used regularly is roughly more than 2000 tokens!).

In [None]:
# OOV tokens in test set
train_test_common = set.intersection(train_dev_set, test_set)
test_oov = test_set - train_test_common
print(f"Number of OOV tokens in test set : {len(test_oov)}")
print()
print(test_oov)

In [None]:
# Populate dictionary mapping count: list[tokens]
train_counts = defaultdict(list)
for token, count in train_charset.items():
    train_counts[count].append(token)
for token, count in test_charset.items():
    train_counts[count].append(token)

# Compute sorter order of the count keys
count_keys = sorted(list(train_counts.keys()))

Build a paired list that computes the number of unique kanji which occurs less than some `MAX_COUNT` number of times.

In [None]:
MAX_COUNT = 32

TOKEN_COUNT_X = []
NUM_TOKENS_Y = []
for count in range(1, MAX_COUNT + 1):
    if count in train_counts:
        num_tokens = len(train_counts[count])

        TOKEN_COUNT_X.append(count)
        NUM_TOKENS_Y.append(num_tokens)

In [None]:
import matplotlib.pyplot as plt

plt.bar(x=TOKEN_COUNT_X, height=NUM_TOKENS_Y)
plt.title("Occurance of unique tokens in train+dev set")
plt.xlabel("# of occurances")
plt.ylabel("# of tokens")
plt.xlim(0, MAX_COUNT);

In [None]:
UNCOMMON_TOKENS_COUNT = 5

chars_with_infrequent_occurance = set()
for count in range(1, UNCOMMON_TOKENS_COUNT + 1):
    if count in train_counts:
        token_list = train_counts[count]
        chars_with_infrequent_occurance.update(set(token_list))

print(f"Number of tokens with <= {UNCOMMON_TOKENS_COUNT} occurances : {len(chars_with_infrequent_occurance)}")

## Remove Out-of-Vocabulary tokens from the test set

Previously we counted the set of Out-of-Vocabulary tokens that exist in the test set but not in the train or dev set. Now, let's remove them.

In [None]:
all_tokens = set.union(train_dev_set, test_set)
print(f"Original train+dev+test vocab size : {len(all_tokens)}")

extra_kanji = set(test_oov)
train_token_set = all_tokens - extra_kanji
print(f"New train vocab size : {len(train_token_set)}")

In [None]:
#@title Dakuten normalization
perform_dakuten_normalization = False #@param ["True", "False"] {type:"raw"}
PERFORM_DAKUTEN_NORMALIZATION = bool(perform_dakuten_normalization)

In [None]:
import unicodedata
def process_dakuten(text):
    normalized_text = unicodedata.normalize('NFD', text)
    normalized_text = normalized_text.replace("\u3099", "").replace("\u309A", "")
    return normalized_text

In [None]:
if PERFORM_DAKUTEN_NORMALIZATION:
    normalized_train_token_set = set()
    for token in train_token_set:
        normalized_token = process_dakuten(str(token))
        normalized_train_token_set.update(normalized_token)
        
    print(f"After dakuten normalization, number of train tokens : {len(normalized_train_token_set)}")
else:
    normalized_train_token_set = train_token_set
    

In [None]:
# Preprocessing steps
import re
import unicodedata

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\…\{\}\【\】\・\。\『\』\、\ー\〜]'  # remove special character tokens
#kanji_removal_regex = '[' + "".join([f"\{token}" for token in extra_kanji]) + ']'  # remove test set kanji


def remove_special_characters(data):
    data["text"] = re.sub(chars_to_ignore_regex, '', data["text"]).lower().strip()
    return data

def remove_extra_kanji(data):
    data["text"] = re.sub(kanji_removal_regex, '', data["text"])
    return data

def remove_dakuten(data):
    # perform dakuten normalization (if it was requested)
    if PERFORM_DAKUTEN_NORMALIZATION:
        text = data['text']
        data['text'] = process_dakuten(text)
    return data

## Process dataset

Now that we have the functions necessary to clean up the transcripts, let's create a small pipeline to clean up the manifest and write new manifests for us. For simplicity's sake (as the dataset is so small), a simple sequential pipeline will be sufficient for our use case.

In [None]:
# Processing pipeline
def apply_preprocessors(manifest, preprocessors):
    for processor in preprocessors:
        for idx in tqdm(range(len(manifest)), desc=f"Applying {processor.__name__}"):
            manifest[idx] = processor(manifest[idx])

    print("Finished processing manifest !")
    return manifest

def write_processed_manifest(data, original_path):
    original_manifest_name = os.path.basename(original_path)
    new_manifest_name = original_manifest_name.replace(".json", "_processed.json")

    manifest_dir = os.path.split(original_path)[0]
    filepath = os.path.join(manifest_dir, new_manifest_name)
    with open(filepath, 'w') as f:
        for datum in tqdm(data, desc="Writing manifest data"):
            datum = json.dumps(datum)
            f.write(f"{datum}\n")
    print(f"Finished writing manifest: {filepath}")
    return filepath


In [None]:
# List of pre-processing functions
PREPROCESSORS = [
    remove_special_characters
]

In [None]:
# Load manifests
train_data = read_manifest(train_manifest)
#dev_data = read_manifest(dev_manifest)
test_data = read_manifest(test_manifest)

# Apply preprocessing
train_data_processed = apply_preprocessors(train_data, PREPROCESSORS)
#dev_data_processed = apply_preprocessors(dev_data, PREPROCESSORS)
test_data_processed = apply_preprocessors(test_data, PREPROCESSORS)

# Write new manifests
#train_manifest_cleaned = write_processed_manifest(train_data_processed, train_manifest)
#dev_manifest_cleaned = write_processed_manifest(dev_data_processed, dev_manifest)
#test_manifest_cleaned = write_processed_manifest(test_data_processed, test_manifest)


In [None]:
train_manifest_cleaned = "/mnt/sdb/jhchang/nemo//Training/Manifests/korean_train_processed.json"
test_manifest_cleaned = "/mnt/sdb/jhchang/nemo//Validation/Manifests/korean_val_processed.json"

## Final character set

After pre-processing the dataset, let's recover the final character set used to train the models.

In [None]:
train_manifest_data = read_manifest(train_manifest_cleaned)
train_charset = get_charset(train_manifest_data)

#dev_manifest_data = read_manifest(dev_manifest_cleaned)
#dev_charset = get_charset(dev_manifest_data)

train_dev_set = set.union(set(train_charset.keys())) #, set(dev_charset.keys()))

In [None]:
print(f"Number of tokens in preprocessed train+dev set : {len(train_dev_set)}")

# Character Encoding CTC Model

Now that we have a processed dataset, we can begin training an ASR model on this dataset. The following section will detail how we prepare a CTC model which utilizes a Character Encoding scheme.

This section will utilize a pre-trained [QuartzNet 15x5](https://arxiv.org/abs/1910.10261), which has been trained on roughly 7,000 hours of English speech base model. We will modify the decoder layer (thereby changing the model's vocabulary) and then train for a small number of epochs.

In [None]:
char_model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5", map_location='cpu')

## Update the vocabulary

Changing the vocabulary of a character encoding ASR model is as simple as passing the list of new tokens that comprise the vocabulary as input to `change_vocabulary()`.

In [None]:
char_model.change_vocabulary(new_vocabulary=list(train_dev_set))

In [None]:
list(train_dev_set)

In [None]:
#@title Freeze Encoder { display-mode: "form" }
freeze_encoder = False #@param ["False", "True"] {type:"raw"}
freeze_encoder = bool(freeze_encoder)

In [None]:
import torch
import torch.nn as nn

def enable_bn_se(m):
    if type(m) == nn.BatchNorm1d:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)

    if 'SqueezeExcite' in type(m).__name__:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)

In [None]:
if freeze_encoder:
  char_model.encoder.freeze()
  char_model.encoder.apply(enable_bn_se)
  logging.info("Model encoder has been frozen, and batch normalization has been unfrozen")
else:
  char_model.encoder.unfreeze()
  logging.info("Model encoder has been un-frozen")

## Update config

Each NeMo model has a config embedded in it, which can be accessed via `model.cfg`. In general, this is the config that was used to construct the model.

For pre-trained models, this config generally represents the config used to construct the model when it was trained. A nice benefit to this embedded config is that we can repurpose it to set up new data loaders, optimizers, schedulers, and even data augmentation!

### Updating the character set of the model

The most important step for preparing character encoding models for fine-tuning is to update the model's character set. Remember - the model was trained on some language with some specific dataset that had a certain character set. Character sets would rarely remain the same between training and fine-tuning (though it is still possible).

Each character encoding model has a `model.cfg.labels` attribute, which can be overridden via OmegaConf.

In [None]:
char_model.cfg.labels = list(train_dev_set)

Now, we create a working copy of the model config and update it as needed.

In [None]:
cfg = copy.deepcopy(char_model.cfg)

### Setting up data loaders

Now that the model's character set has been updated let's prepare the model to utilize the new character set even in the data loaders. Note that this is crucial so that the data produced during training/validation matches the new character set, and tokens are encoded/decoded correctly.

**Note**: An important config parameter is `normalize_transcripts` and `parser`. There are some parsers that are used for specific languages for character based models - currently only `en` is supported. These parsers will preprocess the text with the given languages parser. However, for other languages, it is advised to explicitly set `normalize_transcripts = False` - which will prevent the parser from processing text. 

In [None]:
# Setup train, validation, test configs

with open_dict(cfg):    
  # Train dataset  (Concatenate train manifest cleaned and dev manifest cleaned)
  cfg.train_ds.manifest_filepath = f"{train_manifest_cleaned}" #",{dev_manifest_cleaned}"
  cfg.train_ds.labels = list(train_dev_set)
  cfg.train_ds.normalize_transcripts = False
  cfg.train_ds.batch_size = 32
  cfg.train_ds.num_workers = 8
  cfg.train_ds.pin_memory = True
  cfg.train_ds.trim_silence = True

  # Validation dataset  (Use test dataset as validation, since we train using train + dev)
  cfg.validation_ds.manifest_filepath = test_manifest_cleaned
  cfg.validation_ds.labels = list(train_dev_set)
  cfg.validation_ds.normalize_transcripts = False
  cfg.validation_ds.batch_size = 8
  cfg.validation_ds.num_workers = 8
  cfg.validation_ds.pin_memory = True
  cfg.validation_ds.trim_silence = True

In [None]:
# setup data loaders with new configs
char_model.setup_training_data(cfg.train_ds)
char_model.setup_multiple_validation_data(cfg.validation_ds)

### Setting up optimizer and scheduler

When fine-tuning character models, it is generally advised to use a lower learning rate and reduced warmup. A reduced learning rate helps preserve the pre-trained weights of the encoder. Since the fine-tuning dataset is generally smaller than the original training dataset, the warmup steps would be far too much for the smaller fine-tuning dataset.

-----
**Note**: When freezing the encoder, it is possible to use the original learning rate as the model was trained on. The original learning rate can be used because the encoder is frozen, so the learning rate is used only to optimize the decoder. However, a very high learning rate would still destabilize training, even with a frozen encoder.

In [None]:
# Original optimizer + scheduler
print(OmegaConf.to_yaml(char_model.cfg.optim))

In [None]:
with open_dict(char_model.cfg.optim):
  char_model.cfg.optim.lr = 0.0004 #0.01 for freezing
  char_model.cfg.optim.betas = [0.95, 0.5]  # from paper
  char_model.cfg.optim.weight_decay = 0.001  # Original weight decay
  char_model.cfg.optim.sched.warmup_steps = None  # Remove default number of steps of warmup
  char_model.cfg.optim.sched.warmup_ratio = 0.05  # 5 % warmup
  char_model.cfg.optim.sched.min_lr = 1e-5

### Setting up augmentation

Remember that the model was trained on several thousands of hours of data, so the regularization provided to it might not suit the current dataset. We can easily change it as we see fit.

-----

You might notice that we utilize `char_model.from_config_dict()` to create a new SpectrogramAugmentation object and assign it directly in place of the previous augmentation. This is generally the syntax to be followed whenever you notice a `_target_` tag in the config of a model's inner config. 

-----
**Note**: For low resource languages, it might be better to increase augmentation via SpecAugment to reduce overfitting. However, this might, in turn, make it too hard for the model to train in a short number of epochs.

In [None]:
print(OmegaConf.to_yaml(char_model.cfg.spec_augment))

In [None]:
#with open_dict(char_model.cfg.spec_augment):
#   char_model.cfg.spec_augment.freq_masks = 2
#   char_model.cfg.spec_augment.freq_width = 25
#   char_model.cfg.spec_augment.time_masks = 2
#   char_model.cfg.spec_augment.time_width = 0.05

char_model.spec_augmentation = char_model.from_config_dict(char_model.cfg.spec_augment)

## Setup Metrics

Originally, the model was trained on an English dataset corpus. When calculating Word Error Rate, we can easily use the "space" token as a separator for word boundaries. On the other hand, certain languages such as Japanese and Mandarin do not use "space" tokens, instead opting for different ways to annotate the end of the word.

In cases where the "space" token is not used to denote a word boundary, we can use the Character Error Rate metric instead, which computes the edit distance at a token level rather than a word level.

We might also be interested in noting model predictions during training and inference. As such, we can enable logging of the predictions.

In [None]:
#@title Metric
use_cer = False #@param ["False", "True"] {type:"raw"}
log_prediction = True #@param ["False", "True"] {type:"raw"}


In [None]:
char_model._wer.use_cer = use_cer
char_model._wer.log_prediction = log_prediction

## Setup Trainer and Experiment Manager

And that's it! Now we can train the model by simply using the Pytorch Lightning Trainer and NeMo Experiment Manager as always.

For demonstration purposes, the number of epochs is kept intentionally low. Reasonable results can be obtained in around 100 epochs (approximately 25 minutes on Colab GPUs).

In [None]:
import torch
import pytorch_lightning as ptl

if torch.cuda.is_available():
  accelerator = 'gpu'
else:
  accelerator = 'cpu'

EPOCHS = 20  # 100 epochs would provide better results, but would take an hour to train

#checkpoint_callback = ModelCheckpoint(
#    filepath=os.path.join('checkpoints', '{epoch:d}'),
#    verbose=True,
#    save_last=True,
#    save_top_k=args.save_top_k,
#    monitor='val_acc',
#    mode='max'
#)
trainer = ptl.Trainer(gpus=[1], 
                      accelerator=accelerator, 
                      max_epochs=EPOCHS, 
                      accumulate_grad_batches=1,
                      enable_checkpointing=True,
                      logger=False,
                      log_every_n_steps=500,
                      #callbacks = [checkpoint_callback],
                      check_val_every_n_epoch=1)
# resume_from_checkpoint
# Setup model with the trainer
char_model.set_trainer(trainer)

# Finally, update the model's internal config
char_model.cfg = char_model._cfg

In [None]:
# Environment variable generally used for multi-node multi-gpu training.
# In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.
os.environ.pop('NEMO_EXPM_VERSION', None)

config = exp_manager.ExpManagerConfig(
    exp_dir=f'experiments/lang-{LANGUAGE}/',
    name=f"ASR-Char-Model-Language-{LANGUAGE}",
    create_checkpoint_callback=False,
    checkpoint_callback_params=exp_manager.CallbackParams(
        monitor="val_wer",
        mode="min",
        always_save_nemo=True,
        save_best_model=True,
    ),
)

config = OmegaConf.structured(config)

logdir = exp_manager.exp_manager(trainer, config)

In [None]:
try:
  from google import colab
  COLAB_ENV = True
except (ImportError, ModuleNotFoundError):
  COLAB_ENV = False

# Load the TensorBoard notebook extension
if COLAB_ENV:
  %load_ext tensorboard
  %tensorboard --logdir /content/experiments/lang-$LANGUAGE/ASR-Char-Model-Language-$LANGUAGE/
else:
  print("To use tensorboard, please use this notebook in a Google Colab environment.")

In [None]:
%%time
trainer.fit(char_model)

In [None]:
import datetime as dt
timestamp = dt.datetime.now().strftime("%H-%M-%d-%B")
save_path = f"/mnt/sdb/jhchang/nemo/Model-{LANGUAGE}-{EPOCHS}-{timestamp}.nemo"
char_model.save_to(f"{save_path}")
print(f"Model saved at path : {save_path}")

In [None]:
%%time
trainer = ptl.Trainer(gpus=[1], 
                      accelerator=accelerator, 
                      max_epochs=4, 
                      accumulate_grad_batches=1,
                      enable_checkpointing=True,
                      logger=False,
                      log_every_n_steps=5000,
                      #callbacks = [checkpoint_callback],
                      check_val_every_n_epoch=1)
trainer.fit(char_model)

import datetime as dt
timestamp = dt.datetime.now().strftime("%H-%M-%d-%B")
save_path = f"/mnt/sdb/jhchang/nemo/Model-{LANGUAGE}-{trainer.max_epochs}-{timestamp}.nemo"
char_model.save_to(f"{save_path}")
print(f"Model saved at path : {save_path}")

In [None]:
# Bigger batch-size = bigger throughput
#params['char_model']['validation_ds']['batch_size'] = 16
# Setup the test data loader and make sure the model is on GPU

char_model = nemo_asr.models.ASRModel.restore_from(f'/mnt/sdb/jhchang/nemo/Model-ko-epoch 20-05-11-06-July_WER(0.0438).nemo')
cfg = copy.deepcopy(char_model.cfg)

with open_dict(cfg):
  # Train dataset  (Concatenate train manifest cleaned and dev manifest cleaned)
  cfg.train_ds.manifest_filepath = f"{train_manifest_cleaned}" #",{dev_manifest_cleaned}"
  cfg.train_ds.labels = None
  cfg.train_ds.normalize_transcripts = False
  cfg.train_ds.batch_size = 32
  cfg.train_ds.num_workers = 8
  cfg.train_ds.pin_memory = True
  cfg.train_ds.trim_silence = True

  # Validation dataset  (Use test dataset as validation, since we train using train + dev)
  cfg.validation_ds.manifest_filepath = test_manifest_cleaned
  cfg.validation_ds.labels = None
  cfg.validation_ds.normalize_transcripts = False
  cfg.validation_ds.batch_size = 8
  cfg.validation_ds.num_workers = 8
  cfg.validation_ds.pin_memory = True
  cfg.validation_ds.trim_silence = True

char_model.setup_test_data(cfg.validation_ds)
