In [76]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import lmdb
import torchaudio
import librosa
from udls.generated import AudioExample
import IPython.display as ipd
import matplotlib.pyplot as plt

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device in use:", device)

Device in use: cuda


In [159]:
class AudioTransform():
    def __init__(self,
                 sample_rate: int= 44100,
                 n_fft: int = 512,
                 hop_size: int = 16,
                 window_size: int = 128
                 ):
        
        
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_size = hop_size
        self.window_size = window_size

        self.window = torch.hamming_window(window_size)

    
    def forward(self, wav, **kwargs):

        stft = torch.stft(wav, 
                        n_fft=512, 
                        hop_length=16, 
                        win_length=128, 
                        window=torch.hann_window(window_length=128),
                        return_complex=True)

        spectrogram = torch.abs(stft)**2

        melscale_transform = torchaudio.transforms.MelScale(sample_rate=self.sample_rate, 
                                                            n_stft=self.n_fft // 2 + 1)
        melscale_spectrogram = melscale_transform(spectrogram)

        return(melscale_spectrogram)

    
    def invert(self, spec, **kwargs):

        inverse_melscale_transform = torchaudio.transforms.InverseMelScale(sample_rate= self.sample_rate,
                                                                           n_stft=self.n_fft // 2 + 1, 
                                                                           driver="gelsd")
        spectrogram = inverse_melscale_transform(spec)        
        
        griffin_lim_transform = torchaudio.transforms.GriffinLim(n_fft=self.n_fft,
                                                                 win_length=self.window_size,
                                                                 hop_length=self.hop_size
                                                                 )
        wav = griffin_lim_transform(spectrogram)
        
        return(wav)

    

In [160]:
# dataset definition
class LoopDataset(torch.utils.data.Dataset):
    FS = 44100
    SIZE_SAMPLES = 65536

    def __init__(self, db_path: str) -> None:
        super().__init__()

        self._db_path = db_path

        self.env = lmdb.open(self._db_path, lock=False)

        with self.env.begin(write=False) as txn:
            self.keys = list(txn.cursor().iternext(values=False))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx: int):
        with self.env.begin(write=False) as txn:
            ae = AudioExample.FromString(txn.get(self.keys[idx]))

        buffer = ae.buffers["waveform"]
        assert buffer.precision == AudioExample.Precision.INT16
        assert buffer.sampling_rate == self.FS

        audio = torch.frombuffer(buffer.data, dtype=torch.int16)
        audio = audio.float() / (2**15 - 1)
        assert len(audio) == self.SIZE_SAMPLES

        return audio


# get 5 random examples
dataset = LoopDataset(db_path="../../../data/")
valid_ratio = 0.2
nb_valid = int(valid_ratio * len(dataset))
nb_train = len(dataset) - nb_valid
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
    dataset, [nb_train, nb_valid]
)

print(nb_train, nb_valid)

num_threads = 0  # != 0 crashes on windows o_o
batch_size = 128

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, num_workers=num_threads
)
valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads
)

for i in np.random.randint(len(train_dataset), size=5):
    print(f"example #{i}")
    ipd.display(ipd.Audio(train_dataset[i], rate=LoopDataset.FS))

9195 2298
example #852


example #860


example #436


example #7038


example #3125


In [161]:
def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()

In [169]:
audio_process = AudioTransform()

In [170]:
audio_process.forward(train_dataset[0]).shape



torch.Size([128, 4097])