### SEGAN without the GAN

Original SEGAN paper: https://arxiv.org/pdf/1703.09452.pdf

In [1]:
import os
import torch
import re
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchaudio import transforms
from data import SpeechDataset
import time
from model import Autoencoder
import pdb
import matplotlib.pyplot as plt
from pypesq import pesq
import torch.nn.functional as F
import torchaudio
from tqdm.notebook import trange, tqdm
from IPython.display import Audio

In [2]:
num_epochs = 10
batch_size = 512
learning_rate = 1e-5

In [3]:
epochs = filter(lambda x: re.search("^seae_epoch_\d+\.pth$", x), os.listdir('models'))
epochs = map(lambda x: int(re.search("^seae_epoch_(\d+)\.pth$", x)[1]), epochs)
last_epoch = max(epochs)

In [4]:
MODEL_PATH = f'models/seae_epoch_{last_epoch}.pth'

If you'd like to run the training loop, download the OpenSLR12 dataset (http://www.openslr.org/12/), convert all .flac files to .wav and copy to 'data/clean/open_slr'

In [5]:
dataset = SpeechDataset('data/clean/360/', 'data/noise/', window_size=16384, overlap=50)

In [6]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = Autoencoder(bs=batch_size).cuda()
model.load_state_dict(torch.load(MODEL_PATH))
criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=1e-5)

model.train()

# data = next(iter(dataloader))
for epoch in trange(num_epochs):
    print(f'Starting epoch {epoch + 1 + last_epoch}')
    
    pbar = tqdm()
    pbar.reset(total=(len(dataset) // batch_size))
    
    for i, data in enumerate(dataloader):
        inp = Variable(data[0]).cuda()
        output = model(inp)
        loss = criterion(output, Variable(data[1]).cuda())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.update()
    
    pbar.refresh()
    
    torch.save(model.state_dict(), f'models/noisy_seae_epoch_{epoch + last_epoch + 1}.pth')
    
    if True or epoch % 5 == 0:
        print(f'epoch [{epoch}/{num_epochs}]')
        print(round(loss.item(), 5))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Starting epoch 4


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [None]:
model

In [None]:
pesqs = []

model.eval()
with torch.no_grad():
    sample = data.cuda()
    for i, _s in enumerate(sample[:100]):
        output = model(sample)
        ref = output[i, :, :].cpu().detach().numpy().T
        target = sample[i, :, :].cpu().detach().numpy().T
        # plt.figure()
        # plt.plot(ref)
        # plt.figure()
        # plt.plot(target)

        pesqs.append(pesq(ref[:, 0], target[:, 0], 16000))
print(sum(pesqs) / len(pesqs))
Audio(ref[:, 0], rate=16000)

In [None]:
f = 'data/clean/open_slr/2902-9006-0001.wav'
wave, _ = torchaudio.load(f)
specgram = torchaudio.transforms.Spectrogram(1024, 300)(wave)

print("Shape of spectrogram: {}".format(specgram.size()))

plt.figure()
plt.imshow(specgram.log2()[0,:,:].numpy())