In [None]:
from IPython.display import Audio
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import numpy as np
from tqdm import tqdm_notebook
import librosa

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn

In [None]:
from zachary.datasets import AtemporalDataset
from zachary.constants import Configuration
from zachary.sampling import sample_z
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.weight_initializers import initialize_model

In [None]:
BATCH_SIZE = 256
DEVICE = get_torch_device()
conf = Configuration(audio_dir='/home/kureta/Music/V24', hop_length=256)

In [None]:
dataset = AtemporalDataset(conf)

In [None]:
def get_idx(data, start, stop, idx):
    if len(data[0][idx].shape) == 0:
        result = torch.zeros((stop - start, ))
    else:
        result = torch.zeros((stop - start, data[0][idx].shape[0]))
    for i in range(start, stop):
        result[i - start] = data[i][idx]
    return result

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
specgram = get_idx(dataset, conf.time_to_frames(30), conf.time_to_frames(40), 0)

fig, (ax1) = plt.subplots(1, 1)
ax1.imshow(specgram.t(), aspect='auto', origin='lower')
pass

In [None]:
def stft_to_signal(S, num_iters=15):
    S_T = S.T

    # Retrieve phase information
    phase = 2 * np.pi * np.random.random_sample(S_T.shape) - np.pi
    signal = None
    for idx in range(num_iters):
        D = S_T * np.exp(1j * phase)
        signal = librosa.istft(D, hop_length=conf.hop_length, win_length=conf.frame_length)
        # don't calculate phase during the last iteration, because it will not be used.
        if idx < num_iters - 1:
            phase = np.angle(librosa.stft(signal, n_fft=conf.frame_length, hop_length=conf.hop_length))

    return signal

In [None]:
sig = stft_to_signal((specgram * dataset.maxima).numpy(), num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(sig) / conf.sample_rate, len(sig))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, sig)
pass

In [None]:
Audio(sig, rate=conf.sample_rate)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.c1 = nn.Linear(513, 256)
        self.c2 = nn.Linear(256, 128)
        self.c3 = nn.Linear(128, 64)
        self.c4 = nn.Linear(64, 8)

    def forward(self, x):
        z = F.relu(self.c1(x))
        z = F.relu(self.c2(z))
        z = F.relu(self.c3(z))
        z = self.c4(z)
        
        return z

In [None]:
class LabelDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.c2 = nn.Linear(8+3, 64)
        self.c3 = nn.Linear(64, 128)
        self.c4 = nn.Linear(128, 256)
        self.c5 = nn.Linear(256, 513)

    def forward(self, z, *labels):
        x = F.relu(self.c2(torch.cat([z, *labels], 1)))
        x = F.relu(self.c3(x))
        x = F.relu(self.c4(x))
        x = torch.sigmoid(self.c5(x))

        return x

In [None]:
class LabelAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = LabelDecoder()
    
    def forward(self, x, *labels):
        z = self.encoder(x)
        y = self.decoder(z, *labels)
        
        return y

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.c2 = nn.Linear(8, 64)
        self.c3 = nn.Linear(64, 128)
        self.c4 = nn.Linear(128, 256)
        self.c5 = nn.Linear(256, 513)

    def forward(self, z):
        x = F.relu(self.c2(z))
        x = F.relu(self.c3(x))
        x = F.relu(self.c4(x))
        x = torch.sigmoid(self.c5(x))

        return x

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        
        return y

In [None]:
model = LabelAutoencoder()

In [None]:
loss_function = F.mse_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
data_loader = DataLoader(
    dataset,
    pin_memory=True,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8)

In [None]:
model.train()
model.to(DEVICE)
for i in range(5):
    batch = 1
    with tqdm_notebook(total=len(dataset)) as pbar:
        for spectrum, pitch, confidence, loudness in data_loader:
            spectrum = spectrum.to(DEVICE)
            pitch = pitch.to(DEVICE).unsqueeze(1)
            confidence = confidence.to(DEVICE).unsqueeze(1)
            loudness = loudness.to(DEVICE).unsqueeze(1)
            optimizer.zero_grad()
            
            spectrum_hat = model(spectrum, pitch, confidence, loudness)
            loss = loss_function(spectrum_hat, spectrum)
            
            loss.backward()
            optimizer.step()
            
            pbar.set_description(
                f'Epoch: {i + 1} - loss: {loss.data.cpu().numpy():.2E}')
            pbar.update(spectrum.shape[0])

            batch += 1

## Test performance

In [None]:
sample = get_idx(dataset, conf.time_to_frames(10), conf.time_to_frames(20), 0).to(DEVICE)
f0s =  get_idx(dataset, conf.time_to_frames(10), conf.time_to_frames(20), 1).to(DEVICE).unsqueeze(1)
confidences =  get_idx(dataset, conf.time_to_frames(10), conf.time_to_frames(20), 2).to(DEVICE).unsqueeze(1)
loudnesses =  get_idx(dataset, conf.time_to_frames(10), conf.time_to_frames(20), 3).to(DEVICE).unsqueeze(1)

In [None]:
sample.shape, f0s.shape

In [None]:
model.eval()
with torch.no_grad():
    sample_hat = model(sample, f0s, confidences, loudnesses)

sample_hat = sample_hat.cpu() * dataset.maxima
sample_hat_np = sample_hat.numpy()

In [None]:
signal_hat = stft_to_signal(sample_hat_np, num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(signal_hat) / 44100, len(signal_hat))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, signal_hat)
pass

In [None]:
Audio(signal_hat, rate=44100)

In [None]:
resolution = 50
num_cv = conf.time_to_frames(10) // resolution

In [None]:
zs = sample_z(8, 0., 1., num_cv, resolution, 2, True)

In [None]:
zs_t = torch.from_numpy(zs.astype('float32')).to(DEVICE)

In [None]:
constant = torch.zeros((zs_t.shape[0], f0s.shape[1]))
constant[:] = 0.3
constant = constant.to(DEVICE)

constant1 = torch.zeros((zs_t.shape[0], confidences.shape[1]))
constant1[:] = 0.9
constant1 = constant1.to(DEVICE)

constant2 = torch.zeros((zs_t.shape[0], loudnesses.shape[1]))
constant2[:] = 0.5
constant2 = constant2.to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    y = model.decoder(zs_t, constant, constant1, constant2)

y_hat = y.cpu() * dataset.maxima
y_hat_np = y_hat.numpy()

In [None]:
s_hat = stft_to_signal(y_hat_np, num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(s_hat) / 44100, len(s_hat))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, s_hat)
pass

In [None]:
Audio(s_hat, rate=44100)