# AIDK RNN-T Demo

# Content
* ### [Model Architecture](#Model-Architecture)
* ### [Environment Setup](#Environment-setup)
* ### [Launch training](#Launch-training)
* ### [Optimizations](#Optimizations)
* ### [Performance](*Performance)

## ASR
![ASR](./img/asr.png)

* The traditional ASR system (top picture) contains acoustic, phonetic and language components that work together as in a pipeline system
* The end-to-end ASR system is a single neural network that receives raw audio signal as input and provides a sequence of words at output

## Model Architecture
![RNN-T](./img/rnnt_structure.png)

RNN-T is an end-to-end ASR model that directly converts audio into text representation.

The encoder network is a RNN which maps input acoustic frames into a higher-level representation.
The prediction network is a RNN that is explicitly conditioned on the history of previous non-blank targets predicted by the model.
The joint network is a feed-forward network that combines the outputs of the prediction network and the encoder to produce logits followed by a softmax layer to produce a distribution over the next output symbol.

## Environment setup

build docker image

```
cd Dockerfile-ubuntu18.04
docker build -t e2eaiok-pytorch110 . -f DockerfilePytorch110
```

```
docker run -itd --name aidk-rnnt --privileged --network host --device=/dev/dri -v ${dataset_path}:/home/vmagent/app/dataset -v ${aidk_code_path}:/home/vmagent/app/e2eaiok -w /home/vmagent/app/ e2eaiok-pytorch110:latest /bin/bash
```
Enter container with `docker exec -it aidk-rnnt bash`

Start the jupyter notebook service

```
source /opt/intel/oneapi/setvars.sh --ccl-configuration=cpu_icc --force
conda activate pytorch-1.10.0
pip install jupyter
nohup jupyter notebook --notebook-dir=/home/vmagent/app/e2eaiok/ --ip=0.0.0.0 --port=8888 --allow-root &
```

Now you can visit AIDK RNN-T demo in http://${hostname}:8888/

Notes: RNN-T training is based on LibriSpeech train-clean-100 and evaluated on dev-clean, we evaluated WER with stock model (based on MLPerf submission) at train-clean-100 dataset, and final WER is 0.25, all the following optimization guarantee 0.25 WER. MLPerf submission took 38.7min with 8x A100 on LibriSpeech train-960h dataset.

public reference on train-clean-100: https://arxiv.org/pdf/1807.10893.pdf, https://arxiv.org/pdf/1811.00787.pdf

## Launch training

In [3]:
!cd /home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch && bash scripts/train.sh

