In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
import numpy as np
from tqdm import tqdm_notebook
import librosa

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn

In [None]:
from zachary.datasets import AtemporalDataset, GANDataset, f0_midi_from_stft_frame, pseudo_one_hot, sample_z
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.weight_initializers import initialize_model

In [None]:
BATCH_SIZE = 256
DEVICE = get_torch_device()

In [None]:
dataset = AtemporalDataset(audio_directory='/home/kureta/Music/chorales/')

In [None]:
dataset[0][0].shape, dataset[0][1].shape, dataset[0][2].shape, dataset[0][3].shape

In [None]:
from functools import partial

In [None]:
dur = partial(librosa.time_to_frames, sr=44100, hop_length=512, n_fft=1024)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4) 
specgram = dataset[:dur(10)]

fig, (ax1) = plt.subplots(1, 1)
ax1.imshow(specgram.t(), aspect='auto', origin='lower')
pass

In [None]:
def stft_to_signal(S, num_iters=15):
    S_T = S.T

    # Retrieve phase information
    phase = 2 * np.pi * np.random.random_sample(S_T.shape) - np.pi
    signal = None
    for idx in range(num_iters):
        D = S_T * np.exp(1j * phase)
        signal = librosa.istft(D, hop_length=512, win_length=1024)
        # don't calculate phase during the last iteration, because it will not be used.
        if idx < num_iters - 1:
            phase = np.angle(librosa.stft(signal, n_fft=1024, hop_length=512))

    return signal

In [None]:
sig = stft_to_signal((dataset[:dur(10)] * dataset.maxima).numpy(), num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(sig)/44100, len(sig))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, sig)
pass

In [None]:
Audio(sig, rate=44100)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.c1 = nn.Linear(513, 256)
        self.c2 = nn.Linear(256, 128)
        self.c3 = nn.Linear(128, 64)
        self.c4 = nn.Linear(64, 8+87+11+11)
        
    def forward(self, x):
        z = F.relu(self.c1(x))
        z = F.relu(self.c2(z))
        z = F.relu(self.c3(z))
        z = self.c4(z)
        
        return z[:, :8], F.sigmoid(z[:, 8:8+87]), F.sigmoid(z[:, 8+87:8+87+11]), F.sigmoid(z[:, 8+87+11:])

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed1 = nn.Embedding(87, 8)
        self.embed2 = nn.Embedding(11, 4)
        self.embed3 = nn.Embedding(11, 4)
        
        self.c2 = nn.Linear(8+8+4+4, 64)
        self.c3 = nn.Linear(64, 128)
        self.c4 = nn.Linear(128, 256)
        self.c5 = nn.Linear(256, 513)
        
    def forward(self, x, f0, conf, loud):
        f0 = F.relu(self.embed1(f0).squeeze(1))
        conf = F.relu(self.embed2(conf).squeeze(1))
        loud = F.relu(self.embed3(loud).squeeze(1))
        z = F.relu(self.c2(torch.cat([x, f0, conf, loud], 1)))
        z = F.relu(self.c3(z))
        z = F.relu(self.c4(z))
        z = self.c5(z)
        
        return z

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()  

    def forward(self, x, f0, conf, loud):
        z, f0_hat, conf_hat, loud_hat = self.encoder(x)
        y = self.decoder(z, f0, conf, loud)
        
        return y, f0_hat, conf_hat, loud_hat

In [None]:
model = Autoencoder()
initialize_model(model)
print('\t', get_num_trainable_params(model))

In [None]:
def custom_loss(x, x_hat, f0, conf, loud, f0_hat, conf_hat, loud_hat):
    x_loss = F.mse_loss(x_hat, x)
    f0_loss = F.cross_entropy(f0_hat, f0.squeeze(1))
    conf_loss = F.cross_entropy(conf_hat, conf.squeeze(1))
    loud_loss = F.cross_entropy(loud_hat, loud.squeeze(1))
    return x_loss + f0_loss + conf_loss + loud_loss

