In [3]:
from birdset.datamodule import DatasetConfig
from birdset.datamodule.birdset_datamodule import BirdSetDataModule
from birdset.datamodule import LoadersConfig, LoaderConfig

dm = BirdSetDataModule(
    dataset=DatasetConfig(
        data_dir="./datasets",
        hf_path="DBD-research-group/BirdSet",
        hf_name="POW",
        n_workers=21,
        val_split=0.2,
        task="multiclass",
        classlimit=500,
        eventlimit=5,
        sample_rate=32000,
    ),
    loaders=LoadersConfig(
        train=LoaderConfig(batch_size=8, shuffle=True),
        valid=LoaderConfig(batch_size=8, shuffle=False),
        test=LoaderConfig(batch_size=8, shuffle=True),
    ),
)

dm.prepare_data()
dm.setup(stage="fit")

train_loader = dm.train_dataset
validation_loader = dm.val_dataset


  torchaudio.set_audio_backend("soundfile")
sampling: 100%|██████████| 48/48 [00:04<00:00, 10.59it/s]


Saving the dataset (0/1 shards):   0%|          | 0/41115 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10279 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/16052 [00:00<?, ? examples/s]

In [10]:
import tensorflow as tf
import numpy as np

num_classes = dm.num_classes

def data_generator(dataset, augment=False):
    for sample in dataset:
        # Get audio and label
        audio = sample['input_values']
        label = sample['labels']
        
        # Handle the shape - BirdSet returns shape (1, 32000) or (batch, 32000)
        if isinstance(audio, np.ndarray):
            audio = audio.flatten()  # Force to 1D
        else:
            audio = audio.numpy().flatten()  # Force to 1D
        
        # Ensure audio is exactly 1D with 32000 samples
        if audio.shape[0] != 32000:
            # Pad or trim to 32000
            if audio.shape[0] < 32000:
                audio = np.pad(audio, (0, 32000 - audio.shape[0]), mode='constant')
            else:
                audio = audio[:32000]
        
        # Extract label and convert to int
        if isinstance(label, np.ndarray):
            label = int(label.item() if label.size == 1 else label.flat[0])
        else:
            label = int(label.numpy().item() if label.numpy().size == 1 else label.numpy().flat[0])
        
        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        audio_tensor = tf.reshape(audio_tensor, [-1])
        
        # Compute STFT
        stft = tf.signal.stft(audio_tensor, frame_length=2048, frame_step=512, fft_length=2048)
        spectrogram = tf.abs(stft)
        
        # Convert to mel scale (80 mel bins is standard)
        num_spectrogram_bins = spectrogram.shape[-1]
        lower_edge_hertz, upper_edge_hertz = 80.0, 7600.0
        num_mel_bins = 128
        linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins, num_spectrogram_bins, 32000, lower_edge_hertz, upper_edge_hertz
        )
        mel_spectrogram = tf.tensordot(spectrogram, linear_to_mel_weight_matrix, 1)
        mel_spectrogram.set_shape(spectrogram.shape[:-1].concatenate(linear_to_mel_weight_matrix.shape[-1:]))
        
        # Convert to log scale (dB)
        log_mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
        
        # Convert to numpy
        spectrogram = log_mel_spectrogram.numpy()
        
        # Transpose to (time, mel_bins)
        spectrogram = spectrogram.T
        
        # Add channel dimension
        spectrogram = spectrogram[..., np.newaxis]
        
        # Resize to 224x224x1
        spectrogram_resized = tf.image.resize(spectrogram, [224, 224], method='bilinear').numpy()
        
        # Repeat to 3 channels
        spectrogram_3ch = np.repeat(spectrogram_resized, 3, axis=-1)

        # Normalize to [0, 1]
        spec_min = spectrogram_3ch.min()
        spec_max = spectrogram_3ch.max()
        spectrogram_3ch = (spectrogram_3ch - spec_min) / (spec_max - spec_min + 1e-8)

        # One-hot encode labels
        label_onehot = tf.one_hot(label, num_classes).numpy()

        if augment:
            # Random time/frequency masking (SpecAugment)
            spectrogram_3ch = tf.image.random_brightness(spectrogram_3ch, 0.2).numpy()
            
            # Random horizontal flip (time reversal)
            if np.random.random() > 0.5:
                spectrogram_3ch = np.flip(spectrogram_3ch, axis=1)
        
        yield spectrogram_3ch.astype(np.float32), label_onehot.astype(np.float32)

# Create TensorFlow datasets
train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(dm.train_dataset, augment=False),
    output_signature=(
        tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(num_classes,), dtype=tf.float32),
    )
).batch(8).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(dm.val_dataset),
    output_signature=(
        tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(num_classes,), dtype=tf.float32),
    )
).batch(8).prefetch(tf.data.AUTOTUNE)

In [12]:
# We want to use Tensorflow Keras for the model
import keras as ks
from keras.layers import Dense, Dropout

num_classes = dm.num_classes

# Copy efficientnetb0 architecture
model = ks.applications.ResNet50(
    include_top=False,
    weights='imagenet',
    input_shape=(224, 224, 3),
    pooling='avg'
)

x = model.output
x = Dropout(0.3)(x)
x = Dense(num_classes, activation='softmax')(x)
model = ks.models.Model(inputs=model.input, outputs=x)


In [13]:
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7)
]

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
                metrics=['accuracy'])

# let's train the model
model.fit(train_dataset, validation_data=val_dataset, epochs=10, callbacks=callbacks)


Epoch 1/10


  labels = torch.tensor(labels, dtype=torch.float16)


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1ca11425f90>

In [15]:
model.fit(train_dataset, validation_data=val_dataset, epochs=10, callbacks=callbacks)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1ca80a39390>