scripts/train.sh: line 25: [: : integer expression expected
STARTING TIMING RUN AT 2022-09-09 07:59:46 AM
running benchmark
scripts/train.sh: line 123: [: -ne: unary operator expected
Distributed training
2022-09-09 07:59:48,825 - __main__ - INFO - MASTER_ADDR=127.0.0.1
2022-09-09 07:59:48,825 - __main__ - INFO - MASTER_PORT=29500
2022-09-09 07:59:48,825 - __main__ - INFO - I_MPI_PIN_DOMAIN=[0xffffffffffff0,0xffffffffffff00000000000000,]
2022-09-09 07:59:48,826 - __main__ - INFO - OMP_NUM_THREADS=48
2022-09-09 07:59:48,826 - __main__ - INFO - Using Intel OpenMP
2022-09-09 07:59:48,827 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-09-09 07:59:48,827 - __main__ - INFO - KMP_BLOCKTIME=1
2022-09-09 07:59:48,827 - __main__ - INFO - LD_PRELOAD=/opt/intel/oneapi/intelpython/latest/envs/pytorch-1.10.0/lib/libiomp5.so
2022-09-09 07:59:48,827 - __main__ - INFO - CCL_WORKER_COUNT=4
2022-09-09 07:59:48,827 - __main__ - INFO - CCL_WORKER_AFFINITY=0,1,2,3,52,53,54,55
2022-09-09 

[0] :::MLLOG {"namespace": "", "time_ms": 1662710412415, "event_type": "POINT_IN_TIME", "key": "model_weights_initialization_scale", "value": 0.5, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/train.py", "lineno": 395}}
[0] :::MLLOG {"namespace": "", "time_ms": 1662710412679, "event_type": "POINT_IN_TIME", "key": "weights_initialization", "value": null, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/common/rnn.py", "lineno": 87, "tensor": "pre_rnn"}}
[0] :::MLLOG {"namespace": "", "time_ms": 1662710413624, "event_type": "POINT_IN_TIME", "key": "weights_initialization", "value": null, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/common/rnn.py", "lineno": 87, "tensor": "post_rnn"}}
[0] :::MLLOG {"namespace": "", "time_ms": 1662710413636, "event_type": "POINT_IN_TIME", "key": "weights_initialization", "value": null, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/rnnt/model.py", "lineno": 155, "t

[0] Dataset read by DALI. Number of samples: 73[0] 
[0] Initializing DALI with parameters:[0] 
[0] 	           __class__ : <class 'common.data.dali.pipeline.DaliPipeline'>[0] 
[0] 	          batch_size : 16[0] 
[0] 	           device_id : None[0] 
[0] 	        dither_coeff : 1e-05[0] 
[0] 	       dont_use_mmap : False[0] 
[0] 	           file_root : /home/vmagent/app/dataset/LibriSpeech/valid[0] 
[0] 	    in_mem_file_list : False[0] 
[0] 	        max_duration : inf[0] 
[0] 	           nfeatures : 80[0] 
[0] 	                nfft : 512[0] 
[0] 	         num_threads : 4[0] 
[0] 	       pipeline_type : val[0] 
[0] 	            pre_sort : False[0] 
[0] 	       preemph_coeff : 0.97[0] 
[0] 	preprocessing_device : cpu[0] 
[0] 	      resample_range : None[0] 
[0] 	         sample_rate : 16000[0] 
[0] 	             sampler : <common.data.dali.sampler.SimpleSampler object at 0x7fbfb299dd90>[0] 
[0] 	                seed : 2021[0] 
[0] 	                self : <common.data.dali.pipeline.DaliPipel

[1]   x_lens = (x_lens.int() + self.factor - 1) // self.factor
[0]   x_lens = (x_lens.int() + self.factor - 1) // self.factor
[0] DLL 2022-09-09 08:00:31.622960 - epoch    1 | iter    1/3 | loss 1003.32 | utts/s     2 | took 15.24 s | lrate 4.52e-04[0] 
[0] DLL 2022-09-09 08:00:43.140007 - epoch    1 | iter    2/3 | loss  897.04 | utts/s     3 | took 11.52 s | lrate 6.77e-04[0] 
[0] DLL 2022-09-09 08:00:52.094818 - epoch    1 | iter    3/3 | loss  601.92 | utts/s     4 | took  8.96 s | lrate 9.03e-04[0] 
[0] :::MLLOG {"namespace": "", "time_ms": 1662710452096, "event_type": "INTERVAL_END", "key": "epoch_stop", "value": null, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/train.py", "lineno": 784, "epoch_num": 1}}
[0] DLL 2022-09-09 08:00:52.096816 - epoch    1 | avg train utts/s     3 | took 35.86 s
[0] :::MLLOG {"namespace": "", "time_ms": 1662710452097, "event_type": "POINT_IN_TIME", "key": "throughput", "value": 2.6768749609050704, "metadata": {"file": "/home/

[0] :::MLLOG {"namespace": "", "time_ms": 1662711117511, "event_type": "POINT_IN_TIME", "key": "eval_accuracy", "value": 20.504347826086956, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/train.py", "lineno": 259, "epoch_num": 2}}
[0] :::MLLOG {"namespace": "", "time_ms": 1662711117512, "event_type": "INTERVAL_END", "key": "eval_stop", "value": null, "metadata": {"file": "/home/vmagent/app/e2eaiok/modelzoo/rnnt/pytorch/train.py", "lineno": 260, "epoch_num": 2}}
[0] DLL 2022-09-09 08:11:57.512776 - epoch    2 |   dev ema wer 2050.43 | took 17.56 s[0] 
[0] Saving /results/RNN-T_epoch2_checkpoint.pt...[0] 
ENDING TIMING RUN AT 2022-09-09 08:12:26 AM
RESULT,RNN_SPEECH_RECOGNITION,760,2022-09-09 07:59:46 AM


## Optimizations

### Model architecture

For RNN-T model democratization, we enabled distributed training with pytorch DDP to scale out model training on multi nodes, added time stack layer and increased time stack factor to reduce input sequence lengh, added layer and batch normalization to speedup training converge, decreased layer size to get a lighter model.

<center>
<img src="./img/model_base.png" width="800"/><figure>base model</figure>
<img src="./img/model_opt.png" width="800"/><figure>democratized model</figure>
</center>

### Distributed training

``` python
# data parallel
if world_size > 1:
    model = DDP(model, find_unused_parameters=True)
```

### Add time stack layer

For ASR systems, the number of time frames for an audio input sequence is significantly higher than the number of output text labels. LSTM is sequential model which leads to much time cost in process long sequence data like audio data. The StackTime layer stacks audio frames to reduce sequence length and form a higher dimension input, which helps to speedup training process.

```python
class StackTime(nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = int(factor)

    def stack(self, x):
        x = x.transpose(0, 1)
        T = x.size(1)
        padded = torch.nn.functional.pad(x, (0, 0, 0, (self.factor - (T % self.factor)) % self.factor))
        B, T, H = padded.size()
        x = padded.reshape(B, T // self.factor, -1)
        x = x.transpose(0, 1)
        return x

    def forward(self, x, x_lens):
        if type(x) is not list:
            x = self.stack(x)
            x_lens = (x_lens.int() + self.factor - 1) // self.factor
            return x, x_lens
        else:
            if len(x) != 2:
                raise NotImplementedError("Only number of seq segments equal to 2 is supported")
            assert x[0].size(1) % self.factor == 0, "The length of the 1st seq segment should be multiple of stack factor"
            y0 = self.stack(x[0])
            y1 = self.stack(x[1])
            x_lens = (x_lens.int() + self.factor - 1) // self.factor
            return [y0, y1], x_lens
```

About 4x speedup after increase time stack factor from 2 to 8.

<center>
<img src="./img/time_stack_2.PNG" width="800"/><figure>time_stack = 2</figure>
<img src="./img/time_stack_8.PNG" width="800"/><figure>time_stack = 8</figure>
</center>

Profiling data proves that less time cost on forward/backward since input sequence reduced with time stack layer

<center>
<img src="./img/stack_profile_base.png" width="800"/><figure>base model profiling</figure>
<img src="./img/stack_profile_democratize.png" width="800"/><figure>democratized model profiling</figure>
</center>

## Add layer normalization and batch normalization

Layer normalization for LSTM is important to the success of RNN-T modeling. Add layer normalization for LSTM and batch normalization for input feature help to speedup training converge. It takes 52 epochs to converge without normalization, while only 49 epochs needed with normalization. 

```python
enc_mod["batch_norm"] = nn.BatchNorm1d(pre_rnn_input_size)
```

```python
self.layer_norm = torch.nn.LayerNorm(hidden_size)
```

<center>
<img src="./img/no_norm.PNG" width="800"/><figure>without normalization</figure>
<img src="./img/norm.PNG" width="800"/><figure>with normalization</figure>
</center>

## HPO with SDA (Smart Democratization Advisor)

SDA config

```
model_parameter:
  project: sda
  experiment: rnnt
  parameters:
  - bounds:
      max: 1.0e-2
      min: 1.0e-3
    name: learning_rate
    transformation: log
    type: double
  - bounds:
      max: 10
      min: 1
    name: warmup_epochs
    type: int
metrics:
- name: training_time
  objective: minimize
  threshold: 43200
- name: WER
  objective: minimize
  threshold: 0.25
 ```

request suggestions from SDA

```python
suggestion = self.conn.experiments(self.experiment.id).suggestions().create()
```


## Framework related optimization

leverage IPEX for distributed training and enable socket binding for training in two socket system

```bash
# Use IPEX launch to launch training, enable NUMA binding in two socket system.
${CONDA_PREFIX}/bin/python -m intel_extension_for_pytorch.cpu.launch --distributed --nproc_per_node=2 --nnodes=4 --hostfile hosts train.py ${ARGS}
```

<center>
<img src="./img/no_numa_binding.png" width="600"/><figure>without numa binding</figure>
<img src="./img/numa_binding.png" width="600"/><figure>enable numa binding</figure>
</center>

## Performance Overview

* Distributed training with HW scaling delivered 5.16x speedup from 1 node to 4 nodes
* Time stacking + reduce LSTM layer size delivered 1.86x speedup, and 9.63x speedup over baseline
** Time stack factor: 8, LSTM depth 5 -> 4, LSTM width 1024 -> 512
* Add layer normalization in encoder and decoder, add batch normalization for input feature delivered 1.07x speedup, and 10.31x speedup over baseline
** Add layer normalization in encoder and decoder for LSTM and batch normalization for audio feature
* Reduce CCL worker number delivered 1.07x speedup, and 11.06x speedup over baseline

![rnnt_perf_raw](./img/rnnt_perf_raw.png)
![rnnt_perf_norm](./img/rnnt_perf_norm.png)