In [None]:
import os

import numpy as np
import random

import datetime

In [None]:
audio_filenames = [ './librivox/guidetomen_%02d_rowland_64kb.mp3' % (i,) for i in [1,2,3]]
audio_filenames

In [None]:
import librosa
librosa.__version__  # '0.5.1'

In [None]:
sample_rate= 24000 # input will be standardised to this rate

fft_step   = 12.5/1000. # 12.5ms
fft_window = 50.0/1000.  # 50ms

n_fft = 512*4

hop_length = int(fft_step*sample_rate)
win_length = int(fft_window*sample_rate)

n_mels = 80
fmin = 125 # Hz
#fmax = ~8000

#np.exp(-7.0), np.log(spectra_abs_min)  # "Audio tests" suggest a min log of -4.605 (-6 confirmed fine)
spectra_abs_min = 0.01 # From Google paper, seems justified

win_length, hop_length

In [None]:
# And for the training windowing :
mel_samples  = 1024
batch_size   = 8

epochs = 10

seed = 10

random.seed(seed)
np.random.seed(seed)

In [None]:
# pip install https://github.com/telegraphic/hickle/archive/dev.zip
import hickle as hkl

def audio_to_melspectrafile(audio_filepath, regenerate=False):
    print("convert_wavs_to_spectra_learnable_records(%s)" % (audio_filepath,))
    melspectra_filepath = audio_filepath.replace('.mp3', '.melspectra.hkl')
    if os.path.isfile(melspectra_filepath) and not regenerate:
        print("  Already present")
        return melspectra_filepath

    samples, _sample_rate = librosa.core.load(audio_filepath, sr=sample_rate)
    samples = samples/np.max(samples)  # Force amplitude of waveform into range ~-1 ... +1.0

    spectra_complex = librosa.stft(samples, n_fft=n_fft, 
                       hop_length=hop_length, 
                       win_length=win_length, window='hann', )

    power_spectra = np.abs(spectra_complex)**2
    melspectra = librosa.feature.melspectrogram(S=power_spectra, n_mels=n_mels, fmin=fmin)
    
    mel_log = np.log( np.maximum(spectra_abs_min, np.abs(melspectra) ))

    # Shape of batches will be (Batch, MelsChannel, TimeStep) for PyTorch - no need for Transpose
    data = dict( 
        mels = melspectra,
        mel_log = mel_log,
        spectra_complex = spectra_complex,
        #spectra_real = spectra_complex.real, 
        #spectra_imag = spectra_complex.imag, 
    )
    
    hkl.dump(data, melspectra_filepath, mode='w', compression='gzip')
    return melspectra_filepath

In [None]:
mel_filenames = [ audio_to_melspectrafile(f) for f in audio_filenames ]

In [None]:
# Don't see a clean way of shuffling without having loaded all the input first...

#class DatasetFromMelspectraFile(torch.utils.data.Dataset):
#    def __init__(self, melspectra_filepath):
#        super(DatasetFromMelspectraFile, self).__init__()
#        
#        data = hkl.load(melspectra_filepath)
#        self.mels = data['mels']
#
#    def __getitem__(self, index):
#        offset = index*mel_samples 
#        a = self.mels[:, offset:offset+mel_samples]
#        return a,a  # This is a VAE situation
#
#    def __len__(self):  
#        return self.mels.shape[1]//mel_samples
#    
#class DatasetFromFiles(torch.utils.data.Dataset):
#    def __init__(self, filepath_arr, length_arr):
#        super(DatasetFromFiles, self).__init__()
#        self.filepaths = filepath_arr
#        self.file_index, self.item_index = -1,-1
#        self.d = None
#        
#    def __getitem__(self, index):
#        self.item_index+=1
#        if self.d is None or self.item_index >= len(self.d):
#            self.file_index+=1
#            self.d = DatasetFromMelspectraFile(self.filepaths[self.file_index])
#            self.item_index=0
#        return d[self.item_index]
#
#    def __len__(self):  
#        #return len(self.filepaths)
#        return -1 # DUNNO

In [None]:
# This approach allows us to load the files into memory only as needed - 
#   But may not be necessary for our purposes, since the data is actually pretty small

