In [None]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np

In [None]:
from itertools import chain

In [None]:
import librosa

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
from zachary.preprocess.datasets import AtemporalDataset, TemporalDataset, AtemporalMidiDataset, TemporalMidiDataset
from zachary.preprocess.base import Configuration
from zachary.preprocess.utils import load_audio_file, spectrum_from_signal
from zachary.weight_initializers import initialize_model
from zachary.utils import get_torch_device, get_num_trainable_params

In [None]:
def plot_spectrum(spect):
    plt.rcParams['figure.figsize'] = (19, 6)

    fig, ax1 = plt.subplots(1, 1)
    ax1.imshow(spect, aspect='auto', interpolation='none', origin='lower')
    pass

In [None]:
 class Encoder(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Encoder, self).__init__()
            
            self.conv1 = nn.Conv1d(
                in_channels=in_channels,
                out_channels=256,
                kernel_size=5,
                stride=1,
                padding=2
            )
            self.conv2 = nn.Conv1d(
                in_channels=256,
                out_channels=128,
                kernel_size=3,
                stride=2,
                padding=1
            )
            self.conv3 = nn.Conv1d(
                in_channels=128,
                out_channels=out_channels,
                kernel_size=3,
                stride=2,
                padding=1)
            
            self.norm1 = nn.InstanceNorm1d(num_features=256)
            self.norm2 = nn.InstanceNorm1d(num_features=128)
            self.norm3 = nn.InstanceNorm1d(num_features=64)
        
        def forward(self, x):
            z = F.leaky_relu(self.norm1(self.conv1(x)))
            z = F.leaky_relu(self.norm2(self.conv2(z)))
            z = F.leaky_relu(self.norm3(self.conv3(z)))

            return z

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channels=64):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv1d(
            in_channels=num_channels,
            out_channels=num_channels // 2,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.conv2 = nn.Conv1d(
            in_channels=num_channels // 2,
            out_channels=num_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        
        self.norm1 = nn.InstanceNorm1d(num_features=num_channels * 2)
        self.norm2 = nn.InstanceNorm1d(num_features=num_channels)
    
    def forward(self, z):
        z_hat = F.leaky_relu(self.norm1(self.conv1(z)))
        z_hat = F.leaky_relu(self.norm2(self.conv2(z_hat)))
        return z + z_hat        

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_channels, num_blocks):
        super(Transformer, self).__init__()
        
        self.blocks = []
        for i in range(num_blocks):
            self.blocks.append(ResidualBlock(num_channels))
    
    def forward(self, z):
        z_hat = z
        for block in self.blocks:
            z_hat = block(z_hat)

        return z_hat
        
    def to(self, *args, **kwargs):
        super(Transformer, self).to(*args, **kwargs)
        for block in self.blocks:
            block.to(*args, **kwargs)

In [None]:
 class Decoder(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Decoder, self).__init__()
            
            self.conv1 = nn.ConvTranspose1d(
                in_channels=in_channels,
                out_channels=512,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1
            )
            self.conv2 = nn.ConvTranspose1d(
                in_channels=512,
                out_channels=1024,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1
            )
            self.conv3 = nn.ConvTranspose1d(
                in_channels=1024,
                out_channels=out_channels,
                kernel_size=5,
                stride=1,
                padding=2
            )
            self.norm1 = nn.InstanceNorm1d(num_features=512)
            self.norm2 = nn.InstanceNorm1d(num_features=1024)
        
        def forward(self, z):
            y_hat = F.leaky_relu(self.norm1(self.conv1(z)))
            y_hat = F.leaky_relu(self.norm2(self.conv2(y_hat)))
            y_hat = self.conv3(y_hat)

            return y_hat

