In [39]:
# Import the necessary libraries and modules

import pandas as pd
import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np
import pytorch_lightning as pl

from torch import Tensor
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS

In [44]:
class InsectData(Dataset):

    """
    Class to generate a dataset of insect sounds

    """

    def __init__(self, data: pd.DataFrame, transform: torch.nn.Module, num_classes: int):
        self.data = data
        self.transform = transform
        self.num_classes = num_classes
        # create lists to append data into
        self.wlen = []
        self.classes = []
        self.species = []
        self.family = []
        self.data_set = []
        self.path = []

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        sample_meta = self.data.iloc[idx]
        data_path = sample_meta.data_path
        species = sample_meta.species
        class_id = sample_meta.class_ID
        file_name = sample_meta.file_name
        data_set = sample_meta.data_set

        path = os.path.join(data_path, species, file_name)

        if not os.path.exists(path):
            raise FileNotFoundError(
                f'file not found: \'{path}\'.'
            )

        waveform, samplerate = torchaudio.load(path) # saves the wave and the frequency in two variable

        # method to append data into the lists
        self.wlen.append(waveform.shape[-1] / samplerate) # Tool to check trough the Wavefiles and mapping their length
        self.classes.append(class_id)
        self.species.append(species)
        self.family.append(data_path.split('/')[-1])
        self.data_set.append(data_set)
        self.path.append(path)


        spectrogram: Tensor = self.transform(waveform[0, :]) 

        species_one_hot: Tensor = torch.nn.functional.one_hot(
            torch.as_tensor(class_id, dtype=torch.long),
            num_classes=self.num_classes)

        return spectrogram, species_one_hot
    
    
class InsectDatamodule(pl.LightningDataModule):
    def __init__(
            self,
            csv_paths: list[str] | str,
            n_fft: int = 256,
            hop_length: int = 128,
            batch_size: int = 8):
        super().__init__()

        self.batch_size = batch_size

        csv_paths = [csv_paths] if isinstance(csv_paths, str) else csv_paths # if there is only one csv path passed, it creates a list

        csv_list = []

        for csv_path in csv_paths:
            if not os.path.exists(csv_path):
                raise FileNotFoundError(
                    f'`csv_path` does not exist: \'{csv_path}\'.'
                )

            csv = pd.read_csv(csv_path)
            data_path = csv_path.split('.csv')[0]
            csv['data_path'] = data_path

            csv_list.append(csv)

        csv = pd.concat(csv_list)

        self.class_IDs = sorted(csv.class_ID.unique())
        self.num_classes = len(self.class_IDs)

        if not os.path.exists(data_path):
            raise FileNotFoundError(
                f'`data_path` does not exist: \'{data_path}\'.'
            )

        self.csv = csv

        self.transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length)

    def train_dataloader(self) -> TRAIN_DATALOADERS: # Defines how the Train Dataloader is built

        csv = self.csv[self.csv.data_set == 'train']

        data_set = InsectData(
            data=csv, transform=self.transform, num_classes=self.num_classes)

        return DataLoader(data_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self) -> EVAL_DATALOADERS: # Defines how the Validation Dataloader is built

        csv = self.csv[self.csv.data_set == 'validation']

        data_set = InsectData(
            data=csv, transform=self.transform, num_classes=self.num_classes)

        return DataLoader(data_set, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self) -> EVAL_DATALOADERS: # Defines how the Test Dataloader is built

        csv = self.csv[self.csv.data_set == 'test']

        data_set = InsectData(
            data=csv, transform=self.transform, num_classes=self.num_classes)

        return DataLoader(data_set, batch_size=self.batch_size, shuffle=False)

    def predict_dataloader(self) -> EVAL_DATALOADERS: # Defines a Dataloader with all the Data

        csv = self.csv

        data_set = InsectData(
            data=csv, transform=self.transform, num_classes=self.num_classes)

        return DataLoader(data_set, batch_size=self.batch_size, shuffle=False)


In [45]:
datamodule = InsectDatamodule(csv_paths=['../data/Cicadidae.csv', '../data/Orthoptera.csv'], batch_size=1)
predict_dataloader = datamodule.predict_dataloader()
for x, y in predict_dataloader:
    pass


In [47]:
predict_dataloader.dataset.path[0]

'../data/Cicadidae\\Myopsaltamelanobasis\\Myopsaltamelanobasis_Myopsalta_melanobasis_Brigalow.wav'

In [42]:
predict_dataloader.dataset.family[0]

'Cicadidae'

In [24]:
meta_data = pd.DataFrame({
    'family': predict_dataloader.dataset.family,
    'species': predict_dataloader.dataset.species,
    'class_id': predict_dataloader.dataset.classes,
    'data_set': predict_dataloader.dataset.data_set,
    'file_length': predict_dataloader.dataset.wlen
})

In [25]:
meta_data.head(10)

Unnamed: 0,family,species,class_id,data_set,file_length
0,Cicadidae,Myopsaltamelanobasis,9,test,3.267483
1,Cicadidae,Platypleurasp10,24,train,3.242449
2,Cicadidae,Platypleurasp13,27,validation,10.0
3,Cicadidae,Platypleuraintercapedinis,21,train,10.0
4,Cicadidae,Myopsaltaxerograsidia,10,train,10.0
5,Cicadidae,Platypleuracfcatenata,15,test,10.0
6,Cicadidae,Brevisianabrevis,1,test,10.0
7,Cicadidae,Myopsaltaxerograsidia,10,test,10.0
8,Cicadidae,Brevisianabrevis,1,train,10.0
9,Cicadidae,Platypleuradeusta,17,train,10.0


In [23]:
meta_data[meta_data['data_set'] == 'train'].sort_values('file_length', ascending=True).head(15)

Unnamed: 0,family,species,class_id,data_set,file_length
224,Orthoptera,Roeselianaroeselii,30,train,1.004898
47,Cicadidae,Platypleurasp10,24,train,1.478957
212,Orthoptera,Nemobiussylvestris,11,train,1.557324
256,Orthoptera,Roeselianaroeselii,30,train,2.008889
227,Orthoptera,Roeselianaroeselii,30,train,2.413741
219,Orthoptera,Tettigoniaviridissima,31,train,2.551134
228,Orthoptera,Pholidopteragriseoaptera,13,train,2.671837
229,Orthoptera,Pseudochorthippusparallelus,28,train,2.756417
292,Orthoptera,Roeselianaroeselii,30,train,2.972018
1,Cicadidae,Platypleurasp10,24,train,3.242449


In [60]:
def play_audio(idx: int):
    waveform, sample_rate = torchaudio.load(predict_dataloader.dataset.path[idx])
    play = Audio(waveform.numpy()[0], rate=sample_rate, loop = True)
    return play

In [62]:
waveform, sample_rate = torchaudio.load(predict_dataloader.dataset.path[47])
Audio(waveform.numpy()[0], rate=sample_rate, loop = True)

TypeError: Audio.__init__() got an unexpected keyword argument 'loop'

In [7]:
meta_data[meta_data['family'] == 'Orthoptera'].groupby('class_id')['file_length'].sum()

class_id
2     204.645578
3     135.870680
4     218.979683
11    497.706259
12    269.056961
13    114.017868
28    108.383084
30     47.401814
31     94.290703
Name: file_length, dtype: float64

In [8]:
meta_data[(meta_data['family'] == 'Orthoptera') & (meta_data['data_set'] == 'train')].groupby('class_id')['file_length'].sum()

class_id
2     125.726145
3      90.228798
4     175.820136
11    341.361338
12    145.233197
13     70.643039
28     63.059683
30     43.335170
31     51.306327
Name: file_length, dtype: float64

In [9]:
meta_data[meta_data['family'] == 'Cicadidae'].groupby('class_id')['file_length'].sum()


