# Train Conformer

Paper related https://arxiv.org/abs/2005.08100

<div class="alert alert-info">

This tutorial is available as an IPython notebook at [malaya-speech/example/train-asr](https://github.com/huseinzol05/malaya-speech/tree/master/example/train-asr).
    
</div>

<div class="alert alert-warning">

This example trained on a very small dataset, do not use it for production.
    
</div>

### Model Interface

To initiate CNN-RNN,

```python
import malaya_speech.train.model.conformer as conformer

x = tf.placeholder(tf.float32, [None, None, num_features, 1])
config = malaya_speech.config.conformer_small_encoder_config
model = conformer.Model(**config)
logits = model(x)
```

This model interface do not required input length because we can calculate output length using `ctc.utils.calculate_input_length_deep_speech`.

We also included `base` and `large` configs based on the paper,

```python
malaya_speech.config.conformer_base_encoder_config
malaya_speech.config.conformer_large_encoder_config
```

In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
import tensorflow as tf
from glob import glob
import malaya_speech.config
import malaya_speech.train.model.conformer as conformer
import malaya_speech.train.model.ctc as ctc
import malaya_speech.train as train
import numpy as np


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



### Load Featurizer

We prefer use `malaya_speech.tf_featurization.STTFeaturizer` for ASR.

In [4]:
config = malaya_speech.config.transducer_featurizer_config
featurizer = malaya_speech.tf_featurization.STTFeaturizer(**config)
n_mels = config['num_feature_bins']
initial_learning_rate = 1e-3
max_gradient_norm = 5.0
train_steps = 1000

In [5]:
def preprocess_inputs(example):
    s = featurizer.vectorize(example['waveforms'])
    mel_fbanks = tf.reshape(s, (-1, n_mels))
    length = tf.cast(tf.shape(mel_fbanks)[0], tf.int32)
    length = tf.expand_dims(length, 0)
    example['inputs'] = mel_fbanks
    example['inputs_length'] = length

    return example

### Define Data Pipeline

In [6]:
def parse(serialized_example):

    data_fields = {
        'waveforms': tf.VarLenFeature(tf.float32),
        'targets': tf.VarLenFeature(tf.int64),
    }
    features = tf.parse_single_example(
        serialized_example, features = data_fields
    )
    for k in features.keys():
        features[k] = features[k].values

    features = preprocess_inputs(features)

    keys = list(features.keys())
    for k in keys:
        if k not in ['inputs', 'inputs_length', 'targets']:
            features.pop(k, None)

    return features


def get_dataset(path, batch_size = 32, shuffle_size = 32, thread_count = 24):
    def get():
        files = glob(path)
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.shuffle(shuffle_size)
        dataset = dataset.map(parse, num_parallel_calls = thread_count)
        dataset = dataset.padded_batch(
            batch_size,
            padded_shapes = {
                'inputs': tf.TensorShape([None, n_mels]),
                'inputs_length': tf.TensorShape([None]),
                'targets': tf.TensorShape([None]),
            },
            padding_values = {
                'inputs': tf.constant(0, dtype = tf.float32),
                'inputs_length': tf.constant(0, dtype = tf.int32),
                'targets': tf.constant(0, dtype = tf.int64),
            },
        )
        dataset = dataset.repeat()
        return dataset

    return get

### Define Model definition

1. Use malaya_speech CTC loss,

```python
mean_error, sum_error, sum_weight = ctc.loss.ctc_loss(
    logits, targets_int32, seq_lens
)
```

2. Use malaya_speech CTC sequence accuracy,

```python
accuracy = ctc.metrics.ctc_sequence_accuracy(
    logits, targets_int32, seq_lens
)
```

This will automatically recorded in Tensorboard.

3. Use malaya_speech output length calculation,

```python
seq_lens = ctc.utils.calculate_input_length_deep_speech(
    features['inputs'], logits
)
```

4. Use malaya_speech metrics for Evaluation session,

```python
elif mode == tf.estimator.ModeKeys.EVAL:

    estimator_spec = tf.estimator.EstimatorSpec(
        mode = tf.estimator.ModeKeys.EVAL,
        loss = loss,
        eval_metric_ops = {
            'accuracy': ctc.metrics.ctc_sequence_accuracy_estimator(
                logits, targets_int32, seq_lens
            ),
            'WER': ctc.metrics.word_error_rate_estimator(
                logits, targets_int32
            ),
        },
    )
```

If we passed directly tensor values to `eval_metric_ops`, Tensorflow will throw error, required update operations. So just stick with that implementation.

In [7]:
def model_fn(features, labels, mode, params):
    
    config = malaya_speech.config.conformer_small_encoder_config
    model = conformer.Model(**config)
    inputs = tf.expand_dims(features['inputs'], axis = -1)
    logits = model(inputs)
    logits = tf.layers.dense(logits, malaya_speech.char.VOCAB_SIZE)

    seq_lens = ctc.utils.calculate_input_length_deep_speech(
        features['inputs'], logits
    )

    targets_int32 = tf.cast(features['targets'], tf.int32)

    mean_error, sum_error, sum_weight = ctc.loss.ctc_loss(
        logits, targets_int32, seq_lens
    )

    loss = mean_error
    accuracy = ctc.metrics.ctc_sequence_accuracy(
        logits, targets_int32, seq_lens
    )

    tf.identity(loss, 'train_loss')
    tf.identity(accuracy, name = 'train_accuracy')

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()
        step = tf.cast(global_step, tf.float32)
        dmodel = tf.cast(config['dmodel'], tf.float32)
        learning_rate = train.schedule.transformer_schedule(step,
                                                           dmodel,
                                                           warmup_steps = 100)

        optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
        train_op = optimizer.minimize(loss, global_step = global_step)
        estimator_spec = tf.estimator.EstimatorSpec(
            mode = mode, loss = loss, train_op = train_op
        )

    elif mode == tf.estimator.ModeKeys.EVAL:

        estimator_spec = tf.estimator.EstimatorSpec(
            mode = tf.estimator.ModeKeys.EVAL,
            loss = loss,
            eval_metric_ops = {
                'accuracy': ctc.metrics.ctc_sequence_accuracy_estimator(
                    logits, targets_int32, seq_lens
                ),
                'WER': ctc.metrics.word_error_rate_estimator(
                    logits, targets_int32
                ),
            },
        )

    return estimator_spec

In [8]:
train_hooks = [
    tf.train.LoggingTensorHook(
        ['train_accuracy', 'train_loss'], every_n_iter = 1
    )
]
train_dataset = get_dataset('tolong-sebut/data/tolong-sebut-train*')
dev_dataset = get_dataset('tolong-sebut/data/tolong-sebut-dev*')

In [9]:
!rm -rf asr-smallconformer smallconformer

In [10]:
train.run_training(
    train_fn = train_dataset,
    model_fn = model_fn,
    model_dir = 'asr-smallconformer',
    num_gpus = 1,
    log_step = 1,
    save_checkpoint_step = 200,
    max_steps = train_steps,
    eval_fn = dev_dataset,
    train_hooks = train_hooks,
)



INFO:tensorflow:Using config: {'_model_dir': 'smallconformer', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f0496a505f8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluatio

### Reference

1. Tensorflow ASR, https://github.com/TensorSpeech/TensorFlowASR/tree/main/examples/conformer