In [None]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np

In [None]:
import librosa

In [None]:
from tqdm import tnrange, tqdm_notebook

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
from zachary.preprocess.datasets import AtemporalDataset, TemporalDataset, AtemporalMidiDataset, TemporalMidiDataset
from zachary.preprocess.base import Configuration
from zachary.preprocess.utils import load_audio_file, spectrum_from_signal
from zachary.weight_initializers import initialize_model
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.model.autoencoder import Autoencoder
from zachary.model.generator import Generator
from zachary.model.discriminator import Discriminator

In [None]:
def plot_spectrum(spect):
    plt.rcParams['figure.figsize'] = (19, 6)

    fig, ax1 = plt.subplots(1, 1)
    ax1.imshow(spect, aspect='auto', interpolation='none', origin='lower')
    pass

In [None]:
def stft_to_signal(S, num_iters=15):
    # Retrieve phase information
    phase = 2 * np.pi * np.random.random_sample(S.shape) - np.pi
    signal = None
    for idx in range(num_iters):
        D = S * np.exp(1j * phase)
        signal = librosa.istft(D, hop_length=conf.hop_length, win_length=conf.frame_length)
        # don't calculate phase during the last iteration, because it will not be used.
        if idx < num_iters - 1:
            phase = np.angle(librosa.stft(signal, n_fft=conf.frame_length, hop_length=conf.hop_length))

    return signal

In [None]:
BATCH_SIZE = 128
DEVICE = get_torch_device()

conf = Configuration(audio_dir='/home/kureta/Music/Beethoven Piano Sonatas Barenboim/small/',
                     midi_dir='/home/kureta/Music/midi/beethoven/small/')

