In [60]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from efficientnet_pytorch.model import EfficientNet
from kymatio.torch import TimeFrequencyScattering1D

from kymjtfs.batch_norm import ScatteringBatchNorm

class MedleySolosClassifier(LightningModule):
    def __init__(self, in_shape = 2**16, J = 12, Q = 16, F = 4, T = 2**11, lr=1e-3):
        super().__init__()

        self.in_shape = in_shape
        self.J = J
        self.Q = Q
        self.F = F
        self.T = T
        
        self.lr = lr
        
        self.s1_conv1 = nn.Conv2d(1, 4, kernel_size=(16, 1)).cuda()
        
        self.setup_jtfs()
        
        self.conv_net = EfficientNet.from_name('efficientnet-b0',
                                               in_channels=self.jtfs_dim,
                                               include_top = True,
                                               num_classes = 8).cuda()
        
        self.acc_metric = torchmetrics.Accuracy()

    def setup_jtfs(self):
        self.jtfs = TimeFrequencyScattering1D(
            shape=(self.in_shape, ),
            T=self.T,
            Q=self.Q,
            J=self.J,
            F=self.F,
            average_fr=True,
            max_pad_factor=1, 
            max_pad_factor_fr=1,
            out_3D=True,).cuda()
        
        n_channels = self._get_jtfs_out_dim()
        
        self.jtfs_dim = self._get_jtfs_out_dim()
        self.jtfs_bn = ScatteringBatchNorm(self.jtfs_dim).cuda()
        
    def forward(self, x):
        Sx = self.jtfs(x)
        
        s1, s2 = Sx[0], Sx[1]
        s1 = self._s1_forward(s1)
        s1 = F.pad(s1, 
                   (0, 0, s2.shape[-2] - s1.shape[-2], 0))
        
        sx = torch.cat([s1, s2], dim=1)[:, :, :32, :]
        sx = self.jtfs_bn(sx)
        y = self.conv_net(sx)
        y = F.log_softmax(y, dim=1)
        return y
        
    def step(self, batch):
        x, y = batch
        logits = self(x)

        loss, acc = F.nll_loss(logits, y), self.acc_metric(logits, y)
        return {'loss': loss, 'acc': acc}
    
    def log_metrics(self, outputs, fold):
        keys = list(outputs[0].keys())
        for k in keys:
            metric = torch.stack([x[k] for x in outputs]).mean()
            self.log_metric(f'{fold}/{k}', metric)
        
    def training_step(self, batch, batch_idx):
        return self.step(batch)
    
    def training_epoch_end(self, outputs):
        self.log_metrics(outputs, 'train')
        
    def validation_step(self, batch, batch_idx):
        return self.step(batch)
    
    def validation_epoch_end(self, outputs):
        self.log_metrics(outputs, 'val')

    def test_step(self, batch, batch_idx):
        return self.step(batch)
    
    def test_epoch_end(self, outputs):
        self.log_metrics(outputs, 'test')
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt
    
    def _get_jtfs_out_dim(self):
        dummy_in = torch.randn(self.in_shape).cuda()
        sx = self.jtfs(dummy_in)
        s1 = self._s1_forward(sx[0])
        s1 = F.pad(s1, (0, 0, sx[1].shape[-2] - s1.shape[-2], 0))
        S = torch.cat([s1, sx[1]], dim=1)
        out_dim = S.size(1)
        return out_dim
        
    def _s1_forward(self, s1):
        s1 = s1.unsqueeze(1)
        return F.avg_pool2d(F.relu(self.s1_conv1(s1)), 
                            kernel_size=(4, 1), 
                            padding=(2, 0))

### Data Loader

In [75]:
from typing import Optional

from torch.utils.data import Dataset, DataLoader
import mirdata.datasets.medley_solos_db as msdb
import pytorch_lightning as pl

