# Main

Since we're training our model on a google colab, we made this notebooks to start or continue the model training.

### Init

Init the objects used to train the model

In [None]:
from util.data import DataLoader, DataPrepocessor
from util.model import Model
from util.train import Trainer

data_loader = DataLoader()
preprocessor = DataPrepocessor(
    frame_length=256, frame_step=160, fft_length=384, audio_path=data_loader.audio_path)

train_dataset, validation_dataset = preprocessor.create_dataset_objets(
    data_loader=data_loader, batch_size=30)  # choose carefully (too low means to much trainning time but to high means big computation)

ds_model = Model(input_dim=preprocessor.fft_length//2 + 1,
                    output_dim=preprocessor.char_to_num.vocabulary_size(), rnn_units=512)
ds_model.model.summary(line_length=110)

trainer = Trainer(model=ds_model.model, train_dataset=train_dataset,
                      validation_dataset=validation_dataset, preprocessor=preprocessor)

### Trainning

Choose the number of epochs, default 10

In [None]:
# start training
epochs = 10
save_every_n_hours = 3
trainer.train(epochs=epochs, save_every_n_hours=save_every_n_hours)

In [None]:
# resume training
epochs = 10
save_every_n_hours = 3
trainer.model.load_weights('trainings/<latest_version>')
trainer.train(epochs=epochs, save_every_n_hours=save_every_n_hours)

### Inference
Let's check the model on more validation samples

In [None]:
import tensorflow as tf
import numpy as np
from jiwer import wer

predictions = []
targets = []
for batch in validation_dataset:
    X, y = batch
    batch_predictions = trainer.model.predict(X)
    batch_predictions = trainer.decode_batch_predictions(batch_predictions)
    predictions.extend(batch_predictions)
    for label in y:
        label = tf.strings.reduce_join(
            trainer.preprocessor.num_to_char(label)).numpy().decode("utf-8")
        targets.append(label)
wer_score = wer(targets, predictions)
print("-" * 100)
print(f"Word Error Rate: {wer_score:.4f}")
print("-" * 100)
for i in np.random.randint(0, len(predictions), 10):
    print(f"Target    : {targets[i]}")
    print(f"Prediction: {predictions[i]}")
    print("-" * 100)

### Testing
To test the model with a speech wich is outside the dataset yo can try our web app at model_tester/