In [None]:
!wget https://zenodo.org/record/3824876/files/SignalTrain_LA2A_Dataset_1.1.tgz?download=1

In [None]:
!tar -xvf SignalTrain_LA2A_Dataset_1.1.tgz?download=1

In [None]:
!ls

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!mv SignalTrain_LA2A_Dataset_1.1/ "/content/drive/My Drive"

In [None]:
!mv ssh.tar.gz "/content/drive/My Drive"

In [None]:
!rm -rf /root/.ssh
!mkdir /root/.ssh
!tar -xvzf "/content/drive/My Drive/ssh.tar.gz"
!cp ssh-colab/* /root/.ssh && rm -rf ssh-colab && rm -rf ssh.tar.gz
#!chmod 700 /root/.ssh$
!touch /root/.ssh/known_hosts
!ssh-keyscan github.com >> /root/.ssh/known_hosts
!chmod 644 /root/.ssh/known_hosts
!chmod 600 /root/.ssh/id_rsa_colab
!ssh -T git@github.com

# Imports

In [None]:
!pip install git+ssh://git@github.com/csteinmetz1/auraloss.git

In [None]:
!pip install pytorch_lightning
!pip install torchaudio

In [None]:
import os
import sys
import glob
import torch
import auraloss # here is our package!
import torchaudio
import numpy as np
import torchsummary
from google.colab import output
import pytorch_lightning as pl
from argparse import ArgumentParser
torchaudio.set_audio_backend("sox_io")

# Dataset

In [None]:
# first we will load the full dataset onto the local disk (this takes about 20 min)
!mkdir "/content/data"
!rsync -aP "/content/drive/My Drive/SignalTrain_LA2A_Dataset_1.1.zip/" "/content/data/"
!unzip "/content/data/SignalTrain_LA2A_Dataset_1.1.zip"

In [None]:
class SignalTrainLA2ADataset(torch.utils.data.Dataset):
    """ SignalTrain LA2A dataset. Source: [10.5281/zenodo.3824876](https://zenodo.org/record/3824876)."""
    def __init__(self, root_dir, subset="train", length=16384, preload=False):
        """
        Args:
            root_dir (str): Path to the root directory of the SignalTrain dataset.
            subset (str, optional): Pull data either from "train", "val", or "test" subsets. (Default: "train")
            length (int, optional): Number of samples in the returned examples. (Default: 40)
            preload (bool, optional): Read in all data into RAM during init. (Default: False)
        """
        self.root_dir = root_dir
        self.subset = subset
        self.length = length
        self.preload = preload

        # get all the target files files in the directory first
        self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
        self.input_files  = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))
        self.params       = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]

        self.examples = [] 
        self.audio_files = []
        self.hours = 0  # total number of hours of data in the subset

        # ensure that the sets are ordered correctlty
        self.target_files.sort()
        self.input_files.sort()

        # loop over files to count total length
        for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):
            print(os.path.basename(tfile), os.path.basename(ifile))
            md = torchaudio.info(tfile)
            self.hours += (md.num_frames / md.sample_rate) / 3600 
            num_frames = md.num_frames

            if self.preload:
              output.clear('status_text')
              with output.use_tags('status_text'):
                print(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...")
              input, sr  = torchaudio.load(ifile, normalize=False)
              target, sr = torchaudio.load(tfile, normalize=False)
              #input /= ((2**31) - 1) # apply float32 normalization
              #target /= ((2**31) - 1)
              input = input.half()
              target = target.half()
              self.audio_files.append({"target" : target, "input" : input})
              num_frames = input.shape[-1]
              if self.subset == "train":
                if idx > 25: break
              if self.subset == "val":
                if idx > 1: break
  
            # create one entry for each patch
            for n in range((num_frames // self.length) - 1):
                offset = int(n * self.length)
                self.examples.append({"idx": idx, 
                                      "target_file" : tfile, 
                                      "input_file" : ifile, 
                                      "params" : params, 
                                      "offset": offset, 
                                      "frames" : num_frames})

        # we then want to get the input files
        print(f"Located {len(self.examples)} examples totaling {self.hours:0.1f} hr in the {self.subset} subset.")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):

        if self.preload:
          audio_idx = self.examples[idx]["idx"]
          offset = self.examples[idx]["offset"]
          input = self.audio_files[audio_idx]["input"][:,offset:offset+self.length]
          target = self.audio_files[audio_idx]["target"][:,offset:offset+self.length]
        else:
          offset = self.examples[idx]["offset"] 
          input, sr  = torchaudio.load(self.examples[idx]["input"], 
                                      num_frames=self.length, 
                                       frame_offset=offset, 
                                       normalize=False)
          target, sr = torchaudio.load(self.examples[idx]["target"], 
                                       num_frames=self.length, 
                                       frame_offset=offset, 
                                       normalize=False)
          # apply float32 normalization
          input /= ((2**31) - 1)
          target /= ((2**31) - 1)

        # at random with p=0.5 flip the phase 
        if np.random.rand() > 0.5:
          input *= -1
          target *= -1

        # then get the tuple of parameters
        params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
        params[:,1] /= 100

        return input, target, params

# Model

In [None]:
def center_crop(x, shape):
    start = (x.shape[-1]-shape[-1])//2
    stop  = start + shape[-1]
    return x[...,start:stop]

class FiLM(torch.nn.Module):
    def __init__(self, num_features, cond_dim):
        super(FiLM, self).__init__()
        self.num_features = num_features
        self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
        self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond):

        cond = self.adaptor(cond)
        g, b = torch.chunk(cond, 2, dim=-1)
        g = g.permute(0,2,1)
        b = b.permute(0,2,1)

        x = self.bn(x)      # apply BatchNorm without affine
        x = (x * g) + b     # then apply conditional affine

        return x

class TCNBlock(torch.nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=0, dilation=1, depthwise=False, conditional=False):
        super(TCNBlock, self).__init__()

        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.padding = padding
        self.dilation = dilation
        self.depthwise = depthwise
        self.conditional = conditional

        groups = out_ch if depthwise and (in_ch % out_ch == 0) else 1

        self.conv1 = torch.nn.Conv1d(in_ch, 
                                     out_ch, 
                                     kernel_size=kernel_size, 
                                     padding=padding, 
                                     dilation=dilation,
                                     groups=groups,
                                     bias=False)
        self.conv2 = torch.nn.Conv1d(out_ch, 
                                     out_ch, 
                                     kernel_size=kernel_size, 
                                     padding=padding, 
                                     dilation=1,
                                     groups=groups,
                                     bias=False)

        if depthwise:
            self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
            self.conv2b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)

        self.bn1 = torch.nn.BatchNorm1d(in_ch)

        if conditional:
            self.film = FiLM(out_ch, 64)
        else:
            self.bn2 = torch.nn.BatchNorm1d(out_ch)

        self.relu1 = torch.nn.LeakyReLU()
        self.relu2 = torch.nn.LeakyReLU()

        self.res = torch.nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)

    def forward(self, x, p=None):
        x_in = x

        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv1(x)

        if self.depthwise: # apply pointwise conv
            x = self.conv1b(x)

        if p is not None:   # apply FiLM conditioning
            x = self.film(x, p)
        else:
            x = self.bn2(x)

        x = self.relu2(x)
        x = self.conv2(x)

        if self.depthwise:
            x = self.conv2b(x)

        x_res = self.res(x_in)
        x = x + center_crop(x_res, x.shape)

        return x

class TCNModel(pl.LightningModule):
    """ Temporal convolutional network with conditioning module.

        Params:
            nparams (int): Number of conditioning parameters.
            ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
            ninputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
            nblocks (int): Number of total TCN blocks. Default: 10
            kernel_size (int): Width of the convolutional kernels. Default: 3
            dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
            channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
            channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
            stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
            depthwise (bool): Use depthwise-separable convolutions to reduce the total number of parameters. Default: False
        """
    def __init__(self, 
                 nparams,
                 ninputs=1,
                 noutputs=1,
                 nblocks=10, 
                 kernel_size=3, 
                 dilation_growth=1, 
                 channel_growth=1, 
                 channel_width=64, 
                 stack_size=10,
                 depthwise=False,
                 **kwargs):
        super(TCNModel, self).__init__()

        self.save_hyperparameters()

        # setup loss functions
        self.l1      = torch.nn.L1Loss()
        self.esr     = auraloss.time.ESRLoss()
        self.dc      = auraloss.time.DCLoss()
        self.logcosh = auraloss.time.LogCoshLoss()
        self.stft    = auraloss.freq.STFTLoss()
        self.mrstft  = auraloss.freq.MultiResolutionSTFTLoss()
        self.rrstft  = auraloss.freq.RandomResolutionSTFTLoss()

        if nparams > 0:
            self.gen = torch.nn.Sequential(
                torch.nn.Linear(nparams, 16),
                torch.nn.PReLU(),
                torch.nn.Linear(16, 32),
                torch.nn.PReLU(),
                torch.nn.Linear(32, 64),
                torch.nn.PReLU()
            )

        self.blocks = torch.nn.ModuleList()
        for n in range(nblocks):
            in_ch = out_ch if n > 0 else ninputs
            out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width

            dilation = dilation_growth ** (n % stack_size)
            self.blocks.append(TCNBlock(in_ch, 
                                        out_ch, 
                                        kernel_size=kernel_size, 
                                        dilation=dilation,
                                        depthwise=self.hparams.depthwise,
                                        conditional=True if nparams > 0 else False))

        self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)

    def forward(self, x, p=None):
        # if parameters present, 
        # compute global conditioning
        if p is not None:
            cond = self.gen(p)
        else:
            cond = None

        # iterate over blocks passing conditioning
        for block in self.blocks:
            x = block(x, cond)

        return self.output(x)

    def compute_receptive_field(self):
        """ Compute the receptive field in samples."""
        rf = self.hparams.kernel_size
        for n in range(1,self.hparams.nblocks):
            dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
            rf = rf + ((self.hparams.kernel_size-1) * dilation)
            rf = rf + ((self.hparams.kernel_size-1) * 1)
        return rf

    def training_step(self, batch, batch_idx):
        input, target, params = batch

        # pass the input thrgouh the mode
        pred = self(input, params)

        # crop the target signal 
        target = center_crop(target, pred.shape)

        # compute the error using appropriate loss      
        if   self.hparams.train_loss == "l1":
            loss = self.l1(pred, target)
        elif self.hparams.train_loss == "esr+dc":
            loss = self.esr(pred, target) + self.dc(pred, target)
        elif self.hparams.train_loss == "logcosh":
            loss = self.logcosh(pred, target)
        elif self.hparams.train_loss == "stft":
            loss = torch.stack(self.stft(pred, target),dim=0).sum()
        elif self.hparams.train_loss == "mrstft":
            loss = torch.stack(self.mrstft(pred, target),dim=0).sum()
        elif self.hparams.train_loss == "rrstft":
            loss = torch.stack(self.rrstft(pred, target),dim=0).sum()
        else:
            raise NotImplementedError(f"Invalid loss fn: {self.hparams.train_loss}")

        self.log('train_loss', 
                 loss, 
                 on_step=True, 
                 on_epoch=True, 
                 prog_bar=True, 
                 logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input, target, params = batch

        # pass the input thrgouh the mode
        pred = self(input, params)

        # crop the target signal 
        target = center_crop(target, pred.shape)

        # compute the validation error using all losses
        l1_loss      = self.l1(pred, target)
        esr_loss     = self.esr(pred, target)
        dc_loss      = self.dc(pred, target)
        logcosh_loss = self.logcosh(pred, target)
        stft_loss    = torch.stack(self.stft(pred, target),dim=0).sum()
        mrstft_loss  = torch.stack(self.mrstft(pred, target),dim=0).sum()
        rrstft_loss  = torch.stack(self.rrstft(pred, target),dim=0).sum()

        aggregate_loss = l1_loss + \
                         esr_loss + \
                         dc_loss + \
                         logcosh_loss + \
                         mrstft_loss + \
                         stft_loss + \
                         rrstft_loss

        self.log('val_loss', aggregate_loss)
        self.log('val_loss/L1', l1_loss)
        self.log('val_loss/ESR', esr_loss)
        self.log('val_loss/DC', dc_loss)
        self.log('val_loss/LogCosh', logcosh_loss)
        self.log('val_loss/STFT', stft_loss)
        self.log('val_loss/MRSTFT', mrstft_loss)
        self.log('val_loss/RRSTFT', rrstft_loss)

        # move tensors to cpu for logging
        outputs = {
            "input" : input.cpu().numpy(),
            "target": target.cpu().numpy(),
            "pred"  : pred.cpu().numpy()}

        return outputs

    def validation_epoch_end(self, validation_step_outputs):
        # flatten the output validation step dicts to a single dict
        outputs = res = {k: v for d in validation_step_outputs for k, v in d.items()} 
        
        i = outputs["input"][0].squeeze()
        c = outputs["target"][0].squeeze()
        p = outputs["pred"][0].squeeze()

        # log audio examples
        self.logger.experiment.add_audio("input", i, self.global_step, sample_rate=self.hparams.sample_rate)
        self.logger.experiment.add_audio("target", c, self.global_step, sample_rate=self.hparams.sample_rate)
        self.logger.experiment.add_audio("pred",   p, self.global_step, sample_rate=self.hparams.sample_rate)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    # add any model hyperparameters here
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # --- model related ---
        parser.add_argument('--ninputs', type=int, default=1)
        parser.add_argument('--noutputs', type=int, default=1)
        parser.add_argument('--nblocks', type=int, default=10)
        parser.add_argument('--kernel_size', type=int, default=3)
        parser.add_argument('--dilation_growth', type=int, default=1)
        parser.add_argument('--channel_growth', type=int, default=1)
        parser.add_argument('--channel_width', type=int, default=64)
        parser.add_argument('--stack_size', type=int, default=10)
        parser.add_argument('--depthwise', default=False, action='store_true')
        # --- training related ---
        parser.add_argument('--lr', type=float, default=1e-3)
        parser.add_argument('--train_loss', type=str, default="l1")

        return parser

# Training

In [None]:
# add PROGRAM level args
#root_dir = '/content/drive/My Drive/SignalTrain_LA2A_Dataset_1.1'
root_dir = '/content/data/SignalTrain_LA2A_Dataset_1.1'
sample_rate = 44100
shuffle = True
train_subset = "train"
val_subset = "val"
train_length = 16384
eval_length = 262144
batch_size = 128
num_workers = 0
preload = False
precision = 16

# init the trainer and model 
trainer = pl.Trainer(gpus=1, precision=precision)

# setup the dataloaders
train_dataset = SignalTrainLA2ADataset(root_dir, 
                                      subset=train_subset,
                                      length=train_length,
                                      preload=preload)

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               shuffle=shuffle,
                                               batch_size=batch_size,
                                               num_workers=num_workers)

val_dataset = SignalTrainLA2ADataset(root_dir, 
                                    subset=val_subset,
                                    length=eval_length,
                                    preload=preload)

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                             shuffle=shuffle,
                                             batch_size=batch_size,
                                             num_workers=num_workers)

dict_args = {
      "nparams" : 2,
      "ninputs" : 1,
      "noutputs" : 1,
      "nblocks" : 10, 
      "kernel_size": 3, 
      "dilation_growth" : 1, 
      "channel_growth" : 1, 
      "channel_width" : 64, 
      "stack_size" : 10,
      "depthwise" : False,
      "lr" : 0.001,
      "sample_rate" : sample_rate,
      "train_loss" : "mrstft"
}
model = TCNModel(**dict_args)

# find proper learning rate
#trainer.tune(model, train_dataloader)

#torchsummary.summary(model, [(1,eval_length), (1,2)])

device = "cuda:0"
model.stft.to(device)

# train!
trainer.fit(model, train_dataloader, val_dataloader)


In [None]:
 # Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/