In [None]:
class Generator(nn.Module):
    def __init__(self, a_channels, b_channels, z_channels, transformer_depth):
        super(Generator, self).__init__()
        
        self.encoder = Encoder(a_channels, z_channels)
        self.transformer = Transformer(z_channels, transformer_depth)
        self.decoder = Decoder(z_channels, b_channels)
    
    def forward(self, x):
        z = self.encoder(x)
        z_hat = self.transformer(z)
        y_hat = self.decoder(z_hat)
        
        return y_hat
    
    def to(self, *args, **kwargs):
        super(Generator, self).to(*args, **kwargs)
        self.encoder.to(*args, **kwargs)
        self.transformer.to(*args, **kwargs)
        self.decoder.to(*args, **kwargs)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Conv1d(
                in_channels=in_channels,
                out_channels=1024,
                kernel_size=3,
                stride=2,
                padding=1
            )
        self.conv2 = nn.Conv1d(
                in_channels=1024,
                out_channels=512,
                kernel_size=3,
                stride=2,
                padding=1
            )
        self.conv3 = nn.Conv1d(
                in_channels=512,
                out_channels=1,
                kernel_size=3,
                stride=2,
                padding=1
            )
#         self.conv4 = nn.Conv2d(
#                 in_channels=256,
#                 out_channels=512,
#                 kernel_size=3,
#                 stride=2,
#                 padding=1
#             )
#         self.conv5 = nn.Conv2d(
#                 in_channels=512,
#                 out_channels=1,
#                 kernel_size=3,
#                 stride=2,
#                 padding=1
#             )
        
        self.norm1 = nn.InstanceNorm1d(num_features=64)
        self.norm2 = nn.InstanceNorm1d(num_features=128)
        self.norm3 = nn.InstanceNorm1d(num_features=256)
#         self.norm4 = nn.InstanceNorm2d(num_features=512)
    
    def forward(self, x):
        d = F.leaky_relu(self.norm1(self.conv1(x)))
        d = F.leaky_relu(self.norm2(self.conv2(d)))
#         d = F.leaky_relu(self.norm3(self.conv3(d)))
#         d = F.leaky_relu(self.norm4(self.conv4(d)))
        d = F.sigmoid(self.conv3(d))
        
        return d

