In [2]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

from efficientnet_pytorch import EfficientNet
from kymatio.torch import TimeFrequencyScattering1D
from kymjtfs.batch_norm import ScatteringBatchNorm

class MedleySolosClassifier(LightningModule):
    def __init__(self, in_shape = 2**17, J = 12, Q = 16, F = 4, T = 2**12):
        super().__init__()

        self.in_shape = in_shape
        self.J = J
        self.Q = Q
        self.F = F
        self.T = T
        
        self.s1_conv1 = nn.Conv2d(1, 4, kernel_size=(16, 1))
        
        self.setup_jtfs()
        
        self.conv_net = EfficientNet(global_params = {'include_top': True, 'num_classes': 8})

    def setup_jtfs(self):
        self.jtfs = TimeFrequencyScattering1D(
            shape=(self.in_shape, ),
            T=self.T,
            Q=self.Q,
            J=self.J,
            F=self.F,
            average_fr=True,
            max_pad_factor=1, 
            max_pad_factor_fr=1,
            out_3D=True,)
        
        n_channels = self._get_jtfs_out_dim()
        
        out_dim = self._get_jtfs_out_dim()
        self.jtfs_bn = ScatteringBatchNorm(out_dim)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def forward(self, x):
        Sx = self.jtfs(x)
        
        s1, s2 = Sx[0], Sx[1]
        s1 = F.pad(self._s1_forward(s1), 
                   (0, 0, s2[1].shape[-2] - s1.shape[-2], 0))
        
        sx = torch.cat([s1, s2], dim=1)[:, :, :32, :]
        sx = self.jtfs_bn(sx)
        y = self.conv_net(sx)
        y = F.log_softmax(y, dim=1)
        return y
    
    def configure_optimizers(self):
        opt = Adam(self.parameters(), lr=1e-3)
        scheduler = CosineAnnealingLR(opt, T_max=10)
        return [opt], [scheduler]
    
    def _get_jtfs_out_dim(self):
        dummy_in = torch.randn(self.in_shape)  
        sx = self.jtfs(dummy_in)
        s1 = self._s1_forward(sx[0])
        s1 = F.pad(s1, (0, 0, sx[1].shape[-2] - s1.shape[-2], 0))
        S = torch.cat([s1, Sx[1]], dim=1)
        out_dim = S.size(1)
        return out_dim
        
    def _s1_forward(self, s1):
        s1 = s1.unsqueeze(1)
        return F.avg_pool2d(F.relu(self.s1_conv1(s1)), 
                            kernel_size=(4, 1), 
                            padding=(2, 0))

ModuleNotFoundError: No module named 'kymjtfs'

In [None]:
from torch.utils.data import Dataset, DataLoader

class MedleySolosDB(Dataset):
    def __init__(self, csv_path):
        super().__init__()

    def __getitem__(self):
       raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

In [None]:
import os
from pytorch_lightning import Trainer

train_loader = MedleyDBSolos()
train_ds = DataLoader(train_loader, batch_size=64)

In [None]:
MedleySolosClassifier()