class_id
0      40.000000
1      50.000000
5      60.000000
6      70.000000
7      40.000000
8      67.839320
9      43.267483
10     60.000000
14     60.000000
15    213.654739
16     66.467347
17     83.461678
18     60.000000
19     50.000000
20     53.992063
21     50.000000
22    190.000000
23     80.000000
24    144.721406
25     40.000000
26    100.000000
27    120.000000
29     90.000000
Name: file_length, dtype: float64

In [10]:
meta_data[(meta_data['family'] == 'Orthoptera') & (meta_data['data_set'] == 'train')].groupby('class_id')['file_length'].sum()

class_id
2     125.726145
3      90.228798
4     175.820136
11    341.361338
12    145.233197
13     70.643039
28     63.059683
30     43.335170
31     51.306327
Name: file_length, dtype: float64

In [11]:
meta_data[(meta_data['family'] == 'Orthoptera') & (meta_data['data_set'] == 'train')].groupby('class_id')['file_length'].count()

class_id
2     11
3      9
4     16
11    13
12     8
13    10
28     9
30    10
31     9
Name: file_length, dtype: int64

In [16]:
output = meta_data.groupby(['family', 'class_id', 'data_set'])['file_length'].sum().round(0)
# output.to_csv('output.csv')


In [13]:
output = meta_data.groupby(['family', 'class_id', 'data_set']).agg(
    file_length_sum=('file_length', 'sum'),
    file_count=('file_length', 'count'),
    file_lengths=('file_length', lambda x: sorted(list(round(i, 1) for i in x))),
    shortest_file_length=('file_length', 'min'),
    longest_file_length=('file_length', 'max')
)
output['file_length_sum'] = output['file_length_sum'].round(0)
output['shortest_file_length'] = output['shortest_file_length'].round(1)
output['longest_file_length'] = output['longest_file_length'].round(1)
output.to_csv('metadata.csv')

In [14]:
metadata = pd.read_csv('metadata.csv')

In [15]:
metadata[metadata['data_set'] == "train"].sort_values('shortest_file_length')

Unnamed: 0,family,class_id,data_set,file_length_sum,file_count,file_lengths,shortest_file_length,longest_file_length
90,Orthoptera,30,train,43.0,10,"[1.0, 2.0, 2.4, 3.0, 3.4, 3.6, 4.1, 4.4, 7.9, ...",1.0,11.5
55,Cicadidae,24,train,95.0,11,"[1.5, 3.2, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0,...",1.5,10.0
79,Orthoptera,11,train,341.0,13,"[1.6, 4.3, 4.5, 5.5, 5.8, 8.1, 8.7, 9.9, 24.8,...",1.6,167.2
93,Orthoptera,31,train,51.0,9,"[2.6, 3.7, 4.3, 5.2, 5.6, 6.7, 6.7, 7.1, 9.4]",2.6,9.4
85,Orthoptera,13,train,71.0,10,"[2.7, 3.8, 4.2, 4.2, 5.7, 6.9, 7.5, 9.3, 9.5, ...",2.7,17.0
88,Orthoptera,28,train,63.0,9,"[2.8, 3.6, 3.8, 4.7, 4.8, 5.0, 5.6, 12.2, 20.6]",2.8,20.6
34,Cicadidae,17,train,53.0,6,"[3.5, 10.0, 10.0, 10.0, 10.0, 10.0]",3.5,10.0
70,Orthoptera,2,train,126.0,11,"[3.6, 7.1, 8.9, 9.0, 9.9, 10.6, 12.5, 13.7, 14...",3.6,18.5
76,Orthoptera,4,train,176.0,16,"[3.7, 3.9, 6.3, 6.8, 6.9, 7.7, 8.0, 8.5, 10.4,...",3.7,29.5
28,Cicadidae,15,train,144.0,15,"[3.7, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0...",3.7,10.0