In [None]:
loss_fn = custom_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
data_loader = DataLoader(dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

### This is the training loop

In [None]:
model.to(DEVICE)
model.train()
for i in range(5):
    batch = 1
    with tqdm_notebook(total=len(dataset)) as pbar:
        for absolute, f0, conf, loud in data_loader:
            optimizer.zero_grad()
            absolute = absolute.to(DEVICE)
            f0 = f0.to(DEVICE)
            conf = conf.to(DEVICE)
            loud = loud.to(DEVICE)
            x_hat, f0_hat, conf_hat, loud_hat = model(absolute, f0, conf, loud)
            loss = loss_fn(absolute, x_hat, f0, conf, loud, f0_hat, conf_hat, loud_hat)

            pbar.set_description(f'Epoch: {i + 1} - loss: {loss.data.cpu().numpy():.2E}')
            pbar.update(absolute.shape[0])

            batch += 1

            loss.backward()
            optimizer.step()

## Test performance

In [None]:
sample = torch.zeros((dur(10), 513))
f0s = torch.zeros(dur(10), 1, dtype=torch.int64)
confidences = torch.zeros(dur(10), 1, dtype=torch.int64)
loudnesses = torch.zeros(dur(10), 1, dtype=torch.int64)
for idx in range(dur(10)):
    sample[idx] = dataset[idx][0]
    f0s[idx] = dataset[idx][1]
    confidences[idx] = dataset[idx][2]
    loudnesses[idx] = dataset[idx][3]
sample = sample.to(DEVICE)
f0s = f0s.to(DEVICE)
confidences = confidences.to(DEVICE)
loudnesses = loudnesses.to(DEVICE)
pass

In [None]:
sample.shape, f0s.shape

In [None]:
model.eval()
with torch.no_grad():
    sample_hat, _, _, _ = model(sample, f0s, confidences, loudnesses)

sample_hat = sample_hat.cpu() * dataset.maxima
sample_hat_np = sample_hat.numpy()

In [None]:
signal_hat = stft_to_signal(sample_hat_np, num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(signal_hat)/44100, len(signal_hat))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, signal_hat)
pass

In [None]:
Audio(signal_hat, rate=44100)

In [None]:
resolution = 50
num_cv = dur(10) // resolution

In [None]:
zs = sample_z(8, 0., 1., num_cv, resolution, 2, True)

In [None]:
zs_t = torch.from_numpy(zs.astype('float32')).to(DEVICE)

In [None]:
constant = torch.zeros((zs_t.shape[0], f0s.shape[1]), dtype=torch.int64)
constant[:] = 10
constant = constant.to(DEVICE)

constant1 = torch.zeros((zs_t.shape[0], confidences.shape[1]), dtype=torch.int64)
constant1[:] = 5
constant1 = constant1.to(DEVICE)

constant2 = torch.zeros((zs_t.shape[0], loudnesses.shape[1]), dtype=torch.int64)
constant2[:] = 4
constant2 = constant2.to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    y = model.decoder(zs_t, constant, constant1, constant2)

y_hat = y.cpu() * dataset.maxima
y_hat_np = y_hat.numpy()

In [None]:
s_hat = stft_to_signal(y_hat_np, num_iters=100)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(s_hat)/44100, len(s_hat))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, s_hat)
pass

In [None]:
Audio(s_hat, rate=44100)

In [None]:
gan_dataset = GANDataset(dataset, model.encoder, example_length=64, stft_hop_length=32)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm1d(128),

            nn.ConvTranspose1d(128, 64, 4, 4),
            nn.BatchNorm1d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose1d(64, 32, 4, 4),
            nn.BatchNorm1d(32, 0.8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose1d(32, 16, 4, 4),
        )

    def forward(self, z):
        img = self.conv_blocks(z)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv_blocks = nn.Sequential(
            nn.Conv1d(16, 32, 4, 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            nn.BatchNorm1d(32, 0.8),
            
            nn.Conv1d(32, 64, 4, 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            nn.BatchNorm1d(64, 0.8),
            
            nn.Conv1d(64, 128, 4, 4),
        )

    def forward(self, img):
        validities = self.conv_blocks(img)
        return validities

In [None]:
gen = Generator()
disc = Discriminator()
gen.to(DEVICE)
disc.to(DEVICE)
pass

In [None]:
asd = torch.ones(1, 128, 4)
asd.to(DEVICE)
pass

In [None]:
gen(asd).shape

In [None]:
disc(gen(asd)).shape

In [None]:
qwe = gan_dataset[0].transpose(0, 1).to(DEVICE)

model.eval()
with torch.no_grad():
    qwe = model.encoder(qwe).transpose(0, 1)

disc(qwe.unsqueeze(0)).shape