In [None]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np

In [None]:
from itertools import chain

In [None]:
import librosa

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
from zachary.preprocess.datasets import AtemporalDataset, TemporalDataset
from zachary.preprocess.base import Configuration
from zachary.weight_initializers import initialize_model
from zachary.utils import get_torch_device, get_num_trainable_params

In [None]:
def plot_spectrum(spect):
    plt.rcParams['figure.figsize'] = (19, 6)

    fig, ax1 = plt.subplots(1, 1)
    ax1.imshow(spect, aspect='auto', interpolation='none', origin='lower')
    pass

In [None]:
 class Encoder(nn.Module):
        def __init__(self):
            super(Encoder, self).__init__()
            
            self.conv1 = nn.Conv1d(
                in_channels=513,
                out_channels=256,
                kernel_size=5,
                stride=1,
                padding=2
            )
            self.conv2 = nn.Conv1d(
                in_channels=256,
                out_channels=128,
                kernel_size=3,
                stride=2,
                padding=1
            )
            self.conv3 = nn.Conv1d(
                in_channels=128,
                out_channels=64,
                kernel_size=3,
                stride=2,
                padding=1)
            
            self.norm1 = nn.InstanceNorm1d(num_features=256)
            self.norm2 = nn.InstanceNorm1d(num_features=128)
            self.norm3 = nn.InstanceNorm1d(num_features=64)
        
        def forward(self, x):
            z = F.leaky_relu(self.norm1(self.conv1(x)))
            z = F.leaky_relu(self.norm2(self.conv2(z)))
            z = F.leaky_relu(self.norm3(self.conv3(z)))
            
            return z

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channels=64):
        super(ResidualBlock, self).__init__()
        
        self.conv = nn.Conv1d(
            in_channels=num_channels,
            out_channels=num_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        
        self.norm = nn.InstanceNorm1d(num_features=num_channels)
    
    def forward(self, z):
        z_hat = F.leaky_relu(self.norm(self.conv(z)))
        return z + z_hat        

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_channels=64, num_blocks=6):
        super(Transformer, self).__init__()
        
        self.blocks = []
        for i in range(num_blocks):
            self.blocks.append(ResidualBlock(num_channels))
    
    def forward(self, z):
        for block in self.blocks:
            z = block(z)
        
        return z
    
        
    def to(self, *args, **kwargs):
        super(Transformer, self).to(*args, **kwargs)
        for block in self.blocks:
            block.to(*args, **kwargs)

In [None]:
 class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            
            self.conv1 = nn.ConvTranspose1d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                stride=2,
                padding=1
            )
            self.conv2 = nn.ConvTranspose1d(
                in_channels=128,
                out_channels=256,
                kernel_size=3,
                stride=2,
                padding=1
            )
            self.conv3 = nn.ConvTranspose1d(
                in_channels=256,
                out_channels=513,
                kernel_size=5,
                stride=1,
                padding=2
            )
            self.norm1 = nn.InstanceNorm1d(num_features=128)
            self.norm2 = nn.InstanceNorm1d(num_features=256)
        
        def forward(self, z):
            y_hat = F.leaky_relu(self.norm1(self.conv1(z)))
            y_hat = F.leaky_relu(self.norm2(self.conv2(y_hat)))
            y_hat = F.sigmoid(self.conv3(y_hat))
            
            return y_hat

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Conv2d(
                in_channels=1,
                out_channels=64,
                kernel_size=5,
                stride=2,
                padding=2
            )
        self.conv2 = nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=5,
                stride=2,
                padding=2
            )
        self.conv3 = nn.Conv2d(
                in_channels=128,
                out_channels=256,
                kernel_size=5,
                stride=2,
                padding=2
            )
        self.conv4 = nn.Conv2d(
                in_channels=256,
                out_channels=512,
                kernel_size=5,
                stride=2,
                padding=2
            )
        self.conv5 = nn.Conv2d(
                in_channels=512,
                out_channels=1,
                kernel_size=5,
                stride=2,
                padding=2
            )
        
        self.norm1 = nn.InstanceNorm2d(num_features=64)
        self.norm2 = nn.InstanceNorm2d(num_features=128)
        self.norm3 = nn.InstanceNorm2d(num_features=256)
        self.norm4 = nn.InstanceNorm2d(num_features=512)
    
    def forward(self, x):
        d = F.leaky_relu(self.norm1(self.conv1(x)))
        d = F.leaky_relu(self.norm2(self.conv2(d)))
        d = F.leaky_relu(self.norm3(self.conv3(d)))
        d = F.leaky_relu(self.norm4(self.conv4(d)))
        d = F.sigmoid(self.conv5(d))
        
        return d

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.encoder = Encoder()
        self.transformer = Transformer()
        self.decoder = Decoder()
    
    def forward(self, x):
        z = self.encoder(x)
        z_hat = self.transformer(z)
        y_hat = self.decoder(z_hat)
        
        return y_hat
    
    def to(self, *args, **kwargs):
        super(Generator, self).to(*args, **kwargs)
        self.encoder.to(*args, **kwargs)
        self.transformer.to(*args, **kwargs)
        self.decoder.to(*args, **kwargs)

