In [94]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import json
import os

In [144]:
def complex_to_tensor_2d(complex_tensor):
    _, width, height = complex_tensor.shape
    tensor_2d = torch.view_as_real(complex_tensor).reshape(2, width, height)
    return tensor_2d

def tensor_2d_to_complex(tensor_2d):
    width, height, _ = tensor_2d.shape
    complex_tensor = torch.view_as_real(tensor_2d).reshape(width, height, 2)
    return complex_tensor


In [118]:
class NSynth(Dataset):
    def __init__(self, annotations_path, audio_path, target_sr, number_of_samples, transform):
        with open(annotations_path, 'r') as f:
            self.annotations = json.load(f)
        self.audio_path = audio_path
        self.target_sr = target_sr
        self.number_of_samples = number_of_samples
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        signal, sr = torchaudio.load(audio_sample_path)
        signal = self._resample(signal, sr)
        signal = self._collapse_channels(signal)
        signal = self._truncate_signal_size(signal, self.number_of_samples)
        signal_transform = self.transform(signal)

        white_noise = self._generate_white_noise(signal, pct=0.05)
        noisy_signal = (signal + white_noise)
        noisy_signal_transform = self.transform(noisy_signal)
        return complex_to_tensor_2d(noisy_signal_transform), complex_to_tensor_2d(signal_transform)

    def _resample(self, signal, sr):
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)
            signal = resampler(signal)
        return signal

    def _get_audio_sample_path(self, index):
        song_title = list(self.annotations.keys())[index] + '.wav'
        path = os.path.join(self.audio_path, song_title)
        return path

    def _collapse_channels(self, signal):
        if signal.shape[0]>1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

    def _generate_white_noise(self, signal, pct):
        max_from_signal = signal.max().item()
        white_noise = max_from_signal * pct * torch.rand_like(signal)
        return white_noise
    
    def _truncate_signal_size(self, signal, sample_number):
        if signal.shape[1] > sample_number:
            signal = signal[:, :sample_number]
        elif signal.shape[1] < sample_number:
            pad_size = sample_number - signal.shape[1]
            signal = torch.nn.functional.pad(signal, pad=(0, pad_size), value=0)
        return signal
        


In [119]:
class AltConvTranspose2d(nn.Module):
    def __init__(self, conv, output_size=None):
        super().__init__()
        self.conv = conv
        self.output_size = output_size
        
    def forward(self, x):
        output = self.conv(x, output_size=x.size())
        return output

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, padding='same'),
            nn.ReLU()
        ) 
        
        self.decoder = nn.Sequential(
            AltConvTranspose2d(nn.ConvTranspose2d(in_channels=8, out_channels=8, kernel_size=3, padding=1)),
            nn.ReLU(),
            AltConvTranspose2d(nn.ConvTranspose2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=3, padding='same'),
            nn.Sigmoid()
        ) 
  
    def forward(self, x): 
        encoded = self.encoder(x) 
        decoded = self.decoder(encoded) 
        return decoded


In [126]:
autoencoder = Autoencoder()

TARGET_SAMPLE_RATE = 20000
N_FFT = 1024
WIN_LENGTH = 512

spectogram = torchaudio.transforms.Spectrogram(n_fft=N_FFT, power=None)
inverse_spec = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT)

data = NSynth(  annotations_path = '../../Downloads/nsynth-test/examples.json', 
                audio_path = '../../Downloads/nsynth-test/audio', 
                target_sr = TARGET_SAMPLE_RATE,
                number_of_samples=40000,
                transform=spectogram)

loader = DataLoader(data, batch_size=64, shuffle=True)

