# Scratch
This is a scratch notebook for the audio generator.

In [6]:
import torch
import torch.nn
import corpus

In [4]:
torch.cuda.is_available()

True

In [22]:
from torchaudio.transforms import Spectrogram
class Featurizer(torch.nn.Module):
    def __init__(self, n_fft=1024, n_hop=512):
        super().__init__()
        self.n_fft = n_fft
        self.n_hop = n_hop
        self.spectrogram = Spectrogram(self.n_fft, hop_length=self.n_hop, power=None)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        complex_spectrum = self.spectrogram(waveform)
        magnitudes = torch.abs(complex_spectrum)
        phases = torch.angle(complex_spectrum)
        # [:, :self.n_fft/2+1, :] is the magnitude spectrum
        # [:, self.n_fft/2+1:, :] is the phase spectrum
        return torch.cat((magnitudes, phases), dim=1)

In [23]:
audio, sr = torchaudio.load("../data/grains.wav")

In [24]:
featurizer = Featurizer(1024, 512)

In [26]:
features = featurizer(audio)

In [27]:
print(features.shape)

torch.Size([1, 1026, 1423])


In [30]:
class Predictor(torch.nn.Module):
    def __init__(self, dim=513*2, num_layers=2, hidden_size=2048):
        super().__init__()
        self.dim = dim
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lstm = torch.nn.LSTM(dim, hidden_size, num_layers)
        self.output = torch.nn.Linear(hidden_size, dim)

    def forward(self, x) -> torch.Tensor:
        output, hidden = self.lstm(x)
        logits = self.output(output[-1, :, :])
        return logits, hidden

    def init_hidden(self, batch_size=1):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size), torch.zeros(self.num_layers, batch_size, self.hidden_size)
    

In [31]:
model = Predictor(513*2, 2, 2048)

In [39]:
# need to reshape so that the FFT dimensions are the last dimension
# dim1: channel
# dim2: STFT frames
# dim3: FFT bins
output = model(features.reshape((features.shape[0], features.shape[2], features.shape[1])))

In [41]:
output[0].shape

torch.Size([1423, 1026])