In [None]:
# model definition, seq2seq stateful lstm for real-time audio processing

import math
import torch
import torch.optim as optim
import torch.nn as nn

class AudioLSTM(nn.Module):
  def __init__(self, input_size, output_size, hidden_size, skip=1, num_layers=1):
    super(AudioLSTM, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.skip = skip

    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, output_size)

    # needed for stateful LSTM
    self.hidden = None  # (hidden state, cell state)

  def forward(self, x):
    """
    :param x: tensor of shape (batch_size, seq_length, features)
    """
    if self.hidden is None:
      self.hidden = (
        torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device),
        torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
      )
    
    out, self.hidden = self.lstm(x, self.hidden)

    out = self.fc(out)

    # add residual step
    if self.skip:
      out += x[:, :, :]
    
    return out
  
  def reset_hidden(self):
    self.hidden = None

  # detach hidden state, this resets gradient tracking on the hidden state
  def detach_hidden(self):
    if self.hidden.__class__ == tuple:
      self.hidden = tuple([h.clone().detach() for h in self.hidden])
    else:
      self.hidden = self.hidden.clone().detach()

  def train_epoch(self, input_data, target_data, loss_fcn, optim, bs, init_len=200, up_fr=1000):
    """
    :param input_data: x examples with shape (num_examples, seq_length, features)
    :param target_data: y target with shape (num_examples, seq_length, features)
    :param loss_fcn: our loss function
    :param optim: our optimizer
    :param bs: batch size, number of examples per batch
    :param init_len: number of audio samples to initialize hidden state
    :param up_fr: number of timesteps before we backprop (truncated backprop)
    """
    # shuffle the examples around
    shuffle = torch.randperm(input_data.shape[0])

    # Iterate over the batches
    ep_loss = 0
    for batch_i in range(math.ceil(shuffle.shape[0] / bs)):
      # Load batch of shuffled segments
      input_batch = input_data[shuffle[batch_i * bs:(batch_i + 1) * bs], :, :]
      target_batch = target_data[shuffle[batch_i * bs:(batch_i + 1) * bs],:, :]

      # Initialise network hidden state by processing some samples then zero the gradient buffers
      self(input_batch[:, 0:init_len, :])
      self.zero_grad()

      # Choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
      start_i = init_len
      batch_loss = 0
      # Iterate over the remaining samples in the mini batch
      for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
        # Process input batch with neural network
        output = self(input_batch[:, start_i:start_i + up_fr, :])

        # Calculate loss and update network parameters
        loss = loss_fcn(output, target_batch[:, start_i:start_i + up_fr, :])
        loss.backward()
        optim.step()

        # Set the network hidden state, to detach it from the computation graph
        self.detach_hidden()
        self.zero_grad()

        # Update the start index for the next iteration and add the loss to the batch_loss total
        start_i += up_fr
        batch_loss += loss

      # Add the average batch loss to the epoch loss and reset the hidden states to zeros
      ep_loss += batch_loss / (k + 1)
      self.reset_hidden()

    return ep_loss / (batch_i + 1)


In [None]:
# set up
input_size = 2  # num features, 2 for stereo
output_size = 2  # num features, 2 for stereo
hidden_size = 16  # hidden state size
skip_con = 1  # is there a skip connection?

# hyperparams
learn_rate = 1e-2

network = AudioLSTM(input_size=input_size, hidden_size=hidden_size,
                                     output_size=output_size, skip=skip_con)

if not torch.cuda.is_available():
    print('cuda device not available/not selected')
    cuda = 0
else:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)
    print('cuda device available')
    network = network.cuda()
    cuda = 1

# Set up training optimiser + scheduler + loss fcns and training info tracker
optimiser = optim.Adam(network.parameters(), lr=learn_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, verbose=True)
# loss_functions = training.LossWrapper(args.loss_fcns, args.pre_filt)

In [None]:
# TODO: train the model
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    train_loss = model.train_epoch(train_dataloader)
    print(f"Epoch {epoch + 1}, Loss: {train_loss}")