In [1]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.nn import Module, Parameter
from torch import FloatTensor
from scipy import signal
import numpy as np
from torchaudio import transforms
import matplotlib.pyplot as plt
import IPython.display as ipd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from scipy.signal import sosfiltfilt
import os
dirname = os.path.abspath('')
rootdir = os.path.split(dirname)[0]

H1_TRAINING_INPUT_PATH = "".join([rootdir, "/data/train/ht1-input.wav"])
H1_TRAINING_TARGET_PATH = "".join([rootdir, "/data/train/ht1-target.wav"])

metadata = torchaudio.info(H1_TRAINING_INPUT_PATH)
print(metadata)


AudioMetaData(sample_rate=44100, num_frames=14994001, num_channels=1, bits_per_sample=16, encoding=PCM_S)


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device=", device) 

device= cpu


In [3]:
train_input, fs = torchaudio.load(H1_TRAINING_INPUT_PATH)
train_target, fs = torchaudio.load(H1_TRAINING_TARGET_PATH)

In [4]:
#len(train_input)
print(train_input.shape)
print(train_target.shape)

torch.Size([1, 14994001])
torch.Size([1, 14994001])


In [5]:
train_input_array = train_input.numpy().reshape(-1)
train_target_array = train_target.numpy().reshape(-1)

# For Reproducibility

In [6]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f8cea174290>

## Initialize Dataloader

In [7]:
class DIIRDataSet(Dataset):
    def __init__(self, input, target, sequence_length):
        self.input = input
        self.target = target
        
        self._sequence_length = sequence_length
        self.input_sequence = self.wrap_to_sequences(self.input, self._sequence_length)
        self.target_sequence = self.wrap_to_sequences(self.target, self._sequence_length)
        self._len = self.input_sequence.shape[0]

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return {'input': self.input_sequence[index, :, :]
               ,'target': self.target_sequence[index, :, :]}

    def wrap_to_sequences(self, data, sequence_length):
        num_sequences = int(np.floor(data.shape[0] / sequence_length))
        print(num_sequences)
        truncated_data = data[0:(num_sequences * sequence_length)]
        wrapped_data = truncated_data.reshape((num_sequences, sequence_length, 1))
        wrapped_data = wrapped_data.permute(0,2,1)
        print(wrapped_data.shape)
        return np.float32(wrapped_data)


In [8]:
train_input.squeeze(0).shape

torch.Size([14994001])

In [9]:
batch_size = 512#1024
sequence_length = 80
train_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = False, pin_memory=True, drop_last=True) #? what does the shuffle really shuffles here?

187425
torch.Size([187425, 1, 80])
187425
torch.Size([187425, 1, 80])


In [10]:
len(loader)

366

# Declare Model

In [11]:
class IIRNN(Module):
    def __init__(self, n_input=1, n_output=1, hidden_size=80, n_channel=1):
        super(IIRNN, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size = 80, hidden_size = self.hidden_size)

        self.fc1 = nn.Linear(self.hidden_size, 512)
        self.fc2 = nn.Linear(512, 80)
#         self.fc3 = nn.Linear(80 , 80)

        
        self.mlp_layer = nn.Sequential(
            self.fc1 ,
            nn.Tanh(),
            self.fc2,
#             nn.Tanh(),
#             self.fc3
        )

    def forward(self, x):

    
#         print(f"xshape: {x.shape}")
        x, hn = self.lstm(x) # output; (sequence_length, batch_size, hidden_size)
#         print(f"x rnn shape: {x.shape}")
        x = self.mlp_layer(x)
        x, hn = self.lstm(x)

        return x


In [12]:
model = IIRNN()

## Define optimizer and criterion

In [13]:
import torch.nn as nn
from torch.optim import Adam

n_epochs = 100
lr = 1e-3

optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

criterion = nn.MSELoss()

In [14]:
for ind, batch in enumerate(loader):
    input_seq_batch = batch['input'].to(device)
    target_seq_batch = batch['target'].to(device)
    optimizer.zero_grad()
    predicted_output = model(input_seq_batch)
    print(input_seq_batch.shape)
    print(predicted_output.shape)
    break

torch.Size([512, 1, 80])
torch.Size([512, 1, 80])


# Define train loop

In [15]:
def train(criterion, model, loader, optimizer):
    model.train()
    device = next(model.parameters()).device
    total_loss = 0
    
    for ind, batch in enumerate(loader):
        input_seq_batch = batch['input'].to(device)
        target_seq_batch = batch['target'].to(device)
        optimizer.zero_grad()
        predicted_output = model(input_seq_batch)

        # premphasis filter
        target_seq_batch_filt = signal.filtfilt([1, -0.95], [1], target_seq_batch)
        predicted_output_filt = signal.filtfilt([1, -0.95], [1], predicted_output)
        
        loss = criterion(target_seq_batch_filt, predicted_output_filt)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    total_loss /= len(loader)
    return total_loss

## Train!

In [16]:
for epoch in range(n_epochs):
    loss = train(criterion, model, loader, optimizer)
    print("Epoch {} -- Loss {:3E}".format(epoch, loss))

Epoch 0 -- Loss 5.729002E-02
Epoch 1 -- Loss 3.587623E-02
Epoch 2 -- Loss 3.248468E-02
Epoch 3 -- Loss 3.028600E-02
Epoch 4 -- Loss 2.902006E-02
Epoch 5 -- Loss 2.751315E-02
Epoch 6 -- Loss 2.609730E-02
Epoch 7 -- Loss 2.527577E-02
Epoch 8 -- Loss 2.445731E-02
Epoch 9 -- Loss 2.390423E-02
Epoch 10 -- Loss 2.326568E-02
Epoch 11 -- Loss 2.277314E-02
Epoch 12 -- Loss 2.234956E-02
Epoch 13 -- Loss 2.192946E-02
Epoch 14 -- Loss 2.162480E-02
Epoch 15 -- Loss 2.128638E-02
Epoch 16 -- Loss 2.097728E-02
Epoch 17 -- Loss 2.063996E-02
Epoch 18 -- Loss 2.040741E-02
Epoch 19 -- Loss 2.009798E-02
Epoch 20 -- Loss 1.970065E-02
Epoch 21 -- Loss 1.940156E-02
Epoch 22 -- Loss 1.908588E-02
Epoch 23 -- Loss 1.880674E-02
Epoch 24 -- Loss 1.860345E-02
Epoch 25 -- Loss 1.833004E-02
Epoch 26 -- Loss 1.808931E-02
Epoch 27 -- Loss 1.783240E-02
Epoch 28 -- Loss 1.761018E-02
Epoch 29 -- Loss 1.738038E-02
Epoch 30 -- Loss 1.716311E-02
Epoch 31 -- Loss 1.697613E-02
Epoch 32 -- Loss 1.677983E-02
Epoch 33 -- Loss 1.6

# Evaluate

In [17]:
save_path = os.path.join('../models/lstm_mlp_lstm'.format(n_epochs-1))
torch.save(model.state_dict(), save_path)

In [18]:
val_batch_size = 128
sequence_length = 80
val_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, pin_memory=True, drop_last=True)

187425
torch.Size([187425, 1, 80])
187425
torch.Size([187425, 1, 80])


In [19]:
def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")

In [20]:
def save_audio(batch):
    #1024,512,1
    out_batch = batch.detach().cpu()
    out_batch = out_batch.squeeze(-1).flatten()
    print(out_batch.shape)
    return out_batch

In [21]:
import soundfile as sf

out_path = '../output/'
sample_rate = 44100
save_tensor = torch.zeros(14994001,80)
with torch.no_grad():
    for i, val_batch in enumerate(val_loader):
        input_seq_batch = val_batch['input'].to(device)
        #target_seq_batch = val_batch['target'].to(device)
        predicted_output = model(input_seq_batch)
        output_tmp = predicted_output.squeeze().detach().cpu()
        #print(output_tmp.shape)
        save_tensor[i,:] = output_tmp
    
    print(save_tensor.shape)
    out_audio = save_audio(save_tensor)
    print(out_audio.shape)
    path = os.path.join(out_path, "lstm_mlp_lstm_h1.wav")
    print("Exporting {}".format(path))
    sf.write(path, out_audio, sample_rate,'PCM_24')
    #torchaudio.save(path, out_audio, sample_rate, encoding="PCM_S", bits_per_sample=16)
    inspect_file(path)
    

torch.Size([14994001, 80])
torch.Size([1199520080])
torch.Size([1199520080])
Exporting ../output/lstm_mlp_lstm_h1.wav
----------
Source: ../output/lstm_mlp_lstm_h1.wav
----------
 - File size: 3598560284 bytes
 - AudioMetaData(sample_rate=44100, num_frames=1199520080, num_channels=1, bits_per_sample=24, encoding=PCM_S)


In [22]:
#small test on padding
t4d = torch.ones(3, 3, 4)
print(t4d.shape)
out = F.pad(t4d, (3,0)) #"constant", 0
print(out.shape)

torch.Size([3, 3, 4])
torch.Size([3, 3, 7])


In [23]:
t4d[1,1,:]

tensor([1., 1., 1., 1.])

In [24]:
out[1,1,:]

tensor([0., 0., 0., 1., 1., 1., 1.])