In [None]:
import os

import numpy as np
import random

import torch
import torch.functional as F
#import torch.utils.data # required

In [None]:
audio_filenames = [ './librivox/guidetomen_%02d_rowland_64kb.mp3' % (i,) for i in [1,2,3]]
audio_filenames

In [None]:
import librosa
librosa.__version__  # '0.5.1'

In [None]:
sample_rate= 24000 # input will be standardised to this rate

fft_step   = 12.5/1000. # 12.5ms
fft_window = 50.0/1000.  # 50ms

n_fft = 512*4

hop_length = int(fft_step*sample_rate)
win_length = int(fft_window*sample_rate)

n_mels = 80
fmin = 125 # Hz
#fmax = ~8000

#np.exp(-7.0), np.log(spectra_abs_min)  # "Audio tests" suggest a min log of -4.605 (-6 confirmed fine)
spectra_abs_min = 0.01 # From Google paper, seems justified

win_length, hop_length

In [None]:
# And for the training windowing :
mel_samples  = 1024
batch_size   = 8

epochs = 10

seed = 10

random.seed(seed)
np.random.seed(seed)

In [None]:
# pip install https://github.com/telegraphic/hickle/archive/dev.zip
import hickle as hkl

def audio_to_melspectrafile(audio_filepath, regenerate=False):
    print("convert_wavs_to_spectra_learnable_records(%s)" % (audio_filepath,))
    melspectra_filepath = audio_filepath.replace('.mp3', '.melspectra.hkl')
    if os.path.isfile(melspectra_filepath) and not regenerate:
        print("  Already present")
        return melspectra_filepath

    samples, _sample_rate = librosa.core.load(audio_filepath, sr=sample_rate)
    samples = samples/np.max(samples)  # Force amplitude of waveform into range ~-1 ... +1.0

    spectra_complex = librosa.stft(samples, n_fft=n_fft, 
                       hop_length=hop_length, 
                       win_length=win_length, window='hann', )

    power_spectra = np.abs(spectra_complex)**2
    melspectra = librosa.feature.melspectrogram(S=power_spectra, n_mels=n_mels, fmin=fmin)

    # Shape of batches will be (Batch, MelsChannel, TimeStep) for PyTorch - no need for Transpose
    data = dict( 
        mels = melspectra,
        spectra_complex = spectra_complex,
        #spectra_real = spectra_complex.real, 
        #spectra_imag = spectra_complex.imag, 
    )
    
    hkl.dump(data, melspectra_filepath, mode='w', compression='gzip')
    return melspectra_filepath

In [None]:
mel_filenames = [ audio_to_melspectrafile(f) for f in audio_filenames ]

In [None]:
# Don't see a clean way of shuffling without having loaded all the input first...

#class DatasetFromMelspectraFile(torch.utils.data.Dataset):
#    def __init__(self, melspectra_filepath):
#        super(DatasetFromMelspectraFile, self).__init__()
#        
#        data = hkl.load(melspectra_filepath)
#        self.mels = data['mels']
#
#    def __getitem__(self, index):
#        offset = index*mel_samples 
#        a = self.mels[:, offset:offset+mel_samples]
#        return a,a  # This is a VAE situation
#
#    def __len__(self):  
#        return self.mels.shape[1]//mel_samples
#    
#class DatasetFromFiles(torch.utils.data.Dataset):
#    def __init__(self, filepath_arr, length_arr):
#        super(DatasetFromFiles, self).__init__()
#        self.filepaths = filepath_arr
#        self.file_index, self.item_index = -1,-1
#        self.d = None
#        
#    def __getitem__(self, index):
#        self.item_index+=1
#        if self.d is None or self.item_index >= len(self.d):
#            self.file_index+=1
#            self.d = DatasetFromMelspectraFile(self.filepaths[self.file_index])
#            self.item_index=0
#        return d[self.item_index]
#
#    def __len__(self):  
#        #return len(self.filepaths)
#        return -1 # DUNNO

In [None]:
def yield_batches_from(melspectra_filepath, bs=batch_size, shuffle=False):
    data = hkl.load(melspectra_filepath)
    mels = data['mels']
    offsets = np.arange(0, mels.shape[1], mel_samples)
    print("Batches from file : ", melspectra_filepath, mels.shape, offsets.shape)
    if shuffle:
        np.random.shuffle(offsets)  # in-place
    for i in range(0, offsets.shape[0], bs):
        yield mels[:, offsets[i : i+bs] ]
    # Stop

def yield_batches_from_files(filepaths, bs=batch_size, shuffle=False, shuffle_within=False):
    if shuffle:
        #random.shuffle(filepaths)  # in-place = meh
        filepaths = random.sample( filepaths, len(filepaths) )  # original unchanged(~)
    for filepath in filepaths:
        file_batcher = yield_batches_from(filepath, bs=bs, shuffle=shuffle_within)
        for batch in file_batcher:
            yield batch
    # Stop

In [None]:
#random.shuffle(filenames)
#filenames

In [None]:
class WaveNettyCell(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, cond_channels=0, 
                 kernel_size=3, stride=1, dilation=1):
        super(WaveNetCell, self).__init__()
        
        self.gate   = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        self.signal = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        
        self.cond = cond_channels>0
        if self.cond:
            self.gate_cond   = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=1, bias=False)
            self.signal_cond = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=1, bias=False)

        self.recombine = torch.nn.Conv1d(hidden_channels, in_channels, 
                                    kernel_size=1, stride=1, dilation=1, 
                                    padding=0, groups=1, bias=True)
            
    def forward(self, input, condition=None):
        gate = self.gate(input)
        signal = self.signal(input)
        if self.cond:
            gate   = gate   + self.gate_cond(condition)
            signal = signal + self.signal_cond(condition)

        gate = F.sigmoid(gate)
            
        mult = gate * F.tanh(signal)
        
        # Yes : There's no side/skip here : It's just a fancy feed-forward
        return input + self.recombine(mult)



In [None]:
class VQ_encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128):
        super(VQ_encoder, self).__init__()
        
        # See https://fomoro.com/tools/receptive-fields/
        
        #   #3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID
        #self.conv = [ WaveNettyCell(in_channels, hidden_channels, 
        #                            stride=2) for c in range(4) ]
            
        #   #3,1,1,VALID;3,1,2,VALID;3,1,4,VALID;3,1,8,VALID;3,1,16,VALID
        #   receptive field = 63 timesteps
        self.conv = [ WaveNettyCell(in_channels, hidden_channels, 
                                    dilation=d) for d in [1,2,4,8,16] ]
            
    def forward(self, input):
        x = input
        for c in self.conv:
            x = c(x)
        return x

In [None]:
class VQ_quantiser(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, cond_channels=0, 
                 kernel_size=3, stride=1, dilation=1):
        super(VQ_quantiser, self).__init__()
            
    def forward(self, input):
        return input  # Doesn't do quantisation yet...

In [None]:
class VQ_decoder(torch.nn.Module):
    def __init__(self, in_channels, latent_channels, hidden_channels=128):
        super(VQ_decoder, self).__init__()
        
        self.conv = [ WaveNettyCell(in_channels, hidden_channels, 
                                    #cond_channels=latent_channels,
                                    dilation=d) for d in [1,2,4,8,16] ]
            
    def forward(self, input, latent):
        x = input
        for c in self.conv:
            #x = c(x, latent)
            x = c(x)
        return x

In [None]:
train_batcher = yield_batches_from_files(mel_filenames, bs=batch_size, shuffle=True, shuffle_within=True)

#for epoch in range(epochs):
#    for batch in train_batcher:
#        pass