class MedleySolosDB(Dataset):
    def __init__(self, data_dir='/import/c4dm-datasets/medley-solos-db/', subset='training'):
        super().__init__()
        
        self.msdb = msdb.Dataset(data_dir)
        self.audio_dir = os.path.join(data_dir, 'audio')
        self.csv_dir = os.path.join(data_dir, 'annotation')
        self.subset = subset
        
        df = pd.read_csv(os.path.join(self.csv_dir, 'Medley-solos-DB_metadata.csv'))
        self.df = df.loc[df['subset'] == subset]
        self.df.reset_index(inplace = True)
        
    def build_audio_fname(self, df_item):
        uuid = df_item['uuid4']
        instr_id = df_item['instrument_id']
        subset = df_item['subset']
        return f'Medley-solos-DB_{subset}-{instr_id}_{uuid}.wav'

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        audio_fname = self.build_audio_fname(item)
        audio, _ = msdb.load_audio(os.path.join(self.audio_dir, audio_fname))
        y = int(item['instrument_id'])
        
        return audio, y

    def __len__(self):
        return len(self.df)
        

class MedleyDataModule(pl.LightningDataModule):
    def __init__(self, 
                 data_dir: str = '/import/c4dm-datasets/medley-solos-db/', 
                 batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        self.train_ds = MedleySolosDB(self.data_dir, subset='training')
        self.val_ds = MedleySolosDB(self.data_dir, subset='validation')
        self.test_ds = MedleySolosDB(self.data_dir, subset='test')

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False)

    def teardown(self, stage: Optional[str] = None):
        # Used to clean-up when the run is finished
        ...

In [46]:
import pandas as pd
data_dir = '/import/c4dm-datasets/medley-solos-db/' 
csv_dir = os.path.join(data_dir, 'annotation')
df = pd.read_csv(os.path.join(csv_dir, 'Medley-solos-DB_metadata.csv'))
df.loc[df['subset'] == 'test']

Unnamed: 0,subset,instrument,instrument_id,song_id,uuid4
0,test,clarinet,0,0,0e4371ac-1c6a-51ab-fdb7-f8abd5fbf1a3
1,test,clarinet,0,0,33383119-fd64-59c1-f596-d1a23e8a0eff
2,test,clarinet,0,0,b2b7a288-e169-5642-fced-b509c06b11fc
3,test,clarinet,0,0,151b6ee4-313a-58d9-fbcb-bab73e0d426b
4,test,clarinet,0,0,b43999d1-9b5e-557f-f9bc-1b3759659858
...,...,...,...,...,...
12231,test,violin,7,138,508f5f17-ab4e-5701-fd56-d9b77f10b877
12232,test,violin,7,138,5c2fc205-dd93-57e2-ffe9-1390d86c5a42
12233,test,violin,7,138,d6131a7f-7823-5202-f4e8-50b07071133c
12234,test,violin,7,138,64eb5555-a916-5458-f19b-03a908ec7122


### Mock Test

In [51]:
import os
from pytorch_lightning import Trainer

train_loader = MedleySolosDB()
train_ds = DataLoader(train_loader, batch_size=64)
model = MedleySolosClassifier()
model(torch.randn(4, 2**16))

### Train and Eval

In [76]:
n_epochs = 200
early_stop_callback = EarlyStopping(monitor="val/loss", 
                                    min_delta=0.00, 
                                    patience=3, 
                                    verbose=False, 
                                    mode="max")
trainer = pl.Trainer(gpus=-1, 
                     max_epochs=n_epochs,
                     progress_bar_refresh_rate=1, 
                     checkpoint_callback=True,
                     callbacks=[early_stop_callback])
model, dataset = MedleySolosClassifier(), MedleyDataModule() 
trainer.fit(model, dataset)
trainer.test()

  rank_zero_deprecation(
  rank_zero_deprecation(
INFO: GPU available: True, used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [7]
INFO: 
  | Name       | Type                      | Params
---------------------------------------------------------
0 | s1_conv1   | Conv2d                    | 68    
1 | jtfs       | TimeFrequencyScattering1D | 0     
2 | jtfs_bn    | ScatteringBatchNorm       | 238   
3 | conv_net   | EfficientNet              | 4.1 M 
4 | acc_metric | Accuracy                  | 0     
---------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.343    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


RuntimeError: CUDA out of memory. Tried to allocate 8.00 GiB (GPU 0; 23.69 GiB total capacity; 18.03 GiB already allocated; 3.51 GiB free; 18.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF