# RNN & CTC

- RNN 모델과 CTC loss를 이용하여 오디오와 텍스트가 정렬이 안되어도 학습이 가능한 모델을 구현

In [None]:
import torch
import torchaudio  #feature extraction
import torch.nn as nn   # RNN module 
import IPython.display as ipd
import matplotlib.pyplot as plt  #for visualization
%matplotlib inline

In [2]:
!wget https://github.com/dbstj1231/2023_AI_Academy_ASR/raw/main/set.wav

--2023-07-11 10:14:30--  https://github.com/dbstj1231/2023_AI_Academy_ASR/raw/main/set.wav
Resolving github.com (github.com)... 20.200.245.247
Connecting to github.com (github.com)|20.200.245.247|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/dbstj1231/2023_AI_Academy_ASR/main/set.wav [following]
--2023-07-11 10:14:30--  https://raw.githubusercontent.com/dbstj1231/2023_AI_Academy_ASR/main/set.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18998 (19K) [audio/wav]
Saving to: ‘set.wav.1’


2023-07-11 10:14:31 (25.3 MB/s) - ‘set.wav.1’ saved [18998/18998]



In [None]:
ipd.Audio("set.wav")

In [None]:
y,sr = torchaudio.load("set.wav")
y,sr

In [None]:
y.shape

In [None]:
y.shape[1]/ sr

In [None]:
n_fft = 512
hop_length = n_fft // 2
n_mels = 64
sr = 16000

mel_converter = torchaudio.transforms.MelSpectrogram(n_fft=n_fft,
                                                     n_mels=n_mels,
                                                     sample_rate=sr,
                                                     hop_length=hop_length)

db_converter = torchaudio.transforms.AmplitudeToDB()

In [None]:
spec = mel_converter(y)
plt.imshow(spec[0],origin="lower",interpolation='nearest',aspect='auto')

In [None]:
db_spec =db_converter(spec)
plt.imshow(db_spec[0],origin="lower",interpolation='nearest',aspect='auto')

In [None]:
plt.plot(y[0])

In [None]:
x = torch.arange(len(y[0]))/sr
plt.plot(x,y[0])

# gen character dict
```
'a' : 1 , 'b' : 2 , 'c':3 ... 'z':26
```

In [None]:
# gen character dict


### torch.nn.RNN

- input: tensor of shape $(L,H_{in})$ for unbatched input, $(L,N,H_{in}​)$ when `batch_first=False` or $(N,L,H_{in}​)$ when batch_first=True containing the features of the input sequence. The input can also be a packed variable length sequence. See `torch.nn.utils.rnn.pack_padded_sequence()` or `torch.nn.utils.rnn.pack_sequence()` for details.

$$
\begin{aligned}
N =& \text{batch size} \\
L =& \text{sequence length} \\
H_{in} =& \text{input_size} \\
\end{aligned}
$$



### arg of nn.RNN
>input_size: The number of expected features in the input `x`
>        
>hidden_size: The number of features in the hidden state `h`
>        
>num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
>            would mean stacking two RNNs together to form a `stacked RNN`,
>            with the second RNN taking in outputs of the first RNN and
>            computing the final results. Default: 1

In [None]:
class RNNModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.n_classes = 27 # # of char_set 
    self.n_fft = 512
    self.n_mels = 64      # input_size of nn.RNN
    self.hop_length = self.n_fft // 2
    self.sr = 16000
    self.hidden_dim = 64
    self.n_char = 27

    self.mel_converter = torchaudio.transforms.MelSpectrogram(n_fft =self.n_fft,
                                                              n_mels = self.n_mels,
                                                              hop_length=self.hop_length,
                                                              sample_rate=self.sr
                                                              )
    self.db_converter  = torchaudio.transforms.AmplitudeToDB()


    # TODO:define rnn layer and output layer


  def forward(self,x):
    # TODO: define forward

    return x

model = RNNModel()

In [None]:
# check shape

In [None]:
# get char_list


In [None]:
# show relation between time frame and character


### CTCLoss
- 실제 출력(ground truth)과 모델이 예측한 출력 사이의 거리를 측정
- CTC는 각 입력 타임스텝에서 가능한 모든 경로를 고려하며, 각 경로의 확률을 합산하여 최종 예측을 생성

In [None]:
??nn.CTCLoss

In [None]:
# check  "set" char_set and shape

In [None]:
# define training loop



In [None]:
# define plot_ctc 

In [None]:
# show output by plot_ctc

In [None]:
# # define training loop with plot_ctc

visualization

In [None]:
plt.figure(figsize=(10,5),dpi=100)
plt.plot(out[0].cpu().detach().numpy())
plt.legend(char_list)

plt.show()

In [None]:
??torch.exp

In [None]:
import plotly.graph_objects as go

data = torch.exp(out[0]).cpu().detach().numpy()
fig = go.Figure()
for idx, char_prob in enumerate(data.T):
  fig.add_trace(go.Line(y=char_prob, name=char_list[idx]))
fig.show()


## Q&A
```
Q : CTC에서 target에 있지 않은 blk의 확률값이 높은 이유는 무엇인가요
CTC에서 blk('_')는 shot pause, slience에 해당하는 확률이 될 수도 있고
apple, sorry 와 같이 같은 글자가 두번 나오는 경우에 pp_p, rr_r 과 같이 두 음을 
구분하는 역할도 하게 되므로 확률이 높게 추정됩니다.


그래서 이를 해결하고자 나온 새로운 loss 가 Facebook의 wav2letter모델을 위해 
제안한 auto segmentation criterion(ASG)가 있습니다.
다만 해당 criterion은 많이 사용되지는 않습니다.
```