In [3]:
import os
import warnings
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from mpvn.data.grad.lit_data_module import LightningGradDataModule
from mpvn.metric import WordErrorRate
from mpvn.model import ConformerLSTMModel

from mpvn.configs import DictConfig

In [4]:
configs = DictConfig()

pl.seed_everything(configs.seed)
warnings.filterwarnings('ignore')

logger = logger = TensorBoardLogger("tensorboard", name="Pronunciation for Vietnamese")
num_devices = 1

data_module = LightningGradDataModule(configs)
vocab = data_module.get_vocab(configs.dataset_download, configs.vocab_size)
data_module.setup(vocab=vocab)

model = ConformerLSTMModel(configs=configs,
                            num_classes=len(vocab),
                            vocab=vocab,
                            per_metric=WordErrorRate(vocab))

trainer = pl.Trainer(precision=configs.precision,
                        accelerator=configs.accelerator,
                        gpus=num_devices,
                        accumulate_grad_batches=configs.accumulate_grad_batches,
                        amp_backend=configs.amp_backend,
                        auto_select_gpus=configs.auto_select_gpus,
                        check_val_every_n_epoch=configs.check_val_every_n_epoch,
                        gradient_clip_val=configs.gradient_clip_val,
                        logger=logger,
                        max_epochs=configs.max_epochs)

trainer.fit(model, data_module)

Global seed set to 1


KeyboardInterrupt: 