In [None]:
!pip install pytorch_lightning

In [None]:
!nvidia-smi

In [None]:
import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from scipy.stats import loguniform
from torchvision.transforms import ToTensor

# Dataset

In [None]:
class IIRFilterDataset(torch.utils.data.Dataset):
  def __init__(self,
               num_points = 512,
               filter_order = 2,
               eps = 1e-8,
               factor = 1,
               num_examples = 10000,
               standard_norm = True):
    super(IIRFilterDataset, self).__init__()
    self.num_points = num_points
    self.filter_order = filter_order
    self.eps = eps
    self.factor = factor
    self.num_examples = num_examples
    self.standard_norm = standard_norm

    # normalizting coef
    self.sample_size = int(10e4)
    self.stats = {}
  
    if standard_norm:
      print("Computing normalization factors...")
      coefs = np.zeros((self.sample_size, (self.filter_order + 1) * 2))
      mags = np.zeros((self.sample_size, self.num_points))
      phss = np.zeros((self.sample_size, self.num_points))

      for n in range(self.sample_size):
        coef, mag, phs = self.generate_filter()
        coefs[n,:] = coef
        mags[n,:] = mag
        phss[n,:] = phs

      # compute statistics
      self.stats["coef"] = {
          "mean" : np.mean(coefs, axis=0),
          "std" : np.std(coefs, axis=0)
      }
      self.stats["mag"] = {
          "mean" : np.mean(mags, axis=0),
          "std" : np.std(mags, axis=0)
      }
      self.stats["phs"] = {
          "mean" : np.mean(phss, axis=0),
          "std" : np.std(phss, axis=0)
      }

  def __len__(self):
    return self.num_examples

  def __getitem__(self, idx):

    # generate random filter coeffiecents
    coef, mag, phs = self.generate_filter()
    
    # apply normalization
    if self.stats is not None:
      #coef = (coef - self.stats["coef"]["mean"]) / self.stats["coef"]["std"] 
      mag = (mag - self.stats["mag"]["mean"]) / self.stats["mag"]["std"] 
      phs = (phs - self.stats["phs"]["mean"]) / self.stats["phs"]["std"] 

    # convert to float32
    mag = mag.astype('float32')
    phs = phs.astype('float32')
    coef = coef.astype('float32')

    return mag, phs, coef

  def generate_filter(self):   
    """ Generate a random filter along with its magnitude and phase response.
    
    Returns:
      coef (ndarray): Recursive filter coeffients stored as [b0, b1, ..., bN, a0, a1, ..., aN].
      mag (ndarray): Magnitude response of the filter (linear) of `num_points`.
      phs (ndarray): Phase response of the filter (unwraped) of 'num_points`.
    
    """
    # first generate random coefs
    # we lay out coefs in an array [b0, b1, b2, a0, a1, a2]
    # in the future we will want to enforce some kind of normalization
    #coef = self.factor * (np.random.rand((self.filter_order + 1) * 2) * 2) - 1

    wn = float(loguniform.rvs(1e-3, 1e0))
    rp = np.random.rand() * 10
    b, a = scipy.signal.cheby1(self.filter_order, rp, wn)
    coef = np.concatenate((b, a), axis=-1)

    w, h = scipy.signal.freqz(b=coef[:3], a=coef[3:], worN=self.num_points)

    mag = np.abs(h)
    phs = np.unwrap(np.angle(h))

    return coef, mag, phs


# Plotting