def yield_batches_from(melspectra_filepath, bs=batch_size, shuffle=False):
    data = hkl.load(melspectra_filepath)
    mels = data['mels']
    offsets = np.arange(0, mels.shape[1]-mel_samples, mel_samples)
    print("Batches from file : ", melspectra_filepath, mels.shape, offsets.shape)
    if shuffle:
        np.random.shuffle(offsets)  # in-place
    batch_x = np.zeros( shape=(bs, n_mels, mel_samples) )  # Allocate once
    for batch_idx in range(0, offsets.shape[0], bs):
        for i in range(0, bs):
             batch_x[i, :, :] = mels[:, offsets[i]:offsets[i]+mel_samples]
        yield batch_x, batch_x # input -> target
    # Stop

def yield_batches_from_files(filepaths, bs=batch_size, shuffle=False, shuffle_within=False):
    if shuffle:
        #random.shuffle(filepaths)  # in-place = meh
        filepaths = random.sample( filepaths, len(filepaths) )  # original unchanged(~)
    for filepath in filepaths:
        file_batcher = yield_batches_from(filepath, bs=bs, shuffle=shuffle_within)
        for batch in file_batcher:
            yield batch
    # Stop

# This is how this code looks when used :
#for epoch in range(epochs):
#    t0 = datetime.datetime.now()
#    train_batcher = yield_batches_from_files(mel_filenames, bs=batch_size, shuffle=True, shuffle_within=True)
#    for batch_idx, batch in enumerate(train_batcher):
#        input, target = batch
#        ...

In [None]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data # required

In [None]:
if False:  # Test ops to get correct Tensor format
    t = torch.from_numpy(np.array([[10,11,12,13,14,15,16,17,18,19], 
                                   [20,21,22,23,24,25,26,27,28,29], 
                                   [30,31,32,33,34,35,36,37,38,39]
                                  ]))
    t
    #t.view(2,3,5)
    t.transpose(0,1).contiguous().view(2,5,3).transpose(1,2)

    # Want to convert long set of mels into batches of length mel_samples: 
    # 0 :
    #   10    11    12    13    14    
    #   20    21    22    23    24   
    #   30    31    32    33    34   
    # 1 :
    #   15    16    17    18    19
    #   25    26    27    28    29
    #   35    36    37    38    39

def TensorFromMelspectraFile(melspectra_filepath, block_len=mel_samples):
    data = hkl.load(melspectra_filepath)
    mel_log = data['mel_log']
    
    if block_len is None: # Allow for 'whole of file' tensor(1,mels,everything)
        block_len=mel_log.shape[1]
    n_blocks = mel_log.shape[1]//mel_samples
    print("Read %5d log(mel[%2d]) = %4d blocks from %s" % 
          (mel_log.shape[1], mel_log.shape[0], n_blocks, melspectra_filepath,))
    
    mel_log_trunc_t = mel_log[:, :n_blocks*mel_samples ].T
    #print(torch.from_numpy(mel_log_trunc_t).contiguous().size())
    return ( torch.from_numpy(mel_log_trunc_t).contiguous()
             .view(n_blocks, block_len, n_mels).transpose(1,2))

In [None]:
mel_datasets = []
for f in mel_filenames:
    t = TensorFromMelspectraFile(f)
    mel_datasets.append( torch.utils.data.TensorDataset(t, t) )

In [None]:
mel_dataset = torch.utils.data.ConcatDataset(mel_datasets)

In [None]:
class WaveNettyCell(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, cond_channels=0, 
                 kernel_size=3, stride=1, dilation=1):
        super(WaveNettyCell, self).__init__()
        
        self.gate   = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        self.signal = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        
        self.cond = cond_channels>0
        if self.cond:
            self.gate_cond   = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=1, bias=False)
            self.signal_cond = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=1, bias=False)

        self.recombine = torch.nn.Conv1d(hidden_channels, in_channels, 
                                    kernel_size=1, stride=1, dilation=1, 
                                    padding=0, groups=1, bias=True)
        
        self.padding = (0, (kernel_size-1)*(dilation+stride*0)  )
            
    def forward(self, input, condition=None):
        gate = self.gate(input)
        signal = self.signal(input)
        if self.cond:
            gate   = gate   + self.gate_cond(condition)
            signal = signal + self.signal_cond(condition)

        gate = F.sigmoid(gate)
            
        mult = gate * F.tanh(signal)
        
        # Yes : There's no side/skip here : It's just a fancy feed-forward
        return input + F.pad( self.recombine(mult), self.padding)