In [None]:
gen_a_b = Generator()
gen_b_a = Generator()
disc_a = Discriminator()
disc_b = Discriminator()

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
optimizer_gen = torch.optim.Adam(chain(gen_a_b.parameters(), gen_b_a.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_a = torch.optim.Adam(disc_a.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc_b = torch.optim.Adam(disc_b.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
BATCH_SIZE = 1
DEVICE = get_torch_device()
EPOCHS = 2

conf_a = Configuration(audio_dir='/home/kureta/Music/Palestrina - Missa Papæ Marcelli - Ensemble Officium, Wilfried Rombach/')
dataset_a = TemporalDataset(conf=conf_a, example_length=32, example_hop_length=4)
data_loader_a = DataLoader(dataset_a, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

conf_b = Configuration(audio_dir='/home/kureta/Music/gertrude/')
dataset_b = TemporalDataset(conf=conf_b, example_length=32, example_hop_length=4)
data_loader_b = DataLoader(dataset_b, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
gen_a_b.train()
gen_b_a.train()
disc_a.train()
disc_b.train()
gen_a_b.to(DEVICE)
gen_b_a.to(DEVICE)
disc_a.to(DEVICE)
disc_b.to(DEVICE)

enc = Encoder()
trans = Transformer()
dec = Decoder()

enc.to(DEVICE)
dec.to(DEVICE)
trans.to(DEVICE)

for i in range(1):
    for example_a, example_b in zip(data_loader_a, data_loader_b):
        example_a = example_a.to(DEVICE)
        example_b = example_b.to(DEVICE)
        b_hat = gen_a_b(example_a)
        a_hat = gen_b_a(example_b)
        real_a = disc_a(example_a.unsqueeze(1))
        fake_a = disc_a(a_hat.unsqueeze(1))
        real_b = disc_b(example_b.unsqueeze(1))
        fake_b = disc_b(b_hat.unsqueeze(1))
        break
#         print(example_a.shape, example_b.shape)
#         break
#         spectrum = spectrum.to(DEVICE)
#         pitch = pitch.to(DEVICE).unsqueeze(1)
#         confidence = confidence.to(DEVICE).unsqueeze(1)
#         loudness = loudness.to(DEVICE).unsqueeze(1)
#         optimizer.zero_grad()

#         spectrum_hat = model(spectrum, pitch, confidence, loudness)
#         loss = loss_function(spectrum_hat, spectrum)

#         loss.backward()
#         optimizer.step()

#         pbar.set_description(
#             f'Epoch: {i + 1} - loss: {loss.data.cpu().numpy():.2E}')
#         pbar.update(spectrum.shape[0])

In [None]:
a_hat.shape, example_a.shape

In [None]:
example_a.unsqueeze(-1).shape

Decoder and encoder are not symmetrical