audio_dataset = AtemporalDataset(conf=conf)
audio_loader = DataLoader(audio_dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
autoencoder = Autoencoder(513, 32, 3)
print(get_num_trainable_params(autoencoder))

In [None]:
ae_criterion = nn.L1Loss()

In [None]:
ae_optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
autoencoder.train()
autoencoder.to(DEVICE)

for i in tnrange(50, desc='Epochs'):
    step = 0
    progress = tqdm_notebook(audio_loader, total=len(audio_dataset)//BATCH_SIZE)
    for x in progress:
        x = x.to(DEVICE)
        
        ae_optimizer.zero_grad()
        x_hat = autoencoder(x)
        loss = ae_criterion(x_hat, x)
        loss.backward()
        ae_optimizer.step()
        
        if step % 100 == 0:
            progress.set_description(f'Loss: {loss:.2e}')
        step += 1

In [None]:
example = audio_dataset[:1000]

In [None]:
plot_spectrum(example.numpy().T)

In [None]:
autoencoder.eval()
with torch.no_grad():
    x_hat = autoencoder(example.to(DEVICE)).cpu()

In [None]:
plot_spectrum(x_hat.numpy().T)

In [None]:
audio_hat = stft_to_signal(x_hat.numpy().T * audio_dataset.maxima.numpy(), 32)

In [None]:
Audio(data=audio_hat, rate=conf.sample_rate)

In [None]:
gen_a_b = Generator(513, 128, 64, 4)
gen_b_a = Generator(128, 513, 64, 4)
disc_a = Discriminator(513)
disc_b = Discriminator(128)

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
optimizer_gen = torch.optim.Adam(chain(gen_a_b.parameters(), gen_b_a.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_a = torch.optim.Adam(disc_a.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_b = torch.optim.Adam(disc_b.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
gen_a_b.train()
gen_b_a.train()
disc_a.train()
disc_b.train()
gen_a_b.to(DEVICE)
gen_b_a.to(DEVICE)
disc_a.to(DEVICE)
disc_b.to(DEVICE)


sizes = [8, 16, 32, 64]
for i, size in enumerate(sizes):
    dataset_a.example_length = size
    dataset_b.example_length = size
    dataset_a.example_hop_length = size // 2
    dataset_b.example_hop_length = size // 2
    step = 0
    
    for example_a, example_b in zip(data_loader_a, data_loader_b):
        example_a = example_a.to(DEVICE)
        example_b = example_b.to(DEVICE)
        
        optimizer_gen.zero_grad()
        
        # Identity loss
#         a_id = gen_b_a(example_a)
#         b_id = gen_a_b(example_b)
#         loss_id = criterion_identity(example_a, a_id) + criterion_identity(example_b, b_id)
        
        # GAN loss
        a_hat = gen_b_a(example_b)
        b_hat = gen_a_b(example_a)
#         with torch.no_grad():
        is_fake_a = disc_a(a_hat)
        is_fake_b = disc_b(b_hat)
        loss_GAN = criterion_GAN(is_fake_a, torch.ones_like(is_fake_a)) + criterion_GAN(is_fake_b, torch.ones_like(is_fake_b))
        
        # Cycle loss
        cycled_a = gen_b_a(b_hat)
        cycled_b = gen_a_b(a_hat)
        loss_cycle = criterion_cycle(cycled_a, example_a) + criterion_cycle(cycled_b, example_b)
        
        # Total generator loss
        # loss_gen = 0.5 * loss_id + 10.0 * loss_cycle + loss_GAN
        loss_gen = 10.0 * loss_cycle + loss_GAN
        loss_gen.backward(retain_graph=True)
        
        if step % 10 == 0:
            # optimize generators
            optimizer_gen.step()

            optimizer_disc_a.zero_grad()

            # Disc A loss
            is_real_a = disc_a(example_a)
    #         is_fake_a = disc_a(a_hat.unsqueeze(1))
            loss_disc_a = criterion_GAN(is_real_a, torch.ones_like(is_real_a)) + criterion_GAN(is_fake_a, torch.zeros_like(is_fake_a))
            loss_disc_a.backward()

            # Optimize Discriminator A
            optimizer_disc_a.step()

            optimizer_disc_b.zero_grad()

            # Disc B loss
            is_real_b = disc_b(example_b)
    #         is_fake_b = disc_b(b_hat.unsqueeze(1))
            loss_disc_b = criterion_GAN(is_real_b, torch.ones_like(is_real_b)) + criterion_GAN(is_fake_b, torch.zeros_like(is_fake_b))
            loss_disc_b.backward()

            # Optimize Discriminator B
            optimizer_disc_b.step()
        step += 1
        if step % 100 == 0:
            print(f'({size}) iteration: {step}/{dataset_b.midi.shape[0]}, generator_loss: {loss_gen:.4e}, cycle_loss: {loss_cycle:.4e}, '
                  f'gan_loss: {loss_GAN:.4e}, disc_loss: {loss_disc_a + loss_disc_b:.4e}')

In [None]:
path_a = '/home/kureta/Music/Beethoven Piano Sonatas Barenboim/split-track01.ape'
path_b = '/home/kureta/Music/midi/beethoven/small/appass_1.mid'

In [None]:
from zachary.preprocess.base import load_midi_file, trim_zeros

In [None]:
audio_a = load_audio_file(path_a, conf)
midi_b = load_midi_file(path_b, conf)
midi_b = trim_zeros(midi_b).astype('float32')

In [None]:
Audio(audio_a, rate=conf.sample_rate)

In [None]:
plot_spectrum(midi_b.T)

In [None]:
spectrum_a = torch.from_numpy(librosa.amplitude_to_db(spectrum_from_signal(audio_a, conf)))

In [None]:
plot_spectrum(spectrum_a.transpose(0, 1))

In [None]:
gen_a_b.eval()
with torch.no_grad():
    midi_b_hat = gen_a_b(spectrum_a.transpose(0, 1).unsqueeze(0).to(DEVICE))

In [None]:
plot_spectrum(midi_b_hat.squeeze(0).cpu())

In [None]:
midi_b_hat.min(), midi_b_hat.max()

In [None]:
gen_b_a.eval()
with torch.no_grad():
    spectrum_a_hat = gen_b_a(torch.from_numpy(midi_b).transpose(0, 1).unsqueeze(0).to(DEVICE))

In [None]:
plot_spectrum(spectrum_a_hat.squeeze(0).cpu())

In [None]:
s = spectrum_a_hat.squeeze(0)[:, :1000].cpu()

In [None]:
audio_a_hat = stft_to_signal(librosa.db_to_amplitude(s.numpy()))

In [None]:
Audio(audio_a_hat, rate=conf.sample_rate)