In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
import numpy as np
import librosa

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
from zachary.datasets import MagPhaseSTFT
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.modules import AutoEncoder
from zachary.weight_initializers import initialize_model
from zachary.plotting import plot_mag_phase

In [None]:
BATCH_SIZE = 128
DEVICE = get_torch_device()

In [None]:
dataset = MagPhaseSTFT()

In [None]:
plot_mag_phase(dataset.denormalize(dataset[0][0]))

In [None]:
model = AutoEncoder()
initialize_model(model)

In [None]:
loss_fn = F.mse_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters())

### Increase `example_length` during training

In [None]:
example_length = 7
dataset.example_length = example_length
data_loader = DataLoader(dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

### This is the training loop

In [None]:
model.to(DEVICE)
model.train()
batch = 1
for example, target in data_loader:
    optimizer.zero_grad()
    loss = loss_fn(model(example.to(DEVICE)), target.to(DEVICE))
    
    if batch % 100 == 0:
        print(loss)
    batch += 1
    
    loss.backward()
    optimizer.step()

### Test performance

In [None]:
sample = dataset.audio[:, :512, :].unsqueeze(0).to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    sample_hat = model(sample)

In [None]:
sample_hat_np = dataset.denormalize(sample_hat.squeeze(0).cpu()).numpy()

In [None]:
plot_mag_phase(sample_hat_np)

In [None]:
sample_np = dataset.denormalize(sample.squeeze(0).cpu()).numpy()

In [None]:
plot_mag_phase(sample_np)

In [None]:
def istft(x):
    return librosa.istft(x[:, :, 0] + 1j * x[:, :, 1], hop_length=512, win_length=1024, center=False)

In [None]:
audio = istft(sample_hat_np)

In [None]:
Audio(audio[1024:-1024], rate=44100)