In [2]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

from efficientnet_pytorch.model import EfficientNet
from kymatio.torch import TimeFrequencyScattering1D
from kymjtfs.batch_norm import ScatteringBatchNorm

class MedleySolosClassifier(LightningModule):
    def __init__(self, in_shape = 2**17, J = 12, Q = 16, F = 4, T = 2**12):
        super().__init__()

        self.in_shape = in_shape
        self.J = J
        self.Q = Q
        self.F = F
        self.T = T
        
        self.s1_conv1 = nn.Conv2d(1, 4, kernel_size=(16, 1)).cuda()
        
        self.setup_jtfs()
        
        self.conv_net = EfficientNet.from_name('efficientnet-b0',
                                               in_channels=self.jtfs_dim,
                                               include_top = True,
                                               num_classes = 8).cuda()

    def setup_jtfs(self):
        self.jtfs = TimeFrequencyScattering1D(
            shape=(self.in_shape, ),
            T=self.T,
            Q=self.Q,
            J=self.J,
            F=self.F,
            average_fr=True,
            max_pad_factor=1, 
            max_pad_factor_fr=1,
            out_3D=True,).cuda()
        
        n_channels = self._get_jtfs_out_dim()
        
        self.jtfs_dim = self._get_jtfs_out_dim()
        self.jtfs_bn = ScatteringBatchNorm(self.jtfs_dim).cuda()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def forward(self, x):
        Sx = self.jtfs(x)
        
        s1, s2 = Sx[0], Sx[1]
        s1 = self._s1_forward(s1)
        s1 = F.pad(s1, 
                   (0, 0, s2.shape[-2] - s1.shape[-2], 0))
        
        sx = torch.cat([s1, s2], dim=1)[:, :, :32, :]
        sx = self.jtfs_bn(sx)
        y = self.conv_net(sx)
        y = F.log_softmax(y, dim=1)
        return y
    
    def configure_optimizers(self):
        opt = Adam(self.parameters(), lr=1e-3)
        scheduler = CosineAnnealingLR(opt, T_max=10)
        return [opt], [scheduler]
    
    def _get_jtfs_out_dim(self):
        dummy_in = torch.randn(self.in_shape).cuda()
        sx = self.jtfs(dummy_in)
        s1 = self._s1_forward(sx[0])
        s1 = F.pad(s1, (0, 0, sx[1].shape[-2] - s1.shape[-2], 0))
        S = torch.cat([s1, sx[1]], dim=1)
        out_dim = S.size(1)
        return out_dim
        
    def _s1_forward(self, s1):
        s1 = s1.unsqueeze(1)
        return F.avg_pool2d(F.relu(self.s1_conv1(s1)), 
                            kernel_size=(4, 1), 
                            padding=(2, 0))

In [11]:
from torch.utils.data import Dataset, DataLoader

class MedleySolosDB(Dataset):
    def __init__(self, csv_path='/import/c4dm-datasets/', fold='training'):
        super().__init__()

        path = os.path.join(csv_path, 'Medley-solos-DB_metadata.csv')
        df = pd.read_csv(path)
        df = df.where(df['subset'] == fold)

    def __getitem__(self):
#                 y = 'instrument_id'
       raise NotImplementedError

    def __len__(self):
        raise NotImplementedError
        

class MedleyDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        self.train_ds = MedleySolosDB(fold='training')
        self.val_ds = MedleySolosDB(fold='validation')
        self.test_ds = MedleySolosDB(fold='test')

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size)

    def teardown(self, stage: Optional[str] = None):
        # Used to clean-up when the run is finished
        ...

In [3]:
import os
from pytorch_lightning import Trainer

train_loader = MedleyDBSolos()
train_ds = DataLoader(train_loader, batch_size=64)

NameError: name 'MedleyDBSolos' is not defined

In [4]:
model = MedleySolosClassifier()

In [5]:
model(torch.randn(4, 2**17))

tensor([[-2.4746, -2.0908, -2.0086, -1.9890, -2.5728, -2.0200, -1.5243, -2.3595],
        [-2.0349, -2.1241, -2.3771, -2.2438, -2.6342, -1.5306, -1.7836, -2.3573],
        [-2.4869, -1.9968, -1.4069, -2.1504, -2.2895, -2.4190, -1.9807, -2.3922],
        [-1.8832, -2.3211, -1.8516, -1.9941, -2.3137, -1.9677, -2.8965, -1.8157]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward0>)

In [7]:
import pandas as pd
path = '/import/c4dm-datasets/Medley-solos-DB_metadata.csv'


Unnamed: 0,subset,instrument,instrument_id,track_id,uuid4
0,test,clarinet,0,0,0e4371ac-1c6a-51ab-fdb7-f8abd5fbf1a3
1,test,clarinet,0,0,33383119-fd64-59c1-f596-d1a23e8a0eff
2,test,clarinet,0,0,b2b7a288-e169-5642-fced-b509c06b11fc
3,test,clarinet,0,0,151b6ee4-313a-58d9-fbcb-bab73e0d426b
4,test,clarinet,0,0,b43999d1-9b5e-557f-f9bc-1b3759659858
...,...,...,...,...,...
21566,validation,violin,7,226,fe4e8e98-6e0f-5a31-f446-99c10e0ac485
21567,validation,violin,7,226,aa606c78-9ee5-507f-f7e9-67c3530faf0f
21568,validation,violin,7,226,05e15c0a-d530-5f3e-fa82-58c55fa44993
21569,validation,violin,7,226,2dd485de-471d-5d8b-fe92-ef957dac021c


In [11]:
df = pd.read_csv(path)
df['subset'].value_counts()

test          12236
training       5841
validation     3494
Name: subset, dtype: int64