In [None]:
!pip install pytorch_lightning

In [None]:
!nvidia-smi

In [None]:
import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from scipy.stats import loguniform
from torchvision.transforms import ToTensor

# Dataset

In [None]:
class IIRFilterDataset(torch.utils.data.Dataset):
  def __init__(self,
               num_points = 512,
               filter_order = 2,
               eps = 1e-8,
               factor = 1,
               num_examples = 10000):
    super(IIRFilterDataset, self).__init__()
    self.num_points = num_points
    self.filter_order = filter_order
    self.eps = eps
    self.factor = factor
    self.num_examples = num_examples
  
  def __len__(self):
    return self.num_examples

  def __getitem__(self, idx):
    
    # first generate random coefs
    # we lay out coefs in an array [b0, b1, b2, a0, a1, a2]
    # in the future we will want to enforce some kind of normalization
    coef = self.factor * (np.random.rand((self.filter_order + 1) * 2) * 2) - 1

    #wn = loguniform.rvs(1e-2, 1e0)
    #rp = np.random.rand() * 10

    wn = np.random.choice([0.05, 0.1, 0.4])
    rp = np.random.choice([1, 10, 20])

    b, a = scipy.signal.cheby1(self.filter_order, rp, wn)
    coef = np.concatenate((b, a), axis=-1)

    # now we comptue the mag and phase response of the filter
    w, h = scipy.signal.freqz(b=coef[:3], a=coef[3:], worN=self.num_points)

    mag = np.abs(w)

    #print(np.max(mag), np.min(mag))
    
    # convert to float32
    mag = mag.astype('float32')/self.factor
    coef = coef.astype('float32')/self.factor

    return mag, coef

# Plotting


In [None]:
def plot_compare_mag_response(pred_coef, target_coef, num_points=512, eps=1e-8, fs=44100):

  fig, ax = plt.subplots()

  pred_coef *= 1
  target_coef *= 1

  w_pred, h_pred = scipy.signal.freqz(b=pred_coef[:3], a=pred_coef[3:], worN=num_points, fs=fs)
  w_target, h_target = scipy.signal.freqz(b=target_coef[:3], a=target_coef[3:], worN=num_points, fs=fs)

  mag_pred = 20 * np.log10(np.abs(h_pred) + 1e-8)
  mag_target = 20 * np.log10(np.abs(h_target) + 1e-8)

  ax.plot(w_target, mag_target, color='b', label="target")
  ax.plot(w_pred, mag_pred, color='r', label="pred")

  ax.set_xscale('log')
  ax.set_ylim([-60, 40])
  ax.set_ylabel('Amplitude [dB]')
  ax.set_xlabel('Frequency [rad/sample]')
  plt.legend()
  plt.grid()

  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  buf.seek(0)
  image = PIL.Image.open(buf)
  image = ToTensor()(image)#.unsqueeze(0)

  plt.close("all")

  return image


# Model

In [None]:
class MLP(pl.LightningModule):
  """ Multi-layer perceptron module. """
  def __init__(self, 
               num_points = 512,
               num_layers = 4,
               hidden_features = 2048,
               filter_order = 2,
               lr = 3e-4,
               **kwargs):
    super(MLP, self).__init__()

    self.save_hyperparameters()

    self.layers = torch.nn.ModuleList()

    for n in range(self.hparams.num_layers):
      in_features = self.hparams.hidden_features if n != 0 else self.hparams.num_points
      out_features = self.hparams.hidden_features
      self.layers.append(torch.nn.Sequential(
        torch.nn.Linear(in_features, out_features),
        torch.nn.PReLU(),
      ))

    n_coef = (self.hparams.filter_order + 1) * 2
    self.layers.append(torch.nn.Linear(out_features, n_coef))

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)

    return x

  def training_step(self, batch, batch_idx):
    mag, coef = batch
    pred_coef = self(mag)
    loss = torch.nn.functional.mse_loss(pred_coef, coef)

    self.log('train_loss', 
              loss, 
              on_step=True, 
              on_epoch=True, 
              prog_bar=True, 
              logger=True)
    return loss

  def validation_step(self, batch, batch_idx):
    mag, coef = batch
    pred_coef = self(mag)
    loss = torch.nn.functional.mse_loss(pred_coef, coef)
    
    self.log('val_loss', loss)

    # move tensors to cpu for logging
    outputs = {
        "pred_coef" : pred_coef.cpu().numpy(),
        "coef": coef.cpu().numpy(),
        "mag"  : mag.cpu().numpy()}

    return outputs

  def validation_epoch_end(self, validation_step_outputs):
    # flatten the output validation step dicts to a single dict
    outputs = res = {k: v for d in validation_step_outputs for k, v in d.items()} 

    pred_coef = outputs["pred_coef"][0]
    coef = outputs["coef"][0]

    self.logger.experiment.add_image("mag", plot_compare_mag_response(pred_coef, coef), self.global_step)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

  # add any model hyperparameters here
  @staticmethod
  def add_model_specific_args(parent_parser):
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
    # --- model related ---
    parser.add_argument('--num_points', type=int, default=512)
    parser.add_argument('--filter_order', type=int, default=2)

    # --- training related ---
    parser.add_argument('--lr', type=float, default=1e-3)

    return parser

# Training

In [None]:
batch_size = 64
num_workers = 0
shuffle = True

num_examples = 10000
num_points = 16

# init the trainer and model 
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=10, auto_lr_find=False)

# setup the dataloaders
train_dataset = IIRFilterDataset(num_points=num_points, num_examples=num_examples)
train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               shuffle=shuffle,
                                               batch_size=batch_size,
                                               num_workers=num_workers)

# build the model
model = MLP(num_points=num_points)

# Run learning rate finder
#lr_finder = trainer.tuner.lr_find(model, train_dataloader, min_lr=1e-08, max_lr=0.01)
# Pick point based on plot, or get suggestion
#new_lr = lr_finder.suggestion()
# update hparams of the model
#model.hparams.lr = new_lr
#print(new_lr)

# train!
trainer.fit(model, train_dataloader, train_dataloader)

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/