Let's train SEAE with Fast.ai

In [None]:
import os
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchaudio import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from data import SpeechDataset
import time
from model import Autoencoder
from transformer import Transformer
import pdb
import matplotlib.pyplot as plt
from pypesq import pesq
import torch.nn.functional as F
import torchaudio
from tqdm.notebook import trange, tqdm
from IPython.display import Audio
from fastai.vision import *
import numpy as np

In [None]:
EXPLORE = True

In [None]:
batch_size = 128

In [None]:
dataset = SpeechDataset('data/clean/360/', 'data/noise/', window_size=16384, overlap=50, snr=10)

dataset_size = len(dataset)
if EXPLORE:
    dataset_size = batch_size * 3
indices = list(range(dataset_size))
validation_split = .2
split = int(np.floor(validation_split * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

In [None]:
model = Autoencoder(bs=batch_size).cuda()

In [None]:
# Implementation of https://arxiv.org/pdf/1903.03107v1.pdf 
def weightedSDR(output, target):
    output = output.view(-1, 16384)
    target = target.view(-1, 16384)
    
    dot_product = torch.sum(output * target)
    loss = (-1 * dot_product) / (torch.norm(target) * torch.norm(output))
    
    return loss

In [None]:
data = DataBunch(train_loader, validation_loader)
learner = Learner(data, model, opt_func=torch.optim.RMSprop, loss_func=root_mean_squared_error, callback_fns=ShowGraph)

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot()

In [None]:
learner.fit_one_cycle(20, 1e-5)

In [None]:
learner.fit_one_cycle(100, 1e-5)

In [None]:
torch.save(model.state_dict(), f'models/attn_3_cycle.pth')

#### Let's hear how the model denoises

In [None]:
from pypesq import pesq

In [None]:
pesqs = []

data = dataset[0]
model.eval()
with torch.no_grad():
    sample = data[0].cuda()
    output = model(data[0].reshape(-1, 1, 16384).cuda())
    ref = output[0, :, :].cpu().detach().numpy().T[:, 0]
    target = data[1][0, :].cpu().detach().numpy().T[:]
    
    plt.figure()
    plt.plot(ref)
    plt.figure()
    plt.plot(target)

    pesqs.append(pesq(target, ref, 16000))
        
print(round(sum(pesqs) / len(pesqs), 4))
# Change between ref/target to hear model output/original
Audio(ref, rate=16000)

In [None]:
src = dataset[0][0].reshape(1,1, 16384)
encoder_layer = nn.TransformerEncoderLayer(d_model=16384, nhead=4)
#src = torch.rand(10, 32, 512)
out = encoder_layer(src)
print(out.shape)