In [None]:
import os, time
import nemo
import nemo.collections.asr as nemo_asr
import torch
import copy
import pytorch_lightning as ptl
from scripts.tools import *

from omegaconf import DictConfig, open_dict
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML

In [None]:
# Cargar el modelo preentrenado
mname = 'QuartzNet15x5Base-En'
quartznet_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=mname)

In [None]:
basedir = '/workspace/nemo/TFM/'

# Automated medical transcription [1411 archivos, 2.44 horas]
amt_val_manifest = basedir + 'data_manifest/amt_val_metafile.json'
amt_train_manifest = basedir + 'data_manifest/amt_train_metafile.json'
amt_test_manifest = basedir + 'data_manifest/amt_test_metafile.json'
# Medical speech transcription [6661 archivos, 8.46 horas]
mst_val_manifest = basedir + 'data_manifest/mst_val_metafile.json'
mst_train_manifest = basedir + 'data_manifest/mst_train_metafile.json'
mst_test_manifest = basedir + 'data_manifest/mst_test_metafile.json'
# Primock57 [6712 archivos, 8.31 horas]
prim_val_manifest = basedir + 'data_manifest/prim_val_metafile.json'
prim_train_manifest = basedir + 'data_manifest/prim_train_metafile.json'
prim_test_manifest = basedir + 'data_manifest/prim_test_metafile.json'

train_manifest = f"{amt_train_manifest},{mst_train_manifest},{prim_train_manifest}"
val_manifest = f"{amt_val_manifest},{mst_val_manifest},{prim_val_manifest}"
test_manifest = f"{amt_test_manifest},{mst_test_manifest},{prim_test_manifest}"

test_list = [amt_test_manifest, mst_test_manifest, prim_test_manifest]


In [None]:
config_path = basedir + 'config_asr.yaml'

yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
    
params['model']['train_ds']['batch_size'] = 16 #32
params['model']['validation_ds']['batch_size'] = 8
params['model']['train_ds']['manifest_filepath'] = train_manifest
params['model']['validation_ds']['manifest_filepath'] = val_manifest
params['model']['test_ds']['manifest_filepath'] = test_manifest

In [None]:
print_metric_medic(quartznet_model, test_list)

In [None]:
# Modificamos el learning rate y el vocabulario
new_opt = copy.deepcopy(params['model']['optim'])
new_opt['lr'] = 0.001
quartznet_model.setup_optimization(optim_config=DictConfig(new_opt))

quartznet_model.change_vocabulary(
    new_vocabulary=[
        ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
        'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", "!", "?", ".",
        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
        'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'   
    ]
)

# Cargamos el dataset
quartznet_model.setup_training_data(train_data_config=params['model']['train_ds'])
quartznet_model.setup_validation_data(val_data_config=params['model']['validation_ds'])
quartznet_model.setup_test_data(test_data_config=params['model']['test_ds'])

In [None]:
if torch.cuda.is_available():
  accelerator = 'gpu'
else:
  accelerator = 'cpu'

torch.set_float32_matmul_precision('high')
csv_logger = pl_loggers.CSVLogger(save_dir="./")
checkpoint_callback = ModelCheckpoint(monitor='val_wer')
trainer = ptl.Trainer(devices=1, 
                      accelerator=accelerator, 
                      max_epochs=10, 
                      accumulate_grad_batches=1,
                      enable_checkpointing=True,
                      logger=csv_logger,
                      log_every_n_steps=5,
                      check_val_every_n_epoch=1,
                      callbacks=[checkpoint_callback])

quartznet_model.set_trainer(trainer)

In [None]:
start = time.time()
trainer.fit(quartznet_model)
end = time.time()
print("TIME:", (end-start)/3600, "h")
print("BEST MODEL PATH:", checkpoint_callback.best_model_path)

# Test

In [None]:
print_metric_medic(quartznet_model, test_list)
print_metric_libri(quartznet_model, '/workspace/nemo/TFM/LibriSpeech/dev-other')

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import pandas as pd

log_version = str(csv_logger.version)
print('log_version:', log_version)
data = pd.read_csv(basedir+'lightning_logs/version_'+ log_version +'/metrics.csv')
val_data = data.loc[pd.notna(data.val_wer)]

# Loss
plt.plot(data.global_step, data.train_loss, label='Entrenamiento')
plt.plot(val_data.global_step, val_data.val_loss, label='Validación', marker='o')

plt.xticks(rotation = 25)
plt.xlabel('Pasos')
plt.ylabel('Pérdida')
plt.title('Valor de la pérdida', fontsize = 20)
plt.grid()
plt.legend()
plt.show()

# WER
plt.plot(data.global_step, data.training_batch_wer, label='Entrenamiento')
plt.plot(val_data.global_step, val_data.val_wer, label='Validación', marker='o')

plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x,pos:format(x*100, ".0f")))
plt.xticks(rotation = 25)
plt.xlabel('Pasos')
plt.ylabel('WER (%)')
plt.title('Word Error Rate', fontsize = 20)
plt.grid()
plt.legend()
plt.show()