In [1]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.nn import Module, Parameter
from torch import FloatTensor
from scipy import signal
import numpy as np
from torchaudio import transforms
import matplotlib.pyplot as plt
import IPython.display as ipd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from scipy.signal import sosfiltfilt
import os
dirname = os.path.abspath('')
rootdir = os.path.split(dirname)[0]

H1_TRAINING_INPUT_PATH = "".join([rootdir, "/data/train/ht1-input.wav"])
H1_TRAINING_TARGET_PATH = "".join([rootdir, "/data/train/ht1-target.wav"])

MUFF_TRAINING_INPUT_PATH = "".join([rootdir, "/data/train/muff-input.wav"])
MUFF_TRAINING_TARGET_PATH = "".join([rootdir, "/data/train/muff-target.wav"])

metadata = torchaudio.info(H1_TRAINING_INPUT_PATH)
print(metadata)


AudioMetaData(sample_rate=44100, num_frames=14994001, num_channels=1, bits_per_sample=16, encoding=PCM_S)


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device=", device) 

device= cpu


In [3]:
train_input, fs = torchaudio.load(MUFF_TRAINING_INPUT_PATH)
train_target, fs = torchaudio.load(MUFF_TRAINING_TARGET_PATH)

In [4]:
#len(train_input)
print(train_input.shape)
print(train_target.shape)

torch.Size([1, 14994001])
torch.Size([1, 14994001])


In [5]:
train_input_array = train_input.numpy().reshape(-1)
train_target_array = train_target.numpy().reshape(-1)

# For Reproducibility

In [6]:
torch.manual_seed(0)

<torch._C.Generator at 0x7ff510667290>

## Initialize Dataloader

In [7]:
class NeuralAudioDataSet(Dataset):
    def __init__(self, input, target, sequence_length):
        self.input = input
        self.target = target
        
        self._sequence_length = sequence_length
        self.input_sequence = self.wrap_to_sequences(self.input, self._sequence_length)
        self.target_sequence = self.wrap_to_sequences(self.target, self._sequence_length)
        self._len = self.input_sequence.shape[0]

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return {'input': self.input_sequence[index, :, :]
               ,'target': self.target_sequence[index, :, :]}

    def wrap_to_sequences(self, data, sequence_length):
        num_sequences = int(np.floor(data.shape[0] / sequence_length))
        print(num_sequences)
        truncated_data = data[0:(num_sequences * sequence_length)]
        wrapped_data = truncated_data.reshape((num_sequences, sequence_length, 1))
        wrapped_data = wrapped_data.permute(0,2,1)
        print(wrapped_data.shape)
        return np.float32(wrapped_data)


In [8]:
train_input.squeeze(0).shape

torch.Size([14994001])

In [9]:
batch_size = 32 #1024
sequence_length = 1024
train_dataset=NeuralAudioDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = False, pin_memory=True, drop_last=True) #? what does the shuffle really shuffles here?

14642
torch.Size([14642, 1, 1024])
14642
torch.Size([14642, 1, 1024])


In [10]:
len(loader)

457

# Declare Model

In [11]:
class IIRNN(Module):
    def __init__(self, n_input=1, n_output=1, hidden_size=80, n_channel=1):
        super(IIRNN, self).__init__()
        self.hidden_size = hidden_size
        # 
        self.lstm = nn.LSTM(input_size = 1, hidden_size = self.hidden_size, batch_first=True)

        self.fc1 = nn.Linear(self.hidden_size, 1)

        self.mlp_layer = nn.Sequential(
            self.fc1 ,
        )

    def forward(self, x):
        # sequence_length: nsamples in 1 training example
        # batch_size: number of groups of audio samples
        # input_size: nchannels for each audio sample: 1
        # hidden_size: number of features for a single audio sample
        
        x, hn = self.lstm(x.permute(0,2,1)) # output; (sequence_length, batch_size, hidden_size)
        
#         print(x.shape)
        
        x = self.mlp_layer(x)

        return x.permute(0,2,1)


In [12]:
model = IIRNN()

## Define optimizer and criterion

In [13]:
import torch.nn as nn
from torch.optim import Adam
from ignite.metrics import PSNR

n_epochs = 100
lr = 1e-3

optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

def esr(output, target):
    loss = torch.mean(torch.abs(target - output)**2) / (torch.mean(torch.abs(target)**2) + 1e-6)
    return loss


def edc(output, target):
    
    step_one = torch.abs(torch.mean(target - output, axis = 2))**2
    
    # avg over time steps but not batch here
    
    step_two = torch.mean(step_one)
    
    step_three = step_two / (torch.mean(torch.abs(target)**2) + 1e-6)
    
    return step_three

