### SEGAN + WaveNet 

TODO:
* Encode wave with mu law
* Decode wave before reconstruction

In [1]:
import os
import torch
import re
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchaudio import transforms
from data import SpeechDataset
import time
from wavenetish_model import Wavenetish
import pdb
import matplotlib.pyplot as plt
from pypesq import pesq
import torch.nn.functional as F
import torchaudio
from tqdm.notebook import trange, tqdm
from IPython.display import Audio

In [2]:
torchaudio.set_audio_backend('sox_io')

In [3]:
if torch.backends.cudnn.enabled and torch.cuda.is_available():
    print('CUDA is ready!')

CUDA is ready!


In [11]:
num_epochs = 150
batch_size = 1024
learning_rate = 1e-3
window_size = 8192

In [12]:
preload_model_from_weights = False
overfit_one_batch = True
limit_samples = batch_size if overfit_one_batch else 0

In [13]:
if preload_model_from_weights:
    epochs = filter(lambda x: re.search("^seae_epoch_\d+\.pth$", x), os.listdir('models'))
    epochs = map(lambda x: int(re.search("^seae_epoch_(\d+)\.pth$", x)[1]), epochs)
    last_epoch = max(epochs)
else:
    last_epoch = 0

In [14]:
if int(last_epoch) > 0: MODEL_PATH = f'../models/seae_epoch_{last_epoch}.pth'

If you'd like to run the training loop, download the OpenSLR12 dataset (http://www.openslr.org/12/), convert all .flac files to .wav and copy to 'data/clean/open_slr'

In [15]:
dataset = SpeechDataset(clean_dir='data/clean/360/',
                        noise_dir='data/noise/', 
                        window_size=window_size, 
                        overlap=50, snr=5, 
                        limit_samples=batch_size,
                        output_one=True)

In [16]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## The training

In [17]:
model = Wavenetish(bs=batch_size, pay_attention=False).cuda()

if preload_model_from_weights:
    model.load_state_dict(torch.load(MODEL_PATH))
    last_epoch = 0
    
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-5)

model.train()

pbar = tqdm()
pbar.reset(total=(len(dataset) // batch_size))
    
for epoch in trange(num_epochs):
    print_epoch = epoch % 10 == 0
    save_state = print_epoch
    
    if print_epoch: print(f'Starting epoch {epoch + 1 + last_epoch}')
    
    for i, data in enumerate(dataloader):
        expected = data[1].cuda()
        output = model(data[0].cuda())
        loss = criterion(output, expected.reshape(batch_size))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    #pbar.update()    
    
    pbar.refresh()
    if save_state:
        pass
        #torch.save(model.state_dict(), f'models/noisy_seae_epoch_{epoch + last_epoch + 1}.pth')
    
    if print_epoch:
        print(f'epoch [{epoch}/{num_epochs}]')
        print(round(loss.item(), 5))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))

Starting epoch 1
epoch [0/80]
5.58871
Starting epoch 11
epoch [10/80]
16.60398
Starting epoch 21
epoch [20/80]
5.68861
Starting epoch 31
epoch [30/80]
5.56296
Starting epoch 41
epoch [40/80]
4.9871
Starting epoch 51
epoch [50/80]
4.96757
Starting epoch 61
epoch [60/80]
4.93975
Starting epoch 71
epoch [70/80]
4.77773



In [None]:
Audio(inp[0].cpu(), rate=16000)

## Looking (listening) to the results

In [None]:
clean_file = 'data/clean/360/1001-134707-0035.wav'
noise_file = 'data/noise/128316.wav'

In [2]:
import math
snr = 5
noise_wave = torchaudio.load(noise_file)[0]
clean_wave = torchaudio.load(clean_file)[0]

noise_len = len(noise_wave[0, :])
clean_len = len(clean_wave[0, :])

if noise_len < clean_len:
    repeat_times = math.ceil(clean_len / noise_len)
    noise_wave = noise_wave.repeat((1, repeat_times))

noised_wave = torch.add(clean_wave[0, :], noise_wave[0, :clean_len] / snr).reshape(1, -1)

NameError: name 'torchaudio' is not defined

In [3]:
first_n = 25000

nw = noise_wave[0, :first_n].reshape(-1)
wv = torchaudio.transforms.MuLawEncoding()(nw)

cw = clean_wave[0, :first_n].reshape(-1)
cwave = torchaudio.transforms.MuLawEncoding()(cw)

deml = torchaudio.transforms.MuLawDecoding()(wv)
xs = range(len(deml))

fig, axs = plt.subplots(5,figsize=(15,15))
axs[0].plot(wv)
axs[0].set_title('Mu Law encoded wave')

axs[4].plot(nw)
axs[4].set_title('Raw wave')

axs[2].set_title('Decoded wave')
axs[2].plot(deml)

axs[3].set_title('Overlap between raw and decoded waves')
axs[3].plot(xs, deml, xs, nw)

axs[1].set_title('Encoded clean signal')
axs[1].plot(cwave)

for ax in axs:
    ax.label_outer()

NameError: name 'noise_wave' is not defined

In [65]:
len(wv)

256000

In [25]:
Audio(noised_wave, rate=16000)

In [26]:
Audio(clean_wave, rate=16000)

In [2]:
from utils import windows

with torch.no_grad():
    model.eval()

    noise_inputs = windows(noised_wave, window_size, 50, step=1)

    predicts = []

    for sample in noise_inputs[0, :4000]:
        reshaped = sample.reshape(1, 1, window_size).cuda()
        predicts.append(model(reshaped)[0].detach().cpu()[0])

NameError: name 'model' is not defined

In [38]:
predicts = [float(i) for i in predicts]
Audio(predicts, rate=16000)

In [81]:
pesqs = []

data = next(iter(dataloader))
model.eval()
with torch.no_grad():
    sample = data[1].cuda()
    inp = data[0]
    
    for i, _s in enumerate(sample[:50]):
        output = model(data[0].cuda())
        ref = output[i, :, :].cpu().detach().numpy().T[:, 0]
        target = sample[i, :, :].cpu().detach().numpy().T[:, 0]
        noised = inp[i, :, :].cpu().detach().numpy().T[:, 0]
        
        pesqs.append(pesq(target, ref, 16000))
        
print(round(sum(pesqs) / len(pesqs), 4))

2.5972
