This notebook's purpose is to show how the operations on the DelayedStack work, paying special attention to shapes and the correctness of the shapes transformations.

# General Information
In general we work with hidden state $h[l]$ with shape:
* Time-delayed stack $h^t[l]$: [B, FREQ, FRAMES, HIDDEN_SIZE].
* Frequency-delayed stack $h^f[l]$: [B, FREQ, FRAMES, HIDDEN_SIZE].
* Centralized stack $h^c[l]$: [B, 1, FRAMES, HIDDEN_SIZE]

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

HIDDEN_SIZE = 2

# Initial Layer

### Time-delayed stack
To ensure output $h^{t}_{ij}[l]$ is only a function of frames which lie in the context $x_{<ij}$, 
the inputs to the time-delayed stack are shifted backwards one step in time: $h^{t}_{ij}[0] = W^{t}_0 x_{i-1, j}$ [MelNet formula (7)].

For this reason, we have to "invent" the first frame. In this case, we are going to assume that the first frame is all 0.

In [2]:
spectrogram = torch.Tensor([[[11,12,13,14],
                             [15,16,17,18],
                             [19,20,21,22]]])
spectrogram.shape

torch.Size([1, 3, 4])

In [3]:
x_time_pad = F.pad(spectrogram,(1,-1))  # we put -1 to maintain the number of FRAMES equal to 4 
x_time_pad.shape, x_time_pad

(torch.Size([1, 3, 4]),
 tensor([[[ 0., 11., 12., 13.],
          [ 0., 15., 16., 17.],
          [ 0., 19., 20., 21.]]]))

Obviously, we could use other values like:
* random values
* average values of the first frame of the spectrograms in our dataset

Now, we are going to implement the linear transformation [MelNet formula (7)]

In [4]:
# First, we change shape from [B, FREQ, FRAMES] to [B, FREQ, FRAMES, 1]
x_time_pad = x_time_pad.unsqueeze(-1)
print(x_time_pad.shape)

# Linear transformation
W_t_0 = nn.Linear(in_features=1, out_features=HIDDEN_SIZE)
h_t_0 = W_t_0(x_time_pad)
print(h_t_0.shape)

torch.Size([1, 3, 4, 1])
torch.Size([1, 3, 4, 2])


Here, we can see how the output shape of the hidden layer zero of the time-delayed stack corresponds to [B=1, FREQ=3, FRAMES=4, HIDDEN_SIZE=2].

### Frequency-delayed stack
To ensure output $h^{t}_{ij}[l]$ is only a function of frames which lie in the context $x_{<ij}$, 
the inputs to the time-delayed stack are shifted backwards one step along the frequency axis: $h^{f}_{ij}[0] = W^{f}_0 x_{i, j-1}$ [MelNet formula (9)].

For this reason, we have to "invent" the "first (lowest)" frequency for all frames. In this case, we are going to assume that the "first (lowest)" frequency is 0.

In [5]:
x_freq_pad = F.pad(spectrogram,(0,0,1,-1))  # we put -1 to maintain the number of FREQ equal to 3
x_freq_pad.shape, x_freq_pad

(torch.Size([1, 3, 4]),
 tensor([[[ 0.,  0.,  0.,  0.],
          [11., 12., 13., 14.],
          [15., 16., 17., 18.]]]))

Obviously, we could use other values like:
* random values
* average values of the "first (lowest)" frequenct of the spectrograms in our dataset

Now, we are going to implement the linear transformation [MelNet formula (9)]

In [6]:
# First, we change shape from [B, FREQ, FRAMES] to [B, FREQ, FRAMES, 1]
x_freq_pad = x_freq_pad.unsqueeze(-1)
print(x_freq_pad.shape)

# Linear transformation
W_f_0 = nn.Linear(in_features=1, out_features=HIDDEN_SIZE)
h_f_0 = W_t_0(x_freq_pad)
print(h_f_0.shape)

torch.Size([1, 3, 4, 1])
torch.Size([1, 3, 4, 2])


Here, we can see how the output shape of the hidden layer zero of the frequency-delayed stack corresponds to [B=1, FREQ=3, FRAMES=4, HIDDEN_SIZE=2].

## Centralized stack
To ensure output $h^{c}_{i}[l]$ is only a function of frames which lie in the context $x_{<ij}$, 
the inputs to the time-delayed stack are shifted backwards one step along the time axis: $h^{c}_{i}[0] = W^{c}_0 x_{i-1, *}$ [MelNet formula (11)].

For this reason, we have to "invent" the first frame. In this case, we are going to assume that the frame is all 0.

