In [48]:
from birdset.datamodule.base_datamodule import DatasetConfig
from birdset.datamodule.birdset_datamodule import BirdSetDataModule
from birdset.datamodule.pretrain_datamodule import PretrainDataModule, BirdSetTransformsWrapper

import matplotlib.pyplot as plt

### Load focal pretraining dataset (XCL) from existing data directory on cluster

In [50]:
XCL_PATH = '/mnt/datasets/bird_recordings/birdset_hf_download/XCL/XCL_processed_3_e2056aa472804029'

In [51]:
dm = PretrainDataModule(
    dataset= DatasetConfig(
        data_dir=XCL_PATH,
        dataset_name='XCL',
        hf_path='DBD-research-group/BirdSet',
        hf_name='XCL',
        n_workers=3,
        val_split=0.2,
        task="multilabel",
        classlimit=500,
        eventlimit=5,
        sampling_rate=32000,
        direct_fingerprint=XCL_PATH
    ),
    transforms=BirdSetTransformsWrapper(
        model_type='waveform'
    )
)

In [52]:
dm.prepare_data()
dm.setup('fit')

Loading dataset from disk:   0%|          | 0/240 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/240 [00:00<?, ?it/s]

In [63]:
from IPython.display import Audio

sample = dm.train_dataset[0]
audio = sample['input_values']
label = sample['labels']

print(f'Audio shape {audio.shape}')
print(f'Label shape {label.shape}')

Audio(audio, rate=32000)

Audio shape torch.Size([1, 160000])
Label shape torch.Size([9736])


### Load Finetuning dataset (HSN focal)

In [75]:
dm = BirdSetDataModule(
    dataset= DatasetConfig(
        data_dir='../data_birdset/HSN',
        dataset_name='HSN',
        hf_path='DBD-research-group/BirdSet',
        hf_name='HSN',
        n_workers=3,
        val_split=0.2,
        task="multilabel",
        classlimit=500,
        eventlimit=5,
        sampling_rate=32000,
    ),
    transforms=BirdSetTransformsWrapper(
        model_type='waveform'
    )
)

In [86]:
# Train (focal)
dm.prepare_data()
dm.setup(stage="fit")

sample = dm.train_dataset[0]
audio = sample['input_values']
label = sample['labels']

Audio(audio, rate=32000)

In [85]:
# Test (soundscape)
dm.prepare_data()
dm.setup(stage="test")

sample = dm.test_dataset[0]
audio = sample['input_values']
label = sample['labels']

Audio(audio, rate=32000)

In [79]:
label.shape

torch.Size([21])