def etotal(output, target): 
    return esr(output, target) + edc(output, target)
    

# criterion = etotal()
criterion = nn.MSELoss()

# Define train loop

In [14]:
import torchaudio
from torchaudio.functional import lfilter

In [15]:
def train(criterion, model, loader, optimizer):
    model.train()
    device = next(model.parameters()).device
    total_loss = 0
    
    for ind, batch in enumerate(loader):
        input_seq_batch = batch['input'].to(device)
        target_seq_batch = batch['target'].to(device)

        optimizer.zero_grad()
        predicted_output = model(input_seq_batch)
        
        
        target_seq_batch_filt = lfilter(target_seq_batch, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        predicted_output_filt = lfilter(predicted_output, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        
        loss = criterion(target_seq_batch_filt, predicted_output_filt)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
#         print(f"Loss: {loss}")
        

    total_loss /= len(loader)

    return total_loss

## Train!

In [None]:
for epoch in range(n_epochs):
    loss = train(criterion, model, loader, optimizer)
    print("Epoch {} -- Loss {:3E}".format(epoch, loss))

Epoch 0 -- Loss 3.911578E-04
Epoch 1 -- Loss 3.826969E-04
Epoch 2 -- Loss 3.550321E-04
Epoch 3 -- Loss 3.368950E-04
Epoch 4 -- Loss 3.026051E-04
Epoch 5 -- Loss 3.200919E-04
Epoch 6 -- Loss 2.931160E-04
Epoch 7 -- Loss 2.710257E-04
Epoch 8 -- Loss 2.630627E-04
Epoch 9 -- Loss 2.609727E-04
Epoch 10 -- Loss 2.608075E-04
Epoch 11 -- Loss 2.864918E-04
Epoch 12 -- Loss 2.799091E-04
Epoch 13 -- Loss 2.702745E-04
Epoch 14 -- Loss 2.555225E-04
Epoch 15 -- Loss 2.689694E-04
Epoch 16 -- Loss 2.528190E-04
Epoch 17 -- Loss 2.548809E-04
Epoch 18 -- Loss 2.546860E-04
Epoch 19 -- Loss 2.523009E-04
Epoch 20 -- Loss 2.519202E-04
Epoch 21 -- Loss 2.507434E-04
Epoch 22 -- Loss 2.502143E-04
Epoch 23 -- Loss 2.493604E-04
Epoch 24 -- Loss 2.486217E-04
Epoch 25 -- Loss 2.474392E-04
Epoch 26 -- Loss 2.468442E-04
Epoch 27 -- Loss 2.470597E-04
Epoch 28 -- Loss 2.462033E-04
Epoch 29 -- Loss 2.453266E-04
Epoch 30 -- Loss 2.414719E-04
Epoch 31 -- Loss 2.433626E-04
Epoch 32 -- Loss 2.404930E-04
Epoch 33 -- Loss 2.3

# Evaluate

In [None]:
save_path = os.path.join('../models/lstm_mlp_premp_filter_mse_muff'.format(n_epochs-1))
torch.save(model.state_dict(), save_path)

In [None]:
val_batch_size = 128
sequence_length = 80
val_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, pin_memory=True, drop_last=True)

In [None]:
def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")

In [None]:
def save_audio(batch):
    #1024,512,1
    out_batch = batch.detach().cpu()
    out_batch = out_batch.squeeze(-1).flatten()
    print(out_batch.shape)
    return out_batch

In [None]:
import soundfile as sf

out_path = '../output/'
sample_rate = 44100
save_tensor = torch.zeros(14994001,80)
with torch.no_grad():
    for i, val_batch in enumerate(val_loader):
        input_seq_batch = val_batch['input'].to(device)
        #target_seq_batch = val_batch['target'].to(device)
        predicted_output = model(input_seq_batch)
        output_tmp = predicted_output.squeeze().detach().cpu()
        #print(output_tmp.shape)
        save_tensor[i,:] = output_tmp
    
    print(save_tensor.shape)
    out_audio = save_audio(save_tensor)
    print(out_audio.shape)
    path = os.path.join(out_path, "lstm_mlp_premp_filter_mse_muff.wav")
    print("Exporting {}".format(path))
    sf.write(path, out_audio, sample_rate,'PCM_24')
    #torchaudio.save(path, out_audio, sample_rate, encoding="PCM_S", bits_per_sample=16)
    inspect_file(path)
    