In [None]:
class VQ_encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128):
        super(VQ_encoder, self).__init__()
        
        # See https://fomoro.com/tools/receptive-fields/
        
        #   #3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID
        #self.conv = [ WaveNettyCell(in_channels, hidden_channels, 
        #                            stride=2) for c in range(4) ]
            
        #   #3,1,1,VALID;3,1,2,VALID;3,1,4,VALID;3,1,8,VALID;3,1,16,VALID
        #   receptive field = 63 timesteps
        self.conv = torch.nn.ModuleList([ WaveNettyCell(in_channels, hidden_channels, 
                                    dilation=d) for d in [1,2,4,8,16] ])
            
    def forward(self, input):
        x = input
        for c in self.conv:
            x = c(x)
        return x

In [None]:
class VQ_quantiser(torch.nn.Module):
    def __init__(self, n_symbols, latent_dimension):
        super(VQ_quantiser, self).__init__()
            
    def forward(self, input):
        return input, []  # Doesn't do quantisation yet...

In [None]:
class VQ_decoder(torch.nn.Module):
    def __init__(self, in_channels, latent_channels=0, hidden_channels=128):
        super(VQ_decoder, self).__init__()
        
        self.conv = torch.nn.ModuleList([ WaveNettyCell(in_channels, hidden_channels, 
                                    #cond_channels=latent_channels,
                                    dilation=d) for d in [1,2,4,8,16] ])
        
        #self.c1 = WaveNettyCell(in_channels, hidden_channels, dilation=1)
            
    def forward(self, input, latent=None):
        x = input
        for c in self.conv:
            #x = c(x, latent)
            x = c(x)
        return x

In [None]:
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
use_cuda

In [None]:
class VQ_VAE_Model(torch.nn.Module):
    def __init__(self):
        super(VQ_VAE_Model, self).__init__()
        #self.name=name
        
        self.channels, self.n_symbols = n_mels, 64
        
        self.encoder = VQ_encoder(self.channels)
        self.quant   = VQ_quantiser(self.n_symbols, self.channels)
        self.decoder = VQ_decoder(self.channels)
        
        print(f"Number of parameter variables : {len(list(self.parameters()))}")
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        
    def forward(self, input):
        x = self.encoder(input)
        x, symbols = self.quant(x)
        x = self.decoder(x)
        return x, symbols

    def train_(self, input, target):
        self.train()  # Set mode
        self.optimizer.zero_grad()
        output, symbols = self(input)
        loss = F.mse_loss(output, target)
        loss.backward()
        self.optimizer.step()
        return loss

    def test_(self, input):
        self.eval()
        output, symbols = self(input)
        return symbols

    def save(self, save_template, epoch):
        #torch.save(self.state_dict(), 'model/epoch_{}_{:02d}.pth'.format(self.name, epoch))
        torch.save(self.state_dict(), save_template.format(epoch))



In [None]:
model = VQ_VAE_Model()
if use_cuda:
    model = model.cuda()

In [None]:
# http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

for epoch in range(epochs):
    t0 = datetime.datetime.now()
    train_batches = torch.utils.data.DataLoader(mel_dataset, batch_size=batch_size, 
                                                shuffle=True, num_workers=1)
    for batch_idx, batch in enumerate(train_batches):
        input, target = batch
        
        x = Variable( input.type(dtype) )
        y = Variable( target.type(dtype) )
        mse = model.train_(x, y)
        
        print(f"Epoch {epoch:2}, Batch {batch_idx:2}, %.6f" % (float(mse*1000*1000),))
        
        #print('Train Epoch: {:2d} [{:6d}/{:6d} ({:3.0f}%)] Non-relations accuracy: {:3.0f}% | Relations accuracy: {:3.0f}% | Tricky accuracy: {:3.0f}% | '.format(
        #        epoch, batch_idx * bs * example_factor, 
        #        len(norel[0]) * example_factor, 
        #        100. * batch_idx * bs/ len(norel[0]), 
        #        accuracy_norel, accuracy_birel, accuracy_trirel, 
        #     ))        


In [None]:
# f'{234.3453453453434534:6.2}'  Wierd choice for format specifiers : overall_width.digits_of_precision