In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
from collections import OrderedDict

In [None]:
import numpy as np
import librosa

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchbearer
from torchbearer.cv_utils import DatasetValidationSplitter

In [None]:
from zachary.audio_data import AudioDataset
import zachary.transforms as transforms
from zachary.utils import get_torch_device, get_num_trainable_params, initialize_model

In [None]:
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.1
DEVICE = get_torch_device()

In [None]:
dataset = AudioDataset()

In [None]:
dataset.example_length = 3
splitter = DatasetValidationSplitter(len(dataset), VALIDATION_SPLIT)
train_dataset = splitter.get_train_dataset(dataset)
val_dataset = splitter.get_val_dataset(dataset)
traingen = torch.utils.data.DataLoader(train_dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valgen = torch.utils.data.DataLoader(val_dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
plt.rcParams['figure.figsize'] = (26, 8)

fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.matshow(dataset[1000][0].numpy(), aspect='auto', interpolation='none', origin='lower')
ax1.set_title('Mu-Law encoded frame')

ax2.matshow(dataset[1000][1].numpy(), aspect='auto', interpolation='none', origin='lower')
ax2.set_title('Mu-Law expanded frame')
pass

In [None]:
def conv_layers_weights_init(model):
    for m in model.modules():
        classname = m.__class__.__name__

        if 'Conv' in classname:
            try:
                size = m.weight.shape[0] * m.weight.shape[2]
                m.weight.data.normal_(0.0, size.reciprocal().sqrt())
            except AttributeError:
                pass

        elif 'BatchNorm' in classname:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

In [None]:
class SeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConv1d, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [None]:
class SeparableConvTransposed1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConvTransposed1d, self).__init__()

        self.conv1 = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
                                        stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        channels = [1026, 512, 128, 32]
        
        layers = OrderedDict([
            ('conv1d_01', SeparableConv1d(channels[0], channels[1], kernel_size=3, padding=1)),
            ('relu_01', nn.SELU()),
            ('conv1d_02', SeparableConv1d(channels[1], channels[2], kernel_size=2)),
            ('relu_02', nn.SELU()),
            ('conv1d_03', SeparableConv1d(channels[2], channels[3], kernel_size=2)),
            (('sigmoid_01'), nn.Sigmoid()),
#             ('relu_03', nn.SELU()),
#             ('conv1d_04', nn.Conv1d(channels[3], channels[4], kernel_size=1, bias=False)),
        ])
        
        self.block = nn.Sequential(layers)

    def forward(self, x):
        return self.block(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        channels = [32, 128, 512, 1026]
        
        layers = OrderedDict([
            ('conv1d_01', SeparableConvTransposed1d(channels[0], channels[1], kernel_size=2)),
            ('relu_01', nn.SELU()),
            ('conv1d_02', SeparableConvTransposed1d(channels[1], channels[2], kernel_size=2)),
            ('relu_02', nn.SELU()),
            ('conv1d_03', SeparableConvTransposed1d(channels[2], channels[3], kernel_size=3, padding=1)),
#             ('relu_03', nn.SELU()),
#             ('conv1d_04', nn.ConvTranspose1d(channels[3], channels[4], kernel_size=1, bias=False)),
        ])
        
        self.block = nn.Sequential(layers)

    def forward(self, x):
        return self.block(x)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.block = nn.Sequential(OrderedDict([
            ('encoder', self.encoder),
            ('decoder', self.decoder)
        ]))
    
    def forward(self, x):
        return self.block(x)

In [None]:
ae = AutoEncoder()
initialize_model(ae)
ae.to(DEVICE)
print(get_num_trainable_params(ae))

In [None]:
ae.train()

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ae.parameters()), lr=0.001)
trial = torchbearer.Trial(ae, optimizer, F.mse_loss, metrics=['loss']).to(DEVICE)
trial.with_generators(train_generator=traingen, val_generator=valgen)
trial.run(epochs=1)

In [None]:
ae.eval()
tmp = dataset.example_length
dataset.example_length = 1000
stft = dataset[0][0].unsqueeze(0).to(DEVICE)
with torch.no_grad():
    stft_hat = ae(stft)
dataset.example_length = tmp

In [None]:
stft_np = stft.cpu().squeeze().numpy()
stft_np = stft_np[:513, :] + 1j * stft_np[513:, :]
audio_np = librosa.istft(stft_np, hop_length=512, center=False)
stft_hat_np = stft_hat.cpu().squeeze().numpy()
stft_hat_np = stft_hat_np[:513, :] + 1j * stft_hat_np[513:, :]
audio_hat_np = librosa.istft(stft_hat_np, hop_length=512, center=False)

In [None]:
plt.rcParams['figure.figsize'] = (24, 8)

fig, (ax1, ax2) = plt.subplots(2,1)
ax1.plot(audio_np)
ax1.set_title('Real frame')

