In [6]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import json
import os

In [7]:
class NSynth(Dataset):
    def __init__(self, annotations_path, audio_path, target_sr, number_of_samples, transform):
        with open(annotations_path, 'r') as f:
            self.annotations = json.load(f)
        self.audio_path = audio_path
        self.target_sr = target_sr
        self.number_of_samples = number_of_samples
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        signal, sr = torchaudio.load(audio_sample_path)
        signal = self._resample(signal, sr)
        signal = self._collapse_channels(signal)
        signal = self._truncate_signal_size(signal, self.number_of_samples)

        white_noise = self._generate_white_noise(signal, pct=0.05)
        noisy_signal = (signal + white_noise)
        return self.transform(noisy_signal), self.transform(signal)

    def _resample(self, signal, sr):
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)
            signal = resampler(signal)
        return signal

    def _get_audio_sample_path(self, index):
        song_title = list(self.annotations.keys())[index] + '.wav'
        path = os.path.join(self.audio_path, song_title)
        return path

    def _collapse_channels(self, signal):
        if signal.shape[0]>1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

    def _generate_white_noise(self, signal, pct):
        max_from_signal = signal.max().item()
        white_noise = max_from_signal * pct * torch.rand_like(signal)
        return white_noise
    
    def _truncate_signal_size(self, signal, sample_number):
        if signal.shape[1] > sample_number:
            signal = signal[:, :sample_number]
        elif signal.shape[1] < sample_number:
            pad_size = sample_number - signal.shape[1]
            signal = torch.nn.functional.pad(signal, pad=(0, pad_size), value=0)
        return signal
    
    def _get_spectogram(self, signal):
        spectogram = torchaudio.transforms.Spectrogram()
        return spectogram(signal)
        


In [8]:
class AltConvTranspose2d(nn.Module):
    def __init__(self, conv, output_size=None):
        super().__init__()
        self.conv = conv
        self.output_size = output_size
        
    def forward(self, x):
        output = self.conv(x, output_size=x.size())
        return output

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, padding='same'),
            nn.ReLU()
        ) 
        
        self.decoder = nn.Sequential(
            AltConvTranspose2d(nn.ConvTranspose2d(in_channels=8, out_channels=8, kernel_size=3, padding=1)),
            nn.ReLU(),
            AltConvTranspose2d(nn.ConvTranspose2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding='same'),
            nn.Sigmoid()
        ) 
  
    def forward(self, x): 
        encoded = self.encoder(x) 
        decoded = self.decoder(encoded) 
        return decoded

In [22]:
autoencoder = Autoencoder()

TARGET_SAMPLE_RATE = 20000
N_FFT = 1024
WIN_LENGTH = 512

spectogram = torchaudio.transforms.Spectrogram(n_fft=N_FFT)
inverse_spec = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT)

data = NSynth(  annotations_path = '../../Downloads/nsynth-test/examples.json', 
                audio_path = '../../Downloads/nsynth-test/audio', 
                target_sr = TARGET_SAMPLE_RATE,
                number_of_samples=40000,
                transform=spectogram)

loader = DataLoader(data, batch_size=64, shuffle=True)

In [23]:
def train(model, loader, n_epochs, loss_fn, lr=3e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    size = len(loader)
    for epoch_i in range(n_epochs):
        print(f'BATCH: [{epoch_i+1}/{n_epochs}]')
        for i, (noisy_signal, clean_signal) in enumerate(loader):
            # Forward Pass
            clean_signal_prediction = model(noisy_signal)
            loss = loss_fn(clean_signal_prediction, clean_signal)

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Display Results
            if i%10 == 0:
                print(f'    - [{i}/{size}]loss: {loss}')



In [24]:
loss_fn = nn.MSELoss()
train(autoencoder, loader, n_epochs=5, loss_fn=loss_fn)

BATCH: [1/5]
    - [0/64]loss: 375871.8125


KeyboardInterrupt: 

In [158]:
noisy_signal_spec, clean_signal_spec = next(iter(loader))

noisy_signal_spec_i = noisy_signal_spec[0]
clean_signal_spec = clean_signal_spec[0]

In [164]:
inverse_spec = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT)

In [165]:
noisy_signal_spec_i.shape

torch.Size([1, 128, 201])