In [None]:
gen_a_b = Generator(513, 128, 64, 4)
gen_b_a = Generator(128, 513, 64, 4)
disc_a = Discriminator(513)
disc_b = Discriminator(128)

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
optimizer_gen = torch.optim.Adam(chain(gen_a_b.parameters(), gen_b_a.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_a = torch.optim.Adam(disc_a.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_b = torch.optim.Adam(disc_b.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
BATCH_SIZE = 1
DEVICE = get_torch_device()

conf = Configuration(audio_dir='/home/kureta/Music/Beethoven Piano Sonatas Barenboim/small/',
                     midi_dir='/home/kureta/Music/midi/beethoven/small/')

dataset_a = TemporalDataset(conf=conf, example_length=32, example_hop_length=4)
data_loader_a = DataLoader(dataset_a, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

dataset_b = TemporalMidiDataset(conf=conf, example_length=32, example_hop_length=4)
data_loader_b = DataLoader(dataset_b, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
gen_a_b.train()
gen_b_a.train()
disc_a.train()
disc_b.train()
gen_a_b.to(DEVICE)
gen_b_a.to(DEVICE)
disc_a.to(DEVICE)
disc_b.to(DEVICE)


sizes = [8, 16, 32, 64]
for i, size in enumerate(sizes):
    dataset_a.example_length = size
    dataset_b.example_length = size
    dataset_a.example_hop_length = size // 2
    dataset_b.example_hop_length = size // 2
    step = 0
    
    for example_a, example_b in zip(data_loader_a, data_loader_b):
        example_a = example_a.to(DEVICE)
        example_b = example_b.to(DEVICE)
        
        optimizer_gen.zero_grad()
        
        # Identity loss
#         a_id = gen_b_a(example_a)
#         b_id = gen_a_b(example_b)
#         loss_id = criterion_identity(example_a, a_id) + criterion_identity(example_b, b_id)
        
        # GAN loss
        a_hat = gen_b_a(example_b)
        b_hat = gen_a_b(example_a)
#         with torch.no_grad():
        is_fake_a = disc_a(a_hat)
        is_fake_b = disc_b(b_hat)
        loss_GAN = criterion_GAN(is_fake_a, torch.ones_like(is_fake_a)) + criterion_GAN(is_fake_b, torch.ones_like(is_fake_b))
        
        # Cycle loss
        cycled_a = gen_b_a(b_hat)
        cycled_b = gen_a_b(a_hat)
        loss_cycle = criterion_cycle(cycled_a, example_a) + criterion_cycle(cycled_b, example_b)
        
        # Total generator loss
        # loss_gen = 0.5 * loss_id + 10.0 * loss_cycle + loss_GAN
        loss_gen = 10.0 * loss_cycle + loss_GAN
        loss_gen.backward(retain_graph=True)
        
        if step % 10 == 0:
            # optimize generators
            optimizer_gen.step()

            optimizer_disc_a.zero_grad()

            # Disc A loss
            is_real_a = disc_a(example_a)
    #         is_fake_a = disc_a(a_hat.unsqueeze(1))
            loss_disc_a = criterion_GAN(is_real_a, torch.ones_like(is_real_a)) + criterion_GAN(is_fake_a, torch.zeros_like(is_fake_a))
            loss_disc_a.backward()

            # Optimize Discriminator A
            optimizer_disc_a.step()

            optimizer_disc_b.zero_grad()

            # Disc B loss
            is_real_b = disc_b(example_b)
    #         is_fake_b = disc_b(b_hat.unsqueeze(1))
            loss_disc_b = criterion_GAN(is_real_b, torch.ones_like(is_real_b)) + criterion_GAN(is_fake_b, torch.zeros_like(is_fake_b))
            loss_disc_b.backward()

            # Optimize Discriminator B
            optimizer_disc_b.step()
        step += 1
        if step % 100 == 0:
            print(f'({size}) iteration: {step}/{dataset_b.midi.shape[0]}, generator_loss: {loss_gen:.4e}, cycle_loss: {loss_cycle:.4e}, '
                  f'gan_loss: {loss_GAN:.4e}, disc_loss: {loss_disc_a + loss_disc_b:.4e}')

In [None]:
def stft_to_signal(S, num_iters=15):
    # Retrieve phase information
    phase = 2 * np.pi * np.random.random_sample(S.shape) - np.pi
    signal = None
    for idx in range(num_iters):
        D = S * np.exp(1j * phase)
        signal = librosa.istft(D, hop_length=conf.hop_length, win_length=conf.frame_length)
        # don't calculate phase during the last iteration, because it will not be used.
        if idx < num_iters - 1:
            phase = np.angle(librosa.stft(signal, n_fft=conf.frame_length, hop_length=conf.hop_length))

    return signal

In [None]:
path_a = '/home/kureta/Music/Beethoven Piano Sonatas Barenboim/split-track01.ape'
path_b = '/home/kureta/Music/midi/beethoven/small/appass_1.mid'

In [None]:
from zachary.preprocess.base import load_midi_file, trim_zeros

In [None]:
audio_a = load_audio_file(path_a, conf)
midi_b = load_midi_file(path_b, conf)
midi_b = trim_zeros(midi_b).astype('float32')

In [None]:
Audio(audio_a, rate=conf.sample_rate)

In [None]:
plot_spectrum(midi_b.T)

In [None]:
spectrum_a = torch.from_numpy(librosa.amplitude_to_db(spectrum_from_signal(audio_a, conf)))

In [None]:
plot_spectrum(spectrum_a.transpose(0, 1))

In [None]:
gen_a_b.eval()
with torch.no_grad():
    midi_b_hat = gen_a_b(spectrum_a.transpose(0, 1).unsqueeze(0).to(DEVICE))

In [None]:
plot_spectrum(midi_b_hat.squeeze(0).cpu())

In [None]:
midi_b_hat.min(), midi_b_hat.max()

In [None]:
gen_b_a.eval()
with torch.no_grad():
    spectrum_a_hat = gen_b_a(torch.from_numpy(midi_b).transpose(0, 1).unsqueeze(0).to(DEVICE))

In [None]:
plot_spectrum(spectrum_a_hat.squeeze(0).cpu())

In [None]:
s = spectrum_a_hat.squeeze(0)[:, :1000].cpu()

In [None]:
audio_a_hat = stft_to_signal(librosa.db_to_amplitude(s.numpy()))

In [None]:
Audio(audio_a_hat, rate=conf.sample_rate)