In [7]:
x_central_pad = F.pad(spectrogram, (1,-1))  # we put -1 to maintain the number of FRAMES equal to 4
x_central_pad.shape, x_central_pad

(torch.Size([1, 3, 4]),
 tensor([[[ 0., 11., 12., 13.],
          [ 0., 15., 16., 17.],
          [ 0., 19., 20., 21.]]]))

Obviously, we could use other values like:
* random values
* average values of the first frame of the spectrograms in our dataset

Now, we are going to implement the linear transformation [MelNet formula (11)].

The central stack, at each timestep, it takes an entire frame as input and outputs a single vector consisting of the RNN hidden state, so we have to manipulate the tensor. 

In [8]:
x_central_pad = x_central_pad.transpose(1,2)
x_central_pad.shape, x_central_pad

(torch.Size([1, 4, 3]),
 tensor([[[ 0.,  0.,  0.],
          [11., 15., 19.],
          [12., 16., 20.],
          [13., 17., 21.]]]))

In [9]:
# Linear transformation
FREQ = 3
W_c_0 = nn.Linear(in_features=3, out_features=HIDDEN_SIZE)
h_c_0 = W_c_0(x_central_pad)
h_c_0 = h_c_0.unsqueeze(dim=1)
print(h_c_0.shape)

torch.Size([1, 1, 4, 2])


Here, we can see how the output shape of the hidden layer zero of the centralized stack corresponds to [B=1, 1, FRAMES=4, HIDDEN_SIZE=2].

# Other layers

TODO: explain computations of other layers

In [10]:
hidden_size = 2
rnn_l = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
W_0 = nn.Linear(in_features=1, out_features=hidden_size)
W_0_c = nn.Linear(in_features=3, out_features=hidden_size)
W_l = nn.Linear(in_features=hidden_size, out_features=hidden_size)

In [11]:
data = torch.Tensor([[[[10],[11],[12],[13]],
                      [[20],[21],[22],[23]],
                      [[30],[31],[32],[33]]],
                     [[[40],[41],[42],[43]],
                      [[50],[51],[52],[53]],
                      [[60],[61],[62],[63]]]
                    ])
B, FREQ, FRAMES, HIDDEN_SIZE = data.size()
B, FREQ, FRAMES, HIDDEN_SIZE

(2, 3, 4, 1)

In [17]:
data = torch.Tensor([[[10,11,12,13],
                      [20,21,22,23],
                      [30,31,32,33]],
                     [[40,41,42,43],
                      [50,51,52,53],
                      [60,61,62,63]]
                    ])
B, FREQ, FRAMES = data.size()
B, FREQ, FRAMES

(2, 3, 4)

In [20]:
res = W_0(data.unsqueeze(-1))
print(res.shape)
print(data.transpose(1,2).shape)
res_c = W_0_c(data.transpose(1,2))
res_c = res_c.unsqueeze(1)
res.shape, res_c.shape

torch.Size([2, 3, 4, 2])
torch.Size([2, 4, 3])


(torch.Size([2, 3, 4, 2]), torch.Size([2, 1, 4, 2]))

In [21]:
res + res_c

tensor([[[[ -1.5819, -14.4504],
          [ -1.9500, -14.6376],
          [ -2.3181, -14.8248],
          [ -2.6862, -15.0120]],

         [[  0.0225,  -6.4117],
          [ -0.3455,  -6.5989],
          [ -0.7136,  -6.7861],
          [ -1.0817,  -6.9733]],

         [[  1.6270,   1.6270],
          [  1.2589,   1.4398],
          [  0.8908,   1.2526],
          [  0.5228,   1.0654]]],


        [[[-12.6244, -20.0658],
          [-12.9924, -20.2530],
          [-13.3605, -20.4402],
          [-13.7286, -20.6274]],

         [[-11.0199, -12.0271],
          [-11.3880, -12.2143],
          [-11.7560, -12.4015],
          [-12.1241, -12.5887]],

         [[ -9.4154,  -3.9884],
          [ -9.7835,  -4.1756],
          [-10.1516,  -4.3628],
          [-10.5197,  -4.5499]]]], grad_fn=<AddBackward0>)

In [None]:
data_backwardfreq = data.transpose(1,2)
print(data_backwardfreq.size())
data_backwardfreq = data_backwardfreq.contiguous().view(-1, FREQ, HIDDEN_SIZE)
print(data_backwardfreq.shape)
data_backwardfreq = data_backwardfreq.flip(1)
data_backwardfreq.shape

In [None]:
res = W_0(data_backwardfreq)
res, hid = rnn:l(res)
res.contiguous().view(B, FRAMES, FREQ, HIDDEN_SIZE).transpose(1,2).shape