In [None]:
def plot_compare_response(pred_coef, target_coef, num_points=512, eps=1e-8, fs=44100):

  pred_coef *= 1
  target_coef *= 1

  w_pred, h_pred = scipy.signal.freqz(b=pred_coef[:3], a=pred_coef[3:], worN=num_points, fs=fs)
  w_target, h_target = scipy.signal.freqz(b=target_coef[:3], a=target_coef[3:], worN=num_points, fs=fs)

  fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))

  mag_pred = 20 * np.log10(np.abs(h_pred) + 1e-8)
  mag_target = 20 * np.log10(np.abs(h_target) + 1e-8)
  ax[0].plot(w_target, mag_target, color='b', label="target")
  ax[0].plot(w_pred, mag_pred, color='r', label="pred")
  ax[0].set_xscale('log')
  ax[0].set_ylim([-60, 40])
  ax[0].set_ylabel('Amplitude [dB]')
  ax[0].set_xlabel('Frequency [rad/sample]')
  ax[0].legend()
  ax[0].grid()

  ang_pred = np.unwrap(np.angle(h_pred))
  ang_target = np.unwrap(np.angle(h_target))
  ax[1].plot(w_target, ang_target, color='b', label="target")
  ax[1].plot(w_pred, ang_pred, color='r', label="pred")
  ax[1].set_ylabel('Angle (radians)')
  ax[1].grid()
  ax[1].axis('tight')
  ax[1].legend()

  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  buf.seek(0)
  image = PIL.Image.open(buf)
  image = ToTensor()(image)#.unsqueeze(0)

  plt.close("all")

  return image


# Model

In [None]:
class MLP(pl.LightningModule):
  """ Multi-layer perceptron module. """
  def __init__(self, 
               num_points = 512,
               num_layers = 4,
               hidden_features = 2048,
               filter_order = 2,
               lr = 3e-4,
               **kwargs):
    super(MLP, self).__init__()

    self.save_hyperparameters()

    self.layers = torch.nn.ModuleList()

    for n in range(self.hparams.num_layers):
      in_features = self.hparams.hidden_features if n != 0 else self.hparams.num_points
      out_features = self.hparams.hidden_features
      self.layers.append(torch.nn.Sequential(
        torch.nn.Linear(in_features, out_features),
        torch.nn.PReLU(),
      ))

    n_coef = (self.hparams.filter_order + 1) * 2
    self.layers.append(torch.nn.Linear(out_features, n_coef))

  def forward(self, mag, phs=None):
    x = mag
    for layer in self.layers:
      x = layer(x) 

    return x

  def training_step(self, batch, batch_idx):
    mag, phs, coef = batch
    pred_coef = self(mag)
    loss = torch.nn.functional.mse_loss(pred_coef, coef)

    self.log('train_loss', 
              loss, 
              on_step=True, 
              on_epoch=True, 
              prog_bar=True, 
              logger=True)
    return loss

  def validation_step(self, batch, batch_idx):
    mag, phs, coef = batch
    pred_coef = self(mag)
    loss = torch.nn.functional.mse_loss(pred_coef, coef)
    
    self.log('val_loss', loss)

    # move tensors to cpu for logging
    outputs = {
        "pred_coef" : pred_coef.cpu().numpy(),
        "coef": coef.cpu().numpy(),
        "mag"  : mag.cpu().numpy()}

    return outputs

  def validation_epoch_end(self, validation_step_outputs):
    # flatten the output validation step dicts to a single dict
    outputs = res = {k: v for d in validation_step_outputs for k, v in d.items()} 

    pred_coef = outputs["pred_coef"][0]
    coef = outputs["coef"][0]

    self.logger.experiment.add_image("mag", plot_compare_response(pred_coef, coef), self.global_step)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

  # add any model hyperparameters here
  @staticmethod
  def add_model_specific_args(parent_parser):
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
    # --- model related ---
    parser.add_argument('--num_points', type=int, default=512)
    parser.add_argument('--filter_order', type=int, default=2)

    # --- training related ---
    parser.add_argument('--lr', type=float, default=1e-3)

    return parser

# Training

In [None]:
batch_size = 128
num_workers = 0
shuffle = True

num_examples = 10000
num_points = 64

# init the trainer and model 
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=10, auto_lr_find=False)

# setup the dataloaders
train_dataset = IIRFilterDataset(num_points=num_points, num_examples=num_examples)
train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               shuffle=shuffle,
                                               batch_size=batch_size,
                                               num_workers=num_workers)

# build the model
model = MLP(num_points=num_points)

# Run learning rate finder
#lr_finder = trainer.tuner.lr_find(model, train_dataloader, min_lr=1e-08, max_lr=0.01)
# Pick point based on plot, or get suggestion
#new_lr = lr_finder.suggestion()
# update hparams of the model
model.hparams.lr = 0.001
#print(new_lr)

# train!
trainer.fit(model, train_dataloader, train_dataloader)

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/