In [127]:
def train(model, loader, n_epochs, loss_fn, lr=3e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    size = len(loader)
    for epoch_i in range(n_epochs):
        print(f'BATCH: [{epoch_i+1}/{n_epochs}]')
        for i, (noisy_signal, clean_signal) in enumerate(loader):
            # Forward Pass
            clean_signal_prediction = model(noisy_signal)
            loss = loss_fn(clean_signal_prediction, clean_signal)

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Display Results
            print(f'    - [{i}/{size}]loss: {loss}')



In [128]:
loss_fn = nn.MSELoss()
train(autoencoder, loader, n_epochs=1, loss_fn=loss_fn)

BATCH: [1/1]
    - [0/64]loss: 9.887795448303223
    - [1/64]loss: 11.647099494934082
    - [2/64]loss: 11.087203025817871
    - [3/64]loss: 10.867298126220703
    - [4/64]loss: 11.925589561462402
    - [5/64]loss: 10.863648414611816
    - [6/64]loss: 9.918437004089355
    - [7/64]loss: 10.863984107971191
    - [8/64]loss: 9.31159782409668
    - [9/64]loss: 11.086263656616211
    - [10/64]loss: 11.1966552734375
    - [11/64]loss: 7.379568576812744
    - [12/64]loss: 11.876327514648438
    - [13/64]loss: 10.837469100952148
    - [14/64]loss: 9.833489418029785
    - [15/64]loss: 8.354236602783203
    - [16/64]loss: 14.09255599975586
    - [17/64]loss: 9.51359748840332


KeyboardInterrupt: 

In [132]:
noisy_signal_spec, clean_signal_spec = next(iter(loader))

noisy_signal_spec_i = noisy_signal_spec[0]
clean_signal_spec_i = clean_signal_spec[0]

In [133]:
noisy_signal_spec_i.shape

torch.Size([2, 513, 79])

In [140]:
torch.tensor([noisy_signal_spec_i]).shape

ValueError: only one element tensors can be converted to Python scalars

In [None]:
torch.extra

In [141]:
clean_signal_spec_pred = autoencoder(noisy_signal_spec_i.unsqueeze(0))

In [146]:
clean_signal_spec_pred.shape

torch.Size([1, 2, 513, 79])

In [148]:
clean_signal_spec_pred.squeeze()

tensor([[[0.4870, 0.4552, 0.4103,  ..., 0.4659, 0.4584, 0.4604],
         [0.5002, 0.4306, 0.4267,  ..., 0.4539, 0.4662, 0.4525],
         [0.4741, 0.4467, 0.4466,  ..., 0.4526, 0.4590, 0.4579],
         ...,
         [0.4774, 0.4709, 0.4674,  ..., 0.4666, 0.4655, 0.4666],
         [0.4773, 0.4714, 0.4675,  ..., 0.4673, 0.4659, 0.4681],
         [0.4744, 0.4736, 0.4730,  ..., 0.4728, 0.4712, 0.4752]],

        [[0.4725, 0.4839, 0.4772,  ..., 0.5000, 0.4998, 0.5167],
         [0.4195, 0.4076, 0.4021,  ..., 0.4776, 0.4785, 0.4993],
         [0.4355, 0.3906, 0.4117,  ..., 0.4699, 0.4877, 0.5006],
         ...,
         [0.4910, 0.4822, 0.4820,  ..., 0.4823, 0.4843, 0.4984],
         [0.4925, 0.4829, 0.4816,  ..., 0.4821, 0.4830, 0.4967],
         [0.5042, 0.4969, 0.4963,  ..., 0.4961, 0.4960, 0.5038]]],
       grad_fn=<SqueezeBackward0>)

In [147]:
tensor_2d_to_complex(clean_signal_spec_pred.squeeze())

RuntimeError: view_as_real is only supported for complex tensors

In [53]:
noisy_signal_i = inverse_spec(noisy_signal_spec_i)

In [54]:
noisy_signal_i

tensor([[0.0109, 0.0023, 0.0356,  ..., 0.0205, 0.0169, 0.0301]])

In [85]:
batch, _, width, heigh = noisy_signal_spec.shape

In [86]:
batch, _, width, heigh = noisy_signal_spec.shape
torch.view_as_real(noisy_signal_spec).reshape(batch, 2, width, heigh)

tensor([[[[ 1.8029e+01,  0.0000e+00,  1.3928e+01,  ...,  1.1009e+01,
            0.0000e+00,  1.1007e+01],
          [ 0.0000e+00,  1.0671e+01,  0.0000e+00,  ...,  0.0000e+00,
            1.0570e+01,  0.0000e+00],
          [-1.1081e+01, -2.3842e-06, -6.8957e+00,  ..., -5.4371e+00,
           -3.7014e-02, -5.5921e+00],
          ...,
          [-5.2303e-01, -6.9919e-07,  6.1342e-01,  ...,  1.1985e-01,
           -6.8469e-02,  2.2392e-01],
          [-3.6077e-02, -6.0674e-02,  5.3807e-02,  ..., -1.4203e-01,
           -5.3823e-02,  2.3000e-02],
          [ 3.1877e-01,  0.0000e+00, -2.6981e-01,  ..., -3.9399e-03,
            5.5269e-02, -2.9821e-01]],

         [[ 1.2707e-01, -6.5051e-02, -7.4519e-02,  ...,  1.8152e-01,
           -3.9045e-02, -2.1917e-02],
          [ 3.3153e-02, -4.6077e-07, -4.9430e-02,  ..., -2.1661e-01,
           -3.5724e-02,  3.2197e-01],
          [ 6.6999e-02,  2.7580e-01, -1.3612e-02,  ..., -1.7254e-01,
            6.2216e-02,  1.4341e-01],
          ...,
     