In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [None]:
import matplotlib.pylab as plt
import numpy as np
import pickle, os, warnings, sys, random, logging, librosa, json, nemo
import soundfile as sf
from tqdm.auto import tqdm
from ruamel.yaml import YAML
import nemo.collections.asr as nemo_asr
from omegaconf import DictConfig
import pytorch_lightning as pl
from Cfg import Cfg
from RecordingCorpus import RecordingCorpus
from multiprocessing import Pool
from contextlib import closing

warnings.filterwarnings("ignore")
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
## 

## Train Phase 1: Start from scratch

In [None]:
C = Cfg('NIST', 8000, 'amharic') 
audio_split_dir=f'{C.build_dir}/audio_split'

In [None]:
!mkdir -p {audio_split_dir}

In [None]:
if 1:
    if __name__ == '__main__':
        with closing(Pool(16)) as pool:
            recordings = RecordingCorpus(C, pool)

    from SplitCorpus import SplitCorpus
    splits=SplitCorpus.transcript_split(C, recordings)

    random.shuffle(splits.artifacts)
    n_samples=len(splits.artifacts)

    n_train = int(0.8*n_samples)
    samples=splits.artifacts
    train_samples=samples[0:n_train]
    test_samples=samples[n_train:]

    for (case, S) in [('train', train_samples), ('test', test_samples)]:
        manifest_fn=f'{C.build_dir}/{case}_manifest.json'
        with open(manifest_fn, 'w', encoding='utf-8') as f_manifest:
            for sample in tqdm(S):
                (_,root,(start,end))=sample.key
                audio = sample.source.value
                duration = sample.source.n_seconds
                transcript = sample.target.value
                audio_path=f'{audio_split_dir}/{root}_{start}_{end}.wav'
                sf.write(audio_path, audio, C.sample_rate)
                metadata = {
                        "audio_filepath": audio_path,
                        "duration": duration,
                        "text": transcript
                    }
                json.dump(metadata, f_manifest)
                f_manifest.write('\n')

In [None]:
model_save_dir='save/nemo_amharic'

In [None]:
!mkdir -p {model_save_dir}

In [None]:
!ls {model_save_dir}

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filepath=model_save_dir+'/amharic_{epoch:02d}_loss_{current:.4f}_trainWER_{training_batch_wer:.4f}_valWER_{validation_wer:.4f}',
    save_top_k=-1,
    verbose=True,
    monitor='loss',
    mode='min',
    period=1,
    save_weights_only=True,
)

In [None]:
trainer = pl.Trainer(gpus=[0], max_epochs=1000, amp_level='O1', precision=16, checkpoint_callback=checkpoint_callback)

In [None]:
config_path = 'quartznet_15x5_amharic.yaml'
yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)

In [None]:
params['model']['train_ds']['manifest_filepath'] = f'{C.build_dir}/train_manifest.json'
params['model']['validation_ds']['manifest_filepath'] = f'{C.build_dir}/test_manifest.json'
params['model']['optim']['lr'] = 0.001

In [None]:
model = nemo_asr.models.EncDecCTCModel(cfg=DictConfig(params['model']), trainer=trainer)

In [None]:
from load_recent import load_recent
load_recent(C, model)

In [None]:
trainer.fit(model)

## Train Phase 2: K-fold validation more or less

In [None]:
from reshuffle_samples import reshuffle_samples
import nemo.collections.asr as nemo_asr
from load_recent import load_recent
from ruamel.yaml import YAML
import pytorch_lightning as pl
from Cfg import Cfg
from omegaconf import DictConfig

In [None]:
C = Cfg('NIST', 8000, 'amharic') 

In [None]:
model_save_dir='save/nemo_amharic'

In [None]:
config_path = 'quartznet_15x5_amharic.yaml'
yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
params['model']['train_ds']['manifest_filepath'] = f'{C.build_dir}/train_manifest.json'
params['model']['validation_ds']['manifest_filepath'] = f'{C.build_dir}/test_manifest.json'
params['model']['optim']['lr'] = 0.001

In [None]:
class ModelCheckpointAtEpochEnd(pl.callbacks.ModelCheckpoint):
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics['epoch'] = trainer.current_epoch
        trainer.checkpoint_callback.on_validation_end(trainer, pl_module)

In [None]:
import os
pid=os.getpid()

In [None]:
import datetime
dt=datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
checkpoint_callback = ModelCheckpointAtEpochEnd(
    filepath=model_save_dir+'/amharic_'+f'{dt}_{pid}'+'_{epoch:02d}',
    verbose=True,
    save_top_k=-1,
    save_weights_only=False,
    period=1)

In [None]:
trainer = pl.Trainer(gpus=[0], max_epochs=20, amp_level='O1', precision=16, checkpoint_callback=checkpoint_callback)

In [None]:
while True:
    reshuffle_samples(C)
    model = nemo_asr.models.EncDecCTCModel(cfg=DictConfig(params['model']), trainer=trainer)
    load_recent(C, model)
    trainer.fit(model)