ax2.plot(audio_hat_np[1024:-1024])
ax2.set_title('Autoencoded frame')
pass

In [None]:
Audio(audio_hat_np[1024:-1024], rate=44100)

In [None]:
Audio(audio_np, rate=44100)

In [None]:
dataset.example_length = 13
splitter = DatasetValidationSplitter(len(dataset), VALIDATION_SPLIT)
train_dataset = splitter.get_train_dataset(dataset)
val_dataset = splitter.get_val_dataset(dataset)
traingen = torch.utils.data.DataLoader(train_dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valgen = torch.utils.data.DataLoader(val_dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
adversarial_loss = torch.nn.BCELoss()
# valid = torch.ones(BATCH_SIZE, 32, dataset.example_length-2, device=DEVICE)
# fake = torch.zeros(BATCH_SIZE, 32, dataset.example_length-2, device=DEVICE)

In [None]:
GEN_IMGS = torchbearer.state_key('gen_imgs')
DISC_GEN = torchbearer.state_key('disc_gen')
DISC_GEN_DET = torchbearer.state_key('disc_gen_det')
DISC_REAL = torchbearer.state_key('disc_real')
G_LOSS = torchbearer.state_key('g_loss')
D_LOSS = torchbearer.state_key('d_loss')

In [None]:
class GAN(nn.Module):
    def __init__(self, encoding_length):
        super().__init__()
        self.encoding_length = encoding_length
        self.discriminator = Encoder()
        self.generator = Decoder()

    def forward(self, real_imgs, state):
        # Generator Forward
        z = torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], 32, self.encoding_length))).to(state[torchbearer.DEVICE])
        state[GEN_IMGS] = self.generator(z)
        state[DISC_GEN] = self.discriminator(state[GEN_IMGS])
        # This clears the function graph built up for the discriminator
        self.discriminator.zero_grad()

        # Discriminator Forward
        state[DISC_GEN_DET] = self.discriminator(state[GEN_IMGS].detach())
        state[DISC_REAL] = self.discriminator(real_imgs)

In [None]:
@torchbearer.callbacks.add_to_loss
def loss_callback(state):
    fake_loss = adversarial_loss(state[DISC_GEN_DET], torch.zeros(state[DISC_REAL].shape[0], 32, dataset.example_length-2, device=DEVICE))
    real_loss = adversarial_loss(state[DISC_REAL], torch.zeros(state[DISC_REAL].shape[0], 32, dataset.example_length-2, device=DEVICE))
    state[G_LOSS] = adversarial_loss(state[DISC_GEN], torch.ones(state[DISC_REAL].shape[0], 32, dataset.example_length-2, device=DEVICE))
    state[D_LOSS] = (real_loss + fake_loss) / 2
    return state[G_LOSS] + state[D_LOSS]

In [None]:
@torchbearer.metrics.running_mean
@torchbearer.metrics.mean
class g_loss(torchbearer.metrics.Metric):
    def __init__(self):
        super().__init__('g_loss')

    def process(self, state):
        return state[G_LOSS]

In [None]:
@torchbearer.metrics.running_mean
@torchbearer.metrics.mean
class d_loss(torchbearer.metrics.Metric):
    def __init__(self):
        super().__init__('d_loss')

    def process(self, state):
        return state[D_LOSS]

In [None]:
model = GAN(dataset.example_length-2)

In [None]:
model.load_state_dict(torch.load('model.pt'))

In [None]:
model.train()
model.encoding_length = dataset.example_length-2
optim = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
torchbearertrial = torchbearer.Trial(model, optim, criterion=None, metrics=['loss', g_loss(), d_loss()],
                                     callbacks=[loss_callback, torchbearer.callbacks.live_loss_plot.LiveLossPlot(on_batch=True, max_cols=3, on_epoch=False)],
                                     verbose=0, pass_state=True).to(DEVICE)
torchbearertrial.with_train_generator(traingen)
torchbearertrial.run(epochs=1)

In [None]:
model.eval()
z = torch.Tensor(np.random.normal(0, 1, (1, 32, 13))).to(DEVICE)
with torch.no_grad():
    stft_hat = model.generator(z)

In [None]:
stft_hat_np = stft_hat.cpu().squeeze().numpy()
stft_hat_np = stft_hat_np[:513, :] + 1j * stft_hat_np[513:, :]
audio_hat_np = librosa.istft(stft_hat_np, hop_length=512, center=False)

In [None]:
plt.rcParams['figure.figsize'] = (24, 4)

fig, (ax1) = plt.subplots(1,1)

ax1.plot(audio_hat_np[512:-512])
ax1.set_title('Autoencoded frame')
pass

In [None]:
Audio(audio_hat_np[128:-128], rate=44100)

In [None]:
torch.save(model.state_dict(), "model.pt")