In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler
import pytorch_lightning as pl
from torchmetrics.functional import accuracy, matthews_corrcoef
from torchsummaryX import summary
from utils import sound_utils

import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

In [2]:
batch_size = 64
epochs = 20
learing_rate = 1e-4
eps = 1e-8

# Load and prepare dataloader

In [59]:
df = pd.read_csv('data/train_test_split.csv')

In [61]:
class WavDataset(Dataset):
    def __init__(self, wave_list):
        super(WavDataset, self).__init__()
        self.wav_list = wave_list
        self.labels_index = {
            'cat': 0,
            'dog': 1
        }
        self.data_path = 'data/cats_dogs/'
        self.max_ms = 8000
        self.audio_utils = sound_utils.SoundUtil()
        
    def __len__(self):
        return len(self.wav_list)
    
    def __getitem__(self, index):
        wav_file = self.wav_list[index]
        wav_file_path = os.path.join(self.data_path, wav_file)
        
        labels = self.labels_index[wav_file[:3]]
        
        return self.audio_utils.convert_sound(wav_file_path, is_augment=True)

In [None]:
class WavDataModule(pl.LightningDataModule):
    def __init__(self, df, batch_size=64):
        self.df = df
        self.batch_size = batch_size
        
    def setup(self):
        self.df
        self.train_set = WavDataset()
    
    def train_dataloader(self):
        data_loader = DataLoader()

# Build Model

In [50]:
class AudioCRNN(pl.LightningModule):
    def __init__(self):
        super(AudioCRNN, self).__init__()
        
        self.feautre_extract = nn.Sequential(
            nn.Conv2d(2, 32, 3),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.1),
            
            # Second level
            nn.Conv2d(32, 64, 3),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.1),
            
            # Third Level
            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.1)
        )
        
        self.recurr = nn.GRU(128, 64, 2, bidirectional=True, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.BatchNorm1d(128),
            nn.Linear(128, 2)
        )
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.feautre_extract(x)
        
        x = x.reshape(-1, 6 * 52, 128)
        x, _ = self.recurr(x)
        x = self.classifier(x[:, -1, :])

        return x
    
    def training_step(self, batch, batch_idx):
        spec, labels = batch
        logits = self(spec)
        return self.loss(logits, labels)
    
    def validation_step(self, batch, batch_idx):
        spec, labels = batch
        logits = self(spec)
        val_loss = self.loss(logits, labels)
        self.log('val_loss', loss)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, eps=1e-8)

In [51]:
model = AudioCRNN()

In [57]:
summary(model, torch.zeros(2, 2, 64, 430))

                                      Kernel Shape       Output Shape  \
Layer                                                                   
0_feautre_extract.Conv2d_0           [2, 32, 3, 3]   [2, 32, 62, 428]   
1_feautre_extract.BatchNorm2d_1               [32]   [2, 32, 62, 428]   
2_feautre_extract.ELU_2                          -   [2, 32, 62, 428]   
3_feautre_extract.MaxPool2d_3                    -   [2, 32, 31, 214]   
4_feautre_extract.Dropout_4                      -   [2, 32, 31, 214]   
5_feautre_extract.Conv2d_5          [32, 64, 3, 3]   [2, 64, 29, 212]   
6_feautre_extract.BatchNorm2d_6               [64]   [2, 64, 29, 212]   
7_feautre_extract.ELU_7                          -   [2, 64, 29, 212]   
8_feautre_extract.MaxPool2d_8                    -   [2, 64, 14, 106]   
9_feautre_extract.Dropout_9                      -   [2, 64, 14, 106]   
10_feautre_extract.Conv2d_10       [64, 128, 3, 3]  [2, 128, 12, 104]   
11_feautre_extract.BatchNorm2d_11            [128] 

Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_feautre_extract.Conv2d_0,"[2, 32, 3, 3]","[2, 32, 62, 428]",608.0,15284736.0
1_feautre_extract.BatchNorm2d_1,[32],"[2, 32, 62, 428]",64.0,32.0
2_feautre_extract.ELU_2,-,"[2, 32, 62, 428]",,
3_feautre_extract.MaxPool2d_3,-,"[2, 32, 31, 214]",,
4_feautre_extract.Dropout_4,-,"[2, 32, 31, 214]",,
5_feautre_extract.Conv2d_5,"[32, 64, 3, 3]","[2, 64, 29, 212]",18496.0,113319936.0
6_feautre_extract.BatchNorm2d_6,[64],"[2, 64, 29, 212]",128.0,64.0
7_feautre_extract.ELU_7,-,"[2, 64, 29, 212]",,
8_feautre_extract.MaxPool2d_8,-,"[2, 64, 14, 106]",,
9_feautre_extract.Dropout_9,-,"[2, 64, 14, 106]",,
