##### WaveRNN Matrix Calculation

In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from IPython.core.debugger import Pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


```
math::
    xt = [ct-1, ft-1, ct]  # input
    ut = σ(Ru ht-1 + Iu*xt + bu)  # update gate
    rt = σ(Rr ht-1 + Ir*xt + br)  # reset gate
    et = tanh(rt∘(Re ht-1) + Ie*xt + be)  # recurrent unit
    ht = ut∘ht-1 + (1-u)∘et # next hidden state
    yc, yf = split(ht)  # coarse, fine
    P(ct) = softmax(O2 relu(O1 yc))  # coarse distribution
    P(ft) = softmax(O4 relu(O3 yf))  # fine distribution
```

In [5]:
# hidden state size 896 (56 num of GPU multi-processors *8 minimum warps size assigned to each processor *2)
hidden_size = 896

In [6]:
# category size 256 (8bit)
category_size = 256

In [7]:
# ht-1 (previous hidden) 
ht_1 = torch.Tensor(1, hidden_size)
print(f'ht-1 = {ht_1.shape}')

ht-1 = torch.Size([1, 896])


In [8]:
# ct-1 (previous coarse), ft-1 (previous fine), ct(current coarse)
ct_1 = torch.Tensor(1, 1)
ft_1 = torch.Tensor(1, 1)
ct = torch.Tensor(1, 1)
print(f'ct-1 = {ct_1.shape}')
print(f'ft-1 = {ft_1.shape}')
print(f'ct = {ct.shape}')

ct-1 = torch.Size([1, 1])
ft-1 = torch.Size([1, 1])
ct = torch.Size([1, 1])


In [9]:
# input xt = [ct-1, ft-1, ct]
xt = [ct_1, ft_1, ct]
print(f'xt = [ct-1, ft-1, ct] {xt[0].shape} {xt[1].shape} {xt[2].shape}')

xt = [ct-1, ft-1, ct] torch.Size([1, 1]) torch.Size([1, 1]) torch.Size([1, 1])


In [10]:
# gating unit R (U)
R = nn.Linear(hidden_size, hidden_size*3)
print(f'R = {R}')

R = Linear(in_features=896, out_features=2688, bias=True)


In [11]:
# hidden state I (W) (coarse, fine)
Ic = nn.Linear(2, 3*hidden_size//2)
If = nn.Linear(3, 3*hidden_size//2)
print(f'Ic = {Ic}')
print(f'If = {Ic}')

Ic = Linear(in_features=2, out_features=1344, bias=True)
If = Linear(in_features=2, out_features=1344, bias=True)


In [12]:
# transoform matrices into categorical distributions
O1 = nn.Linear(hidden_size//2, hidden_size//2)
O2 = nn.Linear(hidden_size//2, hidden_size//2)
O3 = nn.Linear(hidden_size//2, category_size)
O4 = nn.Linear(hidden_size//2, category_size)

In [13]:
# fully connected (previous hidden ht-1) x (gating unit R (U))
Rht_1 = R(ht_1)
print(f'R ht_1 => {Rht_1} {Rht_1.shape}')

R ht_1 => tensor([[nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward>) torch.Size([1, 2688])


In [14]:
# split R (U)
Ruht_1, Rrht_1, Reht_1 = torch.split(Rht_1, hidden_size, dim=1)
print(f'Ru ht_1 => {Ruht_1.shape}')
print(f'Rr ht_1 => {Ruht_1.shape}')
print(f'Re ht_1 => {Ruht_1.shape}')

Ru ht_1 => torch.Size([1, 896])
Rr ht_1 => torch.Size([1, 896])
Re ht_1 => torch.Size([1, 896])


In [15]:
# fully connected (xt) x (hidden state If (W)) (coarse part)
Icxt = Ic(torch.cat((xt[0], xt[1]), dim=1))
print(f'Ic*xt => {Icxt} {Icxt.shape}')

Ic*xt => tensor([[ 0.2323,  0.4750,  0.6151,  ..., -0.0697, -0.6429, -0.4561]],
       grad_fn=<AddmmBackward>) torch.Size([1, 1344])


In [16]:
# fully connected (xt) x (hidden state Ic (W)) (fine part)
Ifxt = If(torch.cat((xt[0], xt[1], xt[2]), dim=1))
print(f'If*xt => {Ifxt} {Ifxt.shape}')

If*xt => tensor([[ 0.1895, -0.2494,  0.0175,  ..., -0.1684,  0.1494, -0.4013]],
       grad_fn=<AddmmBackward>) torch.Size([1, 1344])


In [17]:
# split I
Iuxt = torch.cat((Icxt[:,:hidden_size//2], Ifxt[:,:hidden_size//2]), dim=1)
Irxt = torch.cat((Icxt[:,hidden_size//2:2*hidden_size//2], Ifxt[:,hidden_size//2:2*hidden_size//2]), dim=1)
Iext = torch.cat((Icxt[:,2*hidden_size//2:3*hidden_size//2], Ifxt[:,2*hidden_size//2:3*hidden_size//2]), dim=1)
print(f'Iu*xt => {Iuxt.shape}')
print(f'Ir*xt => {Irxt.shape}')
print(f'Ie*xt => {Irxt.shape}')

Iu*xt => torch.Size([1, 896])
Ir*xt => torch.Size([1, 896])
Ie*xt => torch.Size([1, 896])


In [18]:
# bias terms
bu = nn.Parameter(torch.zeros(hidden_size))
br = nn.Parameter(torch.zeros(hidden_size))
be = nn.Parameter(torch.zeros(hidden_size))
print(f'bu = {bu.shape}, br = {br.shape}, be = {be.shape}')

bu = torch.Size([896]), br = torch.Size([896]), be = torch.Size([896])


In [19]:
# ut (update gate)
ut = torch.sigmoid(Ruht_1 + Iuxt + bu)
print(f'ut = σ(Ru ht-1 + Iu*xt + bu) {ut.shape}')

ut = σ(Ru ht-1 + Iu*xt + bu) torch.Size([1, 896])


In [20]:
# rt (reset gate)
rt = torch.sigmoid(Rrht_1 + Irxt + br)
print(f'rt = σ(Rr ht-1 + Ir*xt + br) {rt.shape}')

rt = σ(Rr ht-1 + Ir*xt + br) torch.Size([1, 896])


In [21]:
# et (recurrent unit)
et = torch.tanh(rt*(Reht_1) + Iext + be)
print(f'et = tanh(rt∘(Re ht-1) + Ie*xt + be) {et.shape}')

et = tanh(rt∘(Re ht-1) + Ie*xt + be) torch.Size([1, 896])


In [22]:
# ht (next hidden state)
ht = ut*ht_1 + (1-ut)*et
print(f'ht = ut∘ht-1 + (1-u)*et {ht.shape}')

ht = ut∘ht-1 + (1-u)*et torch.Size([1, 896])


In [23]:
# yc, yf (coarse, fine)
yc, yf = torch.split(ht, hidden_size//2, dim=1)
print(f'yc, yf = split(ht)  {yc.shape} {yf.shape}')

yc, yf = split(ht)  torch.Size([1, 448]) torch.Size([1, 448])


In [24]:
# P(ct) coarse distribution
Pct = O3(torch.relu(O1(yc)))
print(f'P(ct) = softmax(O2 relu(O1 yc)) {Pct.shape}')

P(ct) = softmax(O2 relu(O1 yc)) torch.Size([1, 256])


In [25]:
# P(ft) fine distribution
Pft = O4(torch.relu(O2(yf)))
print(f'P(ft) = softmax(O4 relu(O2 yf)) {Pft.shape}')

P(ft) = softmax(O4 relu(O2 yf)) torch.Size([1, 256])
