In [1]:
import torch
import torchaudio
import torch.nn as nn
import numpy as np
import torch
from torch.nn import Module, Parameter
from torch import FloatTensor

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device=", device) 

import os
dirname = os.path.abspath('')
rootdir = os.path.split(dirname)[0]

H1_TRAINING_INPUT_PATH = "".join([rootdir, "/data/train/ht1-input.wav"])
H1_TRAINING_TARGET_PATH = "".join([rootdir, "/data/train/ht1-target.wav"])

metadata = torchaudio.info(H1_TRAINING_INPUT_PATH)
print(metadata)

device= cpu
AudioMetaData(sample_rate=44100, num_frames=14994001, num_channels=1, bits_per_sample=16, encoding=PCM_S)


In [2]:
waveform_input, sample_rate = torchaudio.load(H1_TRAINING_INPUT_PATH)
waveform_target, sample_rate = torchaudio.load(H1_TRAINING_TARGET_PATH)

In [3]:
waveform_input.shape

torch.Size([1, 14994001])

In [4]:
sample_rate

44100

In [5]:
for s in range(waveform_input.shape[1] //sample_rate): 
    sample_input = waveform_input[:,s*sample_rate: ((s+1) * sample_rate)]
    sample_target = waveform_target[:,s*sample_rate: ((s+1) * sample_rate)]
    
    break

In [6]:
sample_input.shape

torch.Size([1, 44100])

In [7]:
sample_target.shape

torch.Size([1, 44100])

In [8]:
class MLP(Module): 
    def __init__(self):
        super(MLP, self).__init__()
        
        self.mlp_stack = nn.Sequential(
            nn.Linear(1, 10),
            nn.ReLU(),
            nn.Linear(10, 1),
#             nn.ReLU(),
#             nn.Linear(512, 1),
        )
        
    def forward(self, x):
        
        return self.mlp_stack(x)
        
        
    

In [9]:


class FirstOrderFIRcell(Module):
    def __init__(self, b0=1.0, b1=0.0):
        super(FirstOrderFIRcell, self).__init__()
        self.b0 = Parameter(FloatTensor([b0]))
        self.b1 = Parameter(FloatTensor([b1]))

    def init_states(self, size):
        state = torch.zeros(size).to(self.b0.device) # ? 
        return state

    def forward(self, input, state):
        output = self.b0 * input + state
        state = self.b1 * input 
        return output, state

class FirstOrderFIR(Module):
    def __init__(self):
        super(FirstOrderFIR, self).__init__()
        self.cell = FirstOrderFIRcell()
        
        

    def forward(self, input, initial_states=None):
        batch_size = input.shape[0]
        sequence_length = input.shape[1]
    

        if initial_states is None:
            states = self.cell.init_states(batch_size)
        else:
            states = initial_states

        out_sequence = torch.zeros(input.shape[:-1]).to(input.device)
        for s_idx in range(sequence_length):
            out_sequence[:, s_idx], states = self.cell(input[:, s_idx].view(-1), states)
        out_sequence = out_sequence.unsqueeze(-1)
        

        if initial_states is None:
            return out_sequence
        else:
            return out_sequence

In [10]:
class CompleteNetwork(Module): 
    def __init__(self):
        super(CompleteNetwork, self).__init__()
        self.FIR = FirstOrderFIR()
        self.MLP = MLP()
        
    def forward(self, x): 
        
        out = self.FIR(x)
        return self.MLP(out)
        

In [11]:
model = CompleteNetwork()

In [12]:
from torch.utils.data import Dataset
import numpy as np

class DIIRDataSet(Dataset):
    def __init__(self, input, target, sequence_length):
        self.input = input
        self.target = target
        self._sequence_length = sequence_length
        self.input_sequence = self.wrap_to_sequences(self.input, self._sequence_length)
        self.target_sequence = self.wrap_to_sequences(self.target, self._sequence_length)
        self._len = self.input_sequence.shape[0]

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return {'input': self.input_sequence[index, :, :]
               ,'target': self.target_sequence[index, :, :]}

    def wrap_to_sequences(self, data, sequence_length):
        num_sequences = int(np.floor(data.shape[0] / sequence_length))
        truncated_data = data[0:(num_sequences * sequence_length)]
        wrapped_data = truncated_data.reshape((num_sequences, sequence_length, 1))
        return np.float32(wrapped_data)

In [13]:
from scipy import signal
import numpy as np
fs = 44100
f0 = 20
f1 = 20e3
t = np.linspace(0, 60, 60*int(fs))

train_input = signal.chirp(t=t, f0=f0, t1=60, f1=f1, method='logarithmic') + np.random.normal(scale=5e-2, size=len(t))

# FIR filter 
n = 2
b = signal.firwin(n, 0.3, window = "hamming", pass_zero=True)
train_target = signal.filtfilt(b, 1, train_input)


# fc = 2e3
# sos = signal.butter(N=2, Wn=fc/fs, output='sos')
# train_target = signal.sosfilt(sos, train_input)

In [14]:
train_input

array([ 1.0139815 ,  0.94593441,  0.87812729, ...,  0.51546728,
       -0.68807516,  0.81333108])

In [15]:
train_target

array([ 1.0139815 ,  0.9459944 ,  0.91859071, ...,  0.04234066,
       -0.01183799,  0.81333108])

In [16]:
waveform_input

tensor([[-3.0518e-05, -3.0518e-05, -3.0518e-05,  ...,  0.0000e+00,
          0.0000e+00, -3.0518e-05]])

In [17]:
from torch.utils.data import DataLoader

batch_size = 1024
sequence_length = 512

# debug
loader = DataLoader(dataset=DIIRDataSet(train_input, train_target, sequence_length), batch_size=batch_size, shuffle = False)

In [18]:
import torch.nn as nn
from torch.optim import Adam

n_epochs = 100
lr = 1e-3

optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

criterion = nn.MSELoss()

In [19]:
def train(criterion, model, loader, optimizer):
    model.train()
    device = next(model.parameters()).device
    total_loss = 0
    for batch in loader:
        input_seq_batch = batch['input'].to(device)
        target_seq_batch = batch['target'].to(device)
        optimizer.zero_grad()
        predicted_output = model(input_seq_batch)
        loss = criterion(target_seq_batch, predicted_output)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    total_loss /= len(loader)
    return total_loss

In [20]:
for epoch in range(n_epochs):
    loss = train(criterion, model, loader, optimizer)
    print("Epoch {} -- Loss {:3E}".format(epoch, loss))

Epoch 0 -- Loss 9.053090E-01
Epoch 1 -- Loss 8.659043E-01
Epoch 2 -- Loss 8.299149E-01
Epoch 3 -- Loss 7.966242E-01
Epoch 4 -- Loss 7.657628E-01
Epoch 5 -- Loss 7.370465E-01
Epoch 6 -- Loss 7.101779E-01
Epoch 7 -- Loss 6.848853E-01
Epoch 8 -- Loss 6.609420E-01
Epoch 9 -- Loss 6.381674E-01
Epoch 10 -- Loss 6.164140E-01
Epoch 11 -- Loss 5.955585E-01
Epoch 12 -- Loss 5.754956E-01
Epoch 13 -- Loss 5.561337E-01
Epoch 14 -- Loss 5.373936E-01
Epoch 15 -- Loss 5.192062E-01
Epoch 16 -- Loss 5.015107E-01
Epoch 17 -- Loss 4.842526E-01
Epoch 18 -- Loss 4.673828E-01
Epoch 19 -- Loss 4.508575E-01
Epoch 20 -- Loss 4.346375E-01
Epoch 21 -- Loss 4.186881E-01
Epoch 22 -- Loss 4.029786E-01
Epoch 23 -- Loss 3.874831E-01
Epoch 24 -- Loss 3.721800E-01
Epoch 25 -- Loss 3.570532E-01
Epoch 26 -- Loss 3.420916E-01
Epoch 27 -- Loss 3.272896E-01
Epoch 28 -- Loss 3.126474E-01
Epoch 29 -- Loss 2.981717E-01
Epoch 30 -- Loss 2.838751E-01
Epoch 31 -- Loss 2.697759E-01
Epoch 32 -- Loss 2.558987E-01
Epoch 33 -- Loss 2.4

In [22]:
model.FIR.cell.b0;

Parameter containing:
tensor([1.0144], requires_grad=True)

In [23]:
model.FIR.cell.b1

Parameter containing:
tensor([0.2343], requires_grad=True)

In [None]:
y[n] = b0* x[n] + b1 * x[n-1] + Wx + b

In [None]:
# upload to github 

# compare the target to output on FIR example 

# FIR talk about CNN implementation

# IIR filter (biquads)



