# 🐦 Bird Presence Detection with MobileNetV2 + Augmented Data Pipeline

This notebook builds **binary classifiers** for detecting the presence of specific bird species, focusing particularly on **rare species not seen in BirdNET classes**. Each classifier is trained separately per species using class balancing, heavy augmentations, and overlapping-vocalization simulation.

---

## 🔑 Pipeline Overview

### 1. Data Preparation
- Start from processed splits (`train_filenames.npy`, `val_filenames.npy`, etc.).
- Oversample target bird samples up to the count of the most frequent species.
- Add equal number of **non-bird** or **other bird** samples → balanced binary dataset.
- Optionally inject **soundscape** samples for robustness.

### 2. Data Generator (`AudioFeatureGeneratorMixedPadded`)
- Loads 5s audio chunks at 32 kHz.
- Applies waveform-level augmentations:
  - Reverb, compression, clipping
  - Pitch shifting, time-stretch, random shifting
  - Volume scaling, noise mixing
- Applies mixing-based augmentations:
  - **Target bird + soundscape**
  - **Bird + other birds** (same or different species)
  - **Bird + random noise**
- Converts audio chunks into Mel spectrograms:
  - Scaled to [0, 1]
  - Padded/truncated to fixed length
  - Expanded to 3 channels (RGB-style for CNNs)
- Optional **SpecAugment** applied (time/frequency masking).

### 3. Dataset Construction (`create_dataset`)
- Produces balanced `train_generator` and `val_generator`.
- Wraps them into TensorFlow `tf.data.Dataset` objects for efficient streaming.
- Ensures:
  - Shuffling in training
  - Deterministic loading in validation

---

## 🎯 Model: MobileNetV2 Bird Presence Classifier
- Base network: **MobileNetV2 (ImageNet-pretrained)** with global average pooling.
- Classification head:
  - Dense (512 units, ReLU + Dropout 0.5)
  - Dense (1 unit, sigmoid) → binary output
- L2 regularization applied to Dense layers.
- Option to fine-tune (base layers trainable or frozen).

---

## 🧪 Training Strategy
- **Loss**: Binary Crossentropy
- **Optimizer**: Adam (lr=5e-4)
- **Metrics**: Precision, Recall, F1 (custom callback)
- **Callbacks**:
  - `ValidationF1Callback`: evaluates F1 on validation set each epoch, saves best model + score.
  - `EarlyStopping`: stops if validation loss does not improve for N epochs, restoring best weights.
  - `LambdaCallback`: prints batch-level metrics (loss, precision, recall).

---

## 🚀 Training Loop
- For each **target bird** in the dataset:
  - Create train/val datasets for that species.
  - Build and fine-tune a MobileNetV2-based binary classifier.
  - Save the best model per species (`best_model_<bird_id>.keras`).
  - Track best F1 in `best_val_f1_<bird_id>.txt`.

---

## ✅ Key Outcomes
- Robust binary classifiers trained per species:
  - Handle **class imbalance** via oversampling.
  - Handle **background noise + environment** via augmentation.
  - Handle **overlapping calls** via mixing with soundscapes and other birds.
- Final models are bird-specific detectors that can generalize to noisy, real-world audio.





# **Imports**

In [1]:
from google.colab import drive
import zipfile
import os
import pandas as pd
import numpy as np
import librosa
import librosa.display
import random
import ast
import warnings
import joblib
import soundfile as sf

from IPython.display import Audio, display

import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, backend as K
from tensorflow.keras.utils import Sequence, to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import precision_score, recall_score, f1_score


drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/Main_Birdclef/scripts')
import birdclef_utils
IMPLEMENTATION='sigmoid'

Mounted at /content/drive


# **Unzip Birdclef Audio and Metadata**

In [2]:
birdclef_utils.retrieve_and_process_birdclef_data()
birdclef_utils.retrieve_and_process_birdclef_data(zip_filename='ColabUploads.zip')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Successfully extracted all files from birdclef-2025.zip to /content/data
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Successfully extracted all files from ColabUploads.zip to /content/data


In [23]:
main_dir='/content/data/'
main_processed_dir=os.path.join(main_dir,'ColabUploads')
processed_dir=os.path.join(main_processed_dir,'KaggleUploads')

drive_dir='/content/drive/MyDrive'
main_birdclef_dir=os.path.join(drive_dir,'Main_Birdclef')
csv_dir=os.path.join(main_birdclef_dir,'CSVs')
supplemental_files_dir=os.path.join(main_birdclef_dir,'supplemental_files')
models_dir=os.path.join(main_birdclef_dir,'models')

# **Training/Validation/Testing Filenames**

In [4]:
original_train_filenames=np.load(os.path.join(processed_dir,'train_filenames.npy'))
original_test_filenames=np.load(os.path.join(processed_dir,'test_filenames.npy'))
original_val_filenames=np.load(os.path.join(processed_dir,'val_filenames.npy'))

## **Birdnet Soft-labels for Train Audio**

In [5]:
df_start_times=pd.read_csv(os.path.join(csv_dir, 'birdnet_train_labels_final.csv'))
df_start_times

Unnamed: 0,filename,start_time,end_time,confidence,scientific_name,primary_label
0,grepot1/iNat1214104.ogg,0.0,3.0,0.662340,Nyctibius grandis,grepot1
1,grepot1/XC135771.ogg,3.0,6.0,0.981190,Nyctibius grandis,grepot1
2,grepot1/XC135771.ogg,6.0,9.0,0.372203,Nyctibius grandis,grepot1
3,grepot1/XC135771.ogg,12.0,15.0,0.979664,Nyctibius grandis,grepot1
4,grepot1/XC135771.ogg,21.0,24.0,0.966239,Nyctibius grandis,grepot1
...,...,...,...,...,...,...
174797,plukit1/XC421627.ogg,3.0,6.0,0.471399,Ictinia plumbea,plukit1
174798,plukit1/XC421627.ogg,6.0,9.0,0.987835,Ictinia plumbea,plukit1
174799,plukit1/XC421627.ogg,9.0,12.0,0.983746,Ictinia plumbea,plukit1
174800,plukit1/XC421627.ogg,12.0,15.0,0.698430,Ictinia plumbea,plukit1


# **Birdnet Soft-labels for Train Soundscapes**

In [6]:
labeled_soundscapes=pd.read_csv(os.path.join(csv_dir, 'birdnet_soft_labels.csv'))
labeled_soundscapes


Unnamed: 0,file,start_time,end_time,confidence,scientific_name,primary_label
0,H98_20230430_114000.ogg,21.0,24.0,0.476130,Pheugopedius fasciatoventris,blbwre1
1,H98_20230430_114000.ogg,27.0,30.0,0.171864,Pheugopedius fasciatoventris,blbwre1
2,H98_20230430_114000.ogg,48.0,51.0,0.330452,Pheugopedius fasciatoventris,blbwre1
3,H16_20230503_085500.ogg,6.0,9.0,0.986548,Tapera naevia,strcuc1
4,H16_20230503_085500.ogg,9.0,12.0,0.996310,Tapera naevia,strcuc1
...,...,...,...,...,...,...
27382,H84_20230422_185500.ogg,33.0,36.0,0.114390,Pulsatrix perspicillata,speowl1
27383,H84_20230422_185500.ogg,36.0,39.0,0.136469,Pulsatrix perspicillata,speowl1
27384,H84_20230422_185500.ogg,42.0,45.0,0.169751,Pulsatrix perspicillata,speowl1
27385,H84_20230422_185500.ogg,45.0,48.0,0.107960,Pulsatrix perspicillata,speowl1


# **Start times and confidences for non-birdnet classes**

In [7]:
start_time_info_non_birdnet=pd.read_csv(os.path.join(csv_dir, 'non_birdnet_start_time_quality_times.csv'))
start_time_info_non_birdnet

Unnamed: 0,filename,start_time,end_time,confidence,scientific_name,primary_label
0,1139490/CSA36385.ogg,0.0,5.0,0.778622,,1139490
1,1139490/CSA36385.ogg,1.0,6.0,0.794913,,1139490
2,1139490/CSA36385.ogg,2.0,7.0,0.796506,,1139490
3,1139490/CSA36385.ogg,3.0,8.0,0.784370,,1139490
4,1139490/CSA36385.ogg,4.0,9.0,0.625431,,1139490
...,...,...,...,...,...,...
10640,shghum1/XC571083.ogg,37.0,42.0,0.958973,,shghum1
10641,shghum1/XC571083.ogg,38.0,43.0,0.820355,,shghum1
10642,shghum1/XC571083.ogg,45.0,50.0,0.786046,,shghum1
10643,shghum1/XC571083.ogg,46.0,51.0,0.855939,,shghum1


# **Load train.csv (speech_cleaned_audio_with_duration.csv) and Saved Label Encoder**

In [8]:
warnings.filterwarnings('ignore')
df = pd.read_csv(os.path.join(processed_dir,'speech_cleaned_audio_with_duration.csv'), dtype={'primary_label': 'object'})
df['secondary_labels'] = df['secondary_labels'].apply(ast.literal_eval)

# Path to the label encoder file
label_encoder_path = os.path.join(supplemental_files_dir,'bird_label_encoder.joblib')

try:
    # Load the label encoder
    label_encoder = joblib.load(label_encoder_path)
    print(f"Successfully loaded label_encoder from: {label_encoder_path}")
except FileNotFoundError:
    print(f"Error: File not found at: {label_encoder_path}. "
          f"Please ensure the dataset 'processing-models' is attached and the path is correct.")

Successfully loaded label_encoder from: /content/drive/MyDrive/Main_Birdclef/supplemental_files/bird_label_encoder.joblib


# **Oversampling: Oversample underrepresented labels and flag for augmentation**

In [9]:
train=df[df['cleaned_filename'].isin(original_train_filenames)]
highest_bird_count=train['primary_label'].value_counts().sort_values(ascending=False).values[0]
birds=train['primary_label'].unique()
train_filenames=[]
train_bird_file_counts={}

for bird in birds:
    data=train[train['primary_label']==bird]
    all_files=list(data['cleaned_filename'].values)
    file_count=0
    original_file_count=0
    for file in all_files:
        train_filenames.append((file,False))
        original_file_count+=1
        file_count+=1
    if file_count>125:
      for i in range(20):
        train_filenames.append((random.choice(all_files),True))
        file_count+=1
    while file_count<highest_bird_count:
        train_filenames.append((random.choice(all_files),True))
        file_count+=1
    train_bird_file_counts[bird]=original_file_count
random.shuffle(train_filenames)
#train_filenames = random.sample(train_filenames, 25600)
val_filenames=[]

val_filenames=[(f,False) for f in original_val_filenames]
print(len(train_filenames))

warnings.filterwarnings("ignore", category=FutureWarning)

filename_mapping = df.set_index('cleaned_filename')['filename'].to_dict()

print('Length of Train Filenames', len(train_filenames))
train_filenames=[(filename_mapping[f[0]],f[1]) for f in train_filenames]
print('Length of Train Filenames', len(train_filenames))
no_bird_samples=[('No Bird',random.choice([True,False])) for _ in range(412)]
val_filenames=[(filename_mapping[f[0]],f[1]) for f in val_filenames]


df['isOneBird']=df['secondary_labels'].apply(lambda x: True if len(x)==0 or x[0]=='' else False)

short_train_filenames=np.load(os.path.join(processed_dir,'short_train_filenames.npy'),allow_pickle=True)
short_train_filenames=[f[0] for f in short_train_filenames]
print(f'Length of Short Train Filenames: {len(short_train_filenames)}')
one_bird_files=df[df['isOneBird']]['filename'].values
multiple_bird_files=df[~df['isOneBird']]['filename'].values

kind='all_birds'
if kind=='one_bird':
    train_filenames=[f for f in train_filenames if f[0] in one_bird_files]+[(f,False) for f in original_test_filenames[:4000] if f in one_bird_files]
    val_filenames=[f for f in val_filenames if f[0] in one_bird_files]
elif kind=='multiple_birds':
    train_filenames=[f for f in train_filenames if f[0] in short_train_filenames or f[0] in multiple_bird_files]
else:
    train_filenames=train_filenames
labeled_soundscapes=pd.read_csv(os.path.join(csv_dir, 'birdnet_soft_labels.csv'))
labeled_soundscape_fnames=labeled_soundscapes['file'].unique()
random.shuffle(labeled_soundscape_fnames)
train_filenames=train_filenames+[(f,random.choice([True,False,False])) for f in labeled_soundscape_fnames[:3000]]*25
print('Length of Train Filenames', len(train_filenames))

121354
Length of Train Filenames 121354
Length of Train Filenames 121354
Length of Short Train Filenames: 20735
Length of Train Filenames 196354


# **Class For retrieving labels and start-times with corresponding hard/soft labels.**

In [10]:
class RetreiveData:
  def __init__(self,labels_df,do_skip=True):
    self.do_skips=do_skip
    self.labels_df=labels_df
    self.label_encoder = joblib.load(os.path.join(supplemental_files_dir,'bird_label_encoder.joblib'))
    self.start_time_info=pd.read_csv(os.path.join(csv_dir, 'birdnet_train_labels_final.csv'))
    self.start_time_info['confidence']=self.start_time_info['confidence'].apply(lambda x: 1 if x>0.22 else x)
    self.start_time_info_secondary_labels=pd.read_csv(os.path.join(csv_dir, 'birdnet_secondary_label_detections.csv'))
    self.start_time_info_secondary_labels['confidence']=self.start_time_info_secondary_labels['confidence'].apply(lambda x: 1 if x>0.22 else x)
    self.labeled_soundscapes=pd.read_csv(os.path.join(csv_dir, 'birdnet_soft_labels.csv'))
    self.labeled_soundscapes=self.labeled_soundscapes.rename(columns={'file':'filename'})
    self.list_of_labeled_soundscapes=list(self.labeled_soundscapes['filename'].values)
    self.list_of_train_filenames=list(self.start_time_info['filename'].values)
    self.unlabeled_soundscapes=np.load(os.path.join(supplemental_files_dir,'unlabeled_background_files.npy'),allow_pickle=True)
    self.start_time_info_two=pd.read_csv(os.path.join(csv_dir, 'non_birdnet_start_time_quality_times.csv'))

    self.data=pd.concat([self.start_time_info,self.labeled_soundscapes,self.start_time_info_two,self.start_time_info_secondary_labels])
    self.all_filenames=self.data['filename'].values
    self.all_possible_filenames=pd.read_csv(os.path.join(processed_dir,'speech_cleaned_audio_with_duration.csv'))['filename'].values
    self.non_birdnet_labels=np.load(os.path.join(supplemental_files_dir,'non_birdnet_labels.npy'),allow_pickle=True)

  def get_file_and_start_time(self,filename):
    file_info=filename.split('/')[0]

    if file_info in self.label_encoder.classes_:
      file_kind='train_birds'
    else:
      file_kind='soundscapes'
    primary_labels=[]
    soft_labels=[]
    data=self.data[self.data['filename']==filename]
    try:
      data=data.sample(1)
    except:
      try:
        if file_info in self.non_birdnet_labels:
          return filename,[file_info],[1],None,file_kind
        else:
          if self.do_skips:
            return filename,[file_info],[1],"skip",file_kind
          else:
            return filename,[file_info],[1],None,file_kind
      except:
        return None
    start_time=data['start_time'].values[0]
    try:
      filename=data['filename'].values[0]
    except:
      filename=data['file'].values[0]
    try:
      # Make sure both filename and time window are matched
      window_mask = (
          (self.data['filename'] == filename) &
          (self.data['start_time'] >= start_time) &
          (self.data['start_time'] < start_time + 6)
      )
      data = self.data[window_mask]

      #data=data[(data['start_time']==start_time)&(data['filename']==filename)]
    except:
      window_mask = (
          (self.data['file'] == filename) &
          (self.data['start_time'] >= start_time) &
          (self.data['start_time'] < start_time + 4)
      )
      data = self.data[window_mask]
    # Group and average
    agg = data.groupby('primary_label')['confidence'].mean().reset_index()
    primary_labels = agg['primary_label'].tolist()
    soft_labels = agg['confidence'].tolist()

    return filename,primary_labels,soft_labels,start_time,file_kind
random.shuffle(train_filenames)
data=RetreiveData(df.set_index('filename'))
for i in range(256):
  fname,labels,soft_labels,start_time,file_kind=data.get_file_and_start_time(train_filenames[i][0])
  print(f'Filename: {fname}, Start Time: {start_time}, Labels: {labels}, Soft Labels: {soft_labels}, File Kind: {file_kind}')
  print(type(labels))


Filename: plctan1/XC454405.ogg, Start Time: 9.0, Labels: ['plctan1'], Soft Labels: [1.0], File Kind: train_birds
<class 'list'>
Filename: spbwoo1/XC946051.ogg, Start Time: skip, Labels: ['spbwoo1'], Soft Labels: [1], File Kind: train_birds
<class 'list'>
Filename: palhor2/XC939933.ogg, Start Time: 0.0, Labels: ['palhor2'], Soft Labels: [0.584149494767189], File Kind: train_birds
<class 'list'>
Filename: 1192948/CSA36366.ogg, Start Time: 1.0, Labels: ['1192948'], Soft Labels: [0.6600061357021332], File Kind: train_birds
<class 'list'>
Filename: 50186/CSA08495.ogg, Start Time: 202.0, Labels: ['50186'], Soft Labels: [0.6332333485285441], File Kind: train_birds
<class 'list'>
Filename: brtpar1/XC835183.ogg, Start Time: skip, Labels: ['brtpar1'], Soft Labels: [1], File Kind: train_birds
<class 'list'>
Filename: 787625/iNat673795.ogg, Start Time: 11.0, Labels: ['787625'], Soft Labels: [0.5606703758239746], File Kind: train_birds
<class 'list'>
Filename: H14_20230517_122000.ogg, Start Time: 2

# **Data Generator Pipeline Summary**

The custom `AudioFeatureGeneratorMixedPadded` loads 5-second audio chunks and applies a variety of augmentations to prepare training-ready mel spectrogram inputs for bird presence classification.

### Key Processes Applied:

- **Audio Loading & Chunking:**  
  Loads 5s audio segments at 32 kHz from balanced filename lists.

- **Waveform-Level Augmentations:**  
  - Reverb (delay-based)  
  - Soft compression (tanh clipping)  
  - Pitch shifting (±0.5 semitones)  
  - Time stretching and random shifting  
  - Volume scaling and random clipping  
  - Mixing with background noise audio

- **Mixing Augmentations:**  
  - Mix target bird audio with soundscape clips  
  - Mix bird audio with other bird species (same or different)  
  - Simulate overlapping bird calls and real-world environments

- **Feature Extraction:**  
  Converts augmented audio to Mel spectrograms (scaled [0,1]), padded/truncated to fixed length, then duplicated across 3 channels to match CNN input expectations.

- **SpecAugment:**  
  Random time masking and frequency masking applied on mel spectrograms for further robustness.

- **Batching:**  
  Outputs batches of `(mel_spectrograms, multi-hot labels)` ready for model training, with optional shuffling each epoch.

This pipeline enables robust, balanced, and diverse training data that helps the model generalize to noisy, overlapping bird vocalizations in natural environments.
```



In [11]:
class AudioFeatureGeneratorMixedPadded(tf.keras.utils.Sequence):
    def __init__(self, filenames, labels_df, audio_dir='train_audio', soundscapes_dir='train_soundscapes',
                 sr=32000, chunk_duration=5, batch_size=32, shuffle=True,
                 target_time_length_spectrogram=250, # Removed target_time_length_mfcc
                  num_classes=None, kind='train',return_info=False,imp_type='sigmoid',main_bird=None,all_possible_filenames=None,use_non_birdnet=False,do_skip=True):
        self.use_non_birdnet=use_non_birdnet
        self.do_skip=do_skip
        if self.use_non_birdnet:
          self.random_start_time_prob=0.5
        else:
          self.random_start_time_prob=0.0
        self.all_possible_filenames=all_possible_filenames
        self.start_time_info=pd.read_csv(os.path.join(csv_dir, 'birdnet_train_labels_final.csv'))
        self.start_time_info=self.start_time_info[self.start_time_info['filename'].isin([f[0] for f in self.all_possible_filenames])]
        self.soundscape_audio_info=pd.read_csv(os.path.join(csv_dir, 'birdnet_soft_labels.csv'))
        self.return_info=return_info
        self.filenames = filenames
        self.primary_bird=main_bird

        self.labels_df = labels_df.set_index('filename')
        self.audio_dir = audio_dir
        self.soundscapes_dir = soundscapes_dir
        self.sr = sr
        self.chunk_duration = chunk_duration

        self.batch_size = batch_size
        self.shuffle = shuffle
        bird_dir=os.path.join(models_dir,main_bird)
        self.label_encoder = joblib.load(os.path.join(bird_dir,f'bird_label_encoder_{main_bird}.joblib'))
        self.target_time_length_spectrogram = target_time_length_spectrogram

        self.start_time_label_data=RetreiveData(self.labels_df,do_skip=self.do_skip)
        self.non_birnet_files=self.start_time_label_data.start_time_info_two['filename'].values
        self.start_time_info=self.start_time_info[~self.start_time_info['filename'].isin(self.non_birnet_files)]

        self.start_time = start_time
        one_bird_files=self.labels_df[self.labels_df['isOneBird']]['cleaned_filename'].values
        self.short_train_filenames = np.load(os.path.join(processed_dir,'short_train_filenames.npy'),allow_pickle=True)
        self.short_train_filenames=[f for f in self.short_train_filenames if f[0] in one_bird_files]
        self.kind = kind
        self.implementation_type=imp_type

        self.num_classes = num_classes if num_classes is not None else len(self.label_encoder.classes_)

        self.normalize_audio = True

        self.all_classes = self.label_encoder.classes_
        self.on_epoch_end()

        self.write_one=True

    def __len__(self):
        return int(np.ceil(len(self.filenames) / self.batch_size))


    def _normalize(self, audio):
        """Normalizes the audio waveform to the range [-1, 1] based on its peak."""
        peak = np.abs(audio).max()
        if peak > 0:
            return audio / peak
        return audio


    def spec_augment(self, mel_spectrogram, time_mask_param=8, num_time_masks=3, freq_mask_param=4, num_freq_masks=3):
        """
        Applies time and frequency masking to the Mel spectrogram.
        It's designed to work with 3-channel Mel spectrograms and maintain that shape.

        Args:
            mel_spectrogram: np.ndarray, shape (time, freq, 3) (or whatever shape your _extract_features produces)
            time_mask_param: int, maximum width of the time mask.
            num_time_masks: int, number of time masks to apply.
            freq_mask_param: int, maximum width of the frequency mask.
            num_freq_masks: int, number of frequency masks to apply.

        Returns:
            np.ndarray: augmented mel_spectrogram (same shape as input: time, freq, 3).
        """
        # Create a copy to avoid modifying the original array in place
        augmented_mel = mel_spectrogram.copy()

        # Get the dimensions of the mel spectrogram
        num_time_steps_mel = augmented_mel.shape[0]  # Time dimension
        num_freq_bins_mel = augmented_mel.shape[1]  # Frequency dimension
        # The channel dimension (augmented_mel.shape[2]) will be implicitly handled

        # --- Time Masking ---
        # Applied across all frequency bins AND all 3 channels
        for _ in range(num_time_masks):
            t = np.random.randint(0, time_mask_param + 1)
            if t > 0 and num_time_steps_mel > t:
                t0 = np.random.randint(0, num_time_steps_mel - t + 1)
                # Set a rectangular region to zero across all frequency bins and channels
                augmented_mel[t0:t0 + t, :, :] = 0

        # --- Frequency Masking ---
        # Applied across all time steps AND all 3 channels
        if freq_mask_param is not None and num_freq_masks > 0:
            for _ in range(num_freq_masks):
                f_mel = np.random.randint(0, freq_mask_param + 1)
                if f_mel > 0 and num_freq_bins_mel > f_mel:
                    f0_mel = np.random.randint(0, num_freq_bins_mel - f_mel + 1)
                    # Set a rectangular region to zero across all time steps and channels
                    augmented_mel[:, f0_mel:f0_mel + f_mel, :] = 0

        # The function now returns only the augmented mel spectrogram.
        # Its shape will be identical to the input shape (e.g., (time, freq, 3)).
        return augmented_mel

    def is_speech_in_file(self,filename):
        speech_time_info=self.speech_time_info[filename]
        if len(speech_time_info)==0:
            return False
        else:
            return True
    def __getitem__(self, idx):
      file_index = idx * self.batch_size
      if file_index >= len(self.filenames):
          raise StopIteration
      batch_filenames = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
      batch_mel_features = []
      batch_labels = []
      tracked_filenames = []
      tracked_start_times = []
      tracked_labels = []

      for (filename, augment_flag) in batch_filenames:
          filenames_for_sample=[]
          try:
              # Retrieve metadata and file type
              fname, primary_labels, soft_labels, start_time, file_type = self.start_time_label_data.get_file_and_start_time(filename)
              if start_time=='skip':
                continue
              filenames_for_sample.append(fname)
              # Choose correct directory based on file_type
              if file_type == 'train_birds':
                  audio_path = os.path.join(self.audio_dir, fname)
              else:
                  audio_path = os.path.join(self.soundscapes_dir, fname)

              # If file doesn't exist, skip
              if not os.path.isfile(audio_path):
                  print(f"File not found: {audio_path}")
                  continue

              # If start_time is None, pick a random start time
              if start_time is None or random.random()<self.random_start_time_prob:

                  try:
                      bird_audio, sr = librosa.load(audio_path, sr=self.sr)
                      file_duration = librosa.get_duration(y=bird_audio, sr=self.sr)
                      try:
                          start_time = random.choice([i - 5 for i in range(5, int(file_duration))] + [0])
                      except Exception:
                          start_time = 0
                  except Exception as e:
                      print(f"Error loading full audio for {fname}: {e}")
                      continue

              # Apply adjustment
              if start_time > 1:
                  adjustment = random.uniform(-1, 1)
              else:
                  adjustment = random.uniform(0, 1)
              adjusted_start_time = start_time + adjustment

              # Load chunk
              try:
                  bird_audio_chunk, sr = librosa.load(
                      audio_path, sr=self.sr, duration=self.chunk_duration, offset=adjusted_start_time
                  )
              except Exception as e:
                  print(f"Error loading chunk for {fname}: {e}")
                  continue

              # Pad if needed
              if len(bird_audio_chunk) < self.sr * self.chunk_duration:
                  bird_audio_chunk = np.pad(
                      bird_audio_chunk,
                      (0, self.sr * self.chunk_duration - len(bird_audio_chunk)),
                      mode='constant'
                  )
              added_fname=''
              random_num=random.random()
              if file_type=='soundscapes' and random_num<0.5:
                bird_audio_chunk,primary_labels,soft_labels,added_fname=self._mix_files(bird_audio_chunk,primary_labels,soft_labels)

              if file_type=='train_birds' and random_num<0.5:
                bird_audio_chunk,primary_labels,soft_labels,added_fname=self._mix_files(bird_audio_chunk,primary_labels,soft_labels,audio_type='training')

              filenames_for_sample.append(added_fname)
              tracked_start_times.append(adjusted_start_time)
              tracked_filenames.append(filenames_for_sample)
              sigmoid_label = np.zeros(self.num_classes)
              if self.primary_bird in primary_labels:
                idx_label = self.label_encoder.transform([self.primary_bird])[0]
                sigmoid_label[idx_label] = 1

              # Data augmentation
              if (augment_flag or random.random() < 0.20) and self.kind == 'train':
                  augmented_label = sigmoid_label.copy()
                  augmented_audio = self._augment_audio(bird_audio_chunk)
                  augmented_audio = self._normalize(augmented_audio)
                  mel_features = self._extract_features(augmented_audio)
                  if np.random.rand() < 0.50:
                      mel_features = self.spec_augment(mel_features)
                  batch_mel_features.append(mel_features)
                  batch_labels.append(augmented_label)
              else:
                  audio_to_process = self._normalize(bird_audio_chunk)
                  mel_features = self._extract_features(audio_to_process)
                  batch_mel_features.append(mel_features)
                  batch_labels.append(sigmoid_label)

          except Exception as e_audio:
              print(f"Error processing audio file {fname}: {e_audio}")
              batch_mel_features.append(np.zeros((self.target_time_length_spectrogram, 128, 3)))
              batch_labels.append(np.zeros(self.num_classes))

      batch_mel_features = np.array(batch_mel_features)
      batch_labels = np.array(batch_labels)

      if not self.return_info:
          return batch_mel_features, batch_labels
      else:
          return batch_mel_features, batch_labels, tracked_filenames, tracked_start_times


    def get_label(self, filename):
        primary_label = self.labels_df.loc[filename, 'primary_label']
        secondary_labels = self.labels_df.loc[filename, 'secondary_labels']
        encoded_primary = self.label_encoder.transform([primary_label])[0]
        label = np.zeros(self.num_classes, dtype=np.float32)
        label[encoded_primary] = 1.0
        for secondary_label in secondary_labels:
            if secondary_label != '':
                try:
                    encoded_secondary = self.label_encoder.transform([secondary_label])[0]
                    label[encoded_secondary] = 1.0
                except ValueError:
                    print(f"Warning: Secondary label '{secondary_label}' not found in encoder.")
        return label


    def _pad_or_truncate_audio(self, data, target_length):
        if data.shape[0] < target_length:
            padding = np.zeros((target_length - data.shape[0],), dtype=data.dtype)
            return np.concatenate((data, padding))
        elif data.shape[0] > target_length:
            return data[:target_length]
        else:
            return data
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.filenames)
    def mix_soundscape_with_train_audio(self, audio): # The audio is from a soundscape file
      start_time_info=self.start_time_info
      start_time_info=start_time_info[start_time_info['primary_label']!=self.primary_bird]
      audio_file=random.choice(start_time_info['filename'].values)

      fname, primary_labels, soft_labels, start_time, file_type = self.start_time_label_data.get_file_and_start_time(audio_file)
      audio_path=os.path.join(self.audio_dir,audio_file)
      new_audio, _ = librosa.load(audio_path, sr=self.sr, duration=self.chunk_duration,offset=start_time)

      if len(new_audio) < self.sr * self.chunk_duration:
            repetitions = int(np.ceil(self.sr * self.chunk_duration / len(new_audio)))
            new_audio = np.tile(new_audio, repetitions)[:self.sr * self.chunk_duration]
      return new_audio,primary_labels,soft_labels,audio_file

    def mix_train_audio_with_soundscape(self, audio): # The audio is a training example
      start_time_info=self.soundscape_audio_info
      if self.kind=='train':
        start_time_info=self.soundscape_audio_info[:3000]
      else:
        start_time_info=self.soundscape_audio_info[3000:]
      start_time_info=start_time_info[(start_time_info['confidence']>0.3)&(start_time_info['primary_label']!=self.primary_bird)]
      audio_file=random.choice(start_time_info['file'].values)

      fname, primary_labels, soft_labels, start_time, file_type = self.start_time_label_data.get_file_and_start_time(audio_file)
      audio_path=os.path.join(self.soundscapes_dir ,audio_file)
      new_audio, _ = librosa.load(audio_path, sr=self.sr, duration=self.chunk_duration,offset=start_time)

      if len(new_audio) < self.sr * self.chunk_duration:
            repetitions = int(np.ceil(self.sr * self.chunk_duration / len(new_audio)))
            new_audio = np.tile(new_audio, repetitions)[:self.sr * self.chunk_duration]
      return new_audio,primary_labels,soft_labels,audio_file
    def _mix_files(self, audio,primary_labels,soft_labels,audio_type='soundscape'): # Default is to pass a soundscape file to mix with a training example
        try:
            if audio_type=='soundscape':
              new_audio,new_primary_labels,new_soft_labels,new_audio_file=self.mix_soundscape_with_train_audio(audio)
            else:
              new_audio,new_primary_labels,new_soft_labels,new_audio_file=self.mix_train_audio_with_soundscape(audio)

            primary_labels=primary_labels+new_primary_labels
            soft_labels=soft_labels+new_soft_labels
            bird_level_noise = random.uniform(0.3, 0.7)
            other_bird_level = 1 - bird_level_noise
            mixed_audio = audio * bird_level_noise + new_audio * other_bird_level
            return mixed_audio,primary_labels,soft_labels,new_audio_file
        except Exception as e:
            print(f"Error mixing noise: {e}")
            return audio,primary_labels,soft_labels
    def _load_and_pad_noise(self):
        noise_file = random.choice(os.listdir(self.soundscapes_dir))
        noise_path = os.path.join(self.soundscapes_dir, noise_file)
        noise, _ = librosa.load(noise_path, sr=self.sr, duration=self.chunk_duration)
        if len(noise) < self.sr * self.chunk_duration:
            repetitions = int(np.ceil(self.sr * self.chunk_duration / len(noise)))
            noise = np.tile(noise, repetitions)[:self.sr * self.chunk_duration]
        return noise


    def _extract_features(self, audio):
        mel_spec = librosa.feature.melspectrogram(y=audio, sr=self.sr)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max).T

        min_db = -80.0
        max_db = 0.0

        mel_spec_scaled = (mel_spec_db - min_db) / (max_db - min_db)
        mel_spec_scaled = np.clip(mel_spec_scaled, 0.0, 1.0)

        # Pad or truncate and add a channel dimension
        mel_spec_padded = self._pad_or_truncate(mel_spec_scaled, self.target_time_length_spectrogram)[:, :, np.newaxis]

        # Duplicate the single mel channel across 3 channels for EfficientNet
        mel_spec_padded_3_channel = np.repeat(mel_spec_padded, 3, axis=-1)

        # Removed all MFCC computation and return
        return mel_spec_padded_3_channel
    def _get_different_filename_same_bird(self, current_filename,peak_times_strategy=False):
        other_cleaned_filenames=[f[0] for f in self.short_train_filenames if f[0].split('/')[0]==current_filename.split('/')[0]]
        other_cleaned_filename=random.choice(other_cleaned_filenames)
        other_filename_df=self.labels_df[self.labels_df['cleaned_filename']==other_cleaned_filename]
        other_filename=other_filename_df.index[0]

        duration=other_filename_df['duration'].values[0]
        if duration<5:
            start_time=0+random.uniform(0,1)
        else:
            start_time=random.choice([0,1,2,3])
            start_time+=random.uniform(-1,1)

        return other_filename,start_time

    def _get_different_filename(self, current_filenames,peak_times_strategy=False):
        if self.kind=='train':
            peak_times_dict=self.peak_times_dict_train
        else:
            peak_times_dict=self.peak_times_dict_val
        if peak_times_strategy:
            other_filenames=[f for f in list(peak_times_dict.keys()) if f not in current_filenames and len(peak_times_dict[f])!=0]
            other_filename=random.choice(other_filenames)
            current_filenames.append(other_filename)
            return other_filename, random.choice(peak_times_dict[other_filename])
        else:
            if random.random()<0.8:
                other_cleaned_filenames=[f[0] for f in self.short_train_filenames if f[0] not in current_filenames]
                other_cleaned_filename=random.choice(other_cleaned_filenames)
                other_filename_df=self.labels_df[self.labels_df['cleaned_filename']==other_cleaned_filename]
                other_filename=other_filename_df.index[0]
                current_filenames.append(other_filename)

                duration=other_filename_df['duration'].values[0]
                if duration<5:
                    start_time=0+random.uniform(0,1)
                else:
                    start_time=random.choice([0,1,2,3])
                    start_time+=random.uniform(-1,1)
            else:
                other_cleaned_filenames=list(self.excellent_times.keys())
                other_cleaned_filename=random.choice(other_cleaned_filenames)

                other_filename_df=self.labels_df[self.labels_df['cleaned_filename']==other_cleaned_filename]
                other_filename=other_filename_df.index[0]
                start_time=random.choice(self.excellent_times[other_cleaned_filename])
                adjustment=random.uniform(-1,1)
                if start_time!=0:
                    start_time+=adjustment
                current_filenames.append(other_filename)

            return other_filename,start_time,current_filenames

    def _pad_or_truncate(self, data, target_length):
        if data.shape[0] < target_length:
            padding = np.zeros((target_length - data.shape[0], data.shape[1]), dtype=data.dtype)
            return np.vstack((data, padding))
        elif data.shape[0] > target_length:
            return data[:target_length]
        else:
            return data

    def _mix_with_noise(self, audio):
        try:
            noise = self._load_and_pad_noise()
            bird_level_noise = random.uniform(0.5, 0.9)
            noise_level = 1 - bird_level_noise
            mixed_audio_with_noise = audio * bird_level_noise + noise * noise_level
            return mixed_audio_with_noise
        except Exception as e:
            print(f"Error mixing noise: {e}")
            return audio


    def time_stretch_and_pad_truncate(self,audio):
        """
        Applies time stretching to an audio signal and then pads or truncates
        it to match the target length.

        Args:
            audio (np.ndarray): The input audio signal.
            sr (int): The sampling rate of the audio.
            rate (float): The stretching factor (e.g., 0.8 for slower, 1.2 for faster).
            target_length (int): The desired length of the output audio (in samples).

        Returns:
            np.ndarray: The time-stretched and padded/truncated audio signal.
        """
        rate = random.uniform(0.8, 1.2)
        sr=self.sr
        target_length=self.sr * self.chunk_duration
        stretched_audio = librosa.effects.time_stretch(audio, rate=rate)
        try:
            current_length = stretched_audio.shape[0]
        except:
            print('strethching function')

        if current_length < target_length:
            # Pad with zeros at the end if shorter
            padding = target_length - current_length
            stretched_audio = np.pad(stretched_audio, (0, padding), mode='constant')
        elif current_length > target_length:
            # Truncate from the end if longer
            stretched_audio = stretched_audio[:target_length]

        return stretched_audio


    def _augment_audio(self, audio):
        """
        Applies data augmentation to the audio.

        Args:
            audio (np.ndarray): The audio waveform.
            label (np.ndarray): The multi-hot encoded label.
            filename (str): the filename of the audio

        Returns:
            tuple: Augmented audio and label.
        """
        augmented_audio = audio.copy()
        is_mixed_birds = False
        # Reverb (Simple Delay-Based)
        if random.random() < 0.4:  # Probability of applying reverb
            n_delay_samples = int(0.1 * self.sr)  # 100ms delay
            feedback = random.uniform(0.2, 0.5)
            reverberated_audio = np.zeros_like(augmented_audio)
            for i in range(len(augmented_audio)):
                reverberated_audio[i] = augmented_audio[i]
                if i >= n_delay_samples:
                    reverberated_audio[i] += feedback * reverberated_audio[i - n_delay_samples]
            augmented_audio = reverberated_audio

        # Compression (Simple Soft Clipping)
        if random.random() < 0.2:  # Probability of applying compression
            threshold = random.uniform(0.6, 0.9)
            gain = random.uniform(1.0, 1.5)
            compressed_audio = np.tanh(augmented_audio * gain / threshold) * threshold
            augmented_audio = compressed_audio
        # Time shifting
        if random.random() < 0.4:
            shift_seconds = random.uniform(-2.5, 2.5)  # Shift by up to 0.5 seconds
            shift_samples = int(shift_seconds * self.sr)
            augmented_audio = np.roll(augmented_audio, shift_samples)

        if random.random() < 0.0:
            mulaw_c = random.uniform(150, 300)  # Mu-law parameter
            augmented_audio = librosa.mu_compress(augmented_audio, mu=mulaw_c)
        if random.random() < 0.2:
            clip_factor = random.uniform(0.7, 0.95) # Adjust the range
            augmented_audio = np.clip(augmented_audio, -clip_factor, clip_factor)

        # Pitch shifting
        if random.random() < 0.1:
            n_steps = random.uniform(-0.5, 0.5)
            augmented_audio = librosa.effects.pitch_shift(augmented_audio, sr=self.sr, n_steps=n_steps)
         # Volume change
        if random.random() < 0.3:
            amplitude_factor = random.uniform(0.5, 1.3)
            augmented_audio = augmented_audio * amplitude_factor

        # Time stretching
        if random.random() < 0.1:
            augmented_audio = self.time_stretch_and_pad_truncate(augmented_audio)

        return augmented_audio

    def mix_with_other_bird(self,audio,filenames,multi_hot_label,mix_type='different_bird'):

        other_filename,start_time,filenames = self._get_different_filename(filenames)

        new_label=multi_hot_label.copy()
        if other_filename:
            other_bird_audio, _ = librosa.load(os.path.join(self.audio_dir, other_filename), sr=self.sr,
                                                    duration=self.chunk_duration, offset=start_time)
            if random.random()<0.5:
                other_bird_audio=self._augment_audio(other_bird_audio)
            if len(other_bird_audio) < self.sr * self.chunk_duration:
                other_bird_audio = np.pad(other_bird_audio,(0, self.sr * self.chunk_duration - len(other_bird_audio)),
                                        mode='constant')
            bird_level = random.uniform(0.35, 0.65)
            other_bird_level = 1 - bird_level
            mixed_audio = audio * bird_level + other_bird_audio * other_bird_level

            secondary_label_name = self.labels_df.loc[other_filename, 'primary_label']
            secondary_label_encoded = self.label_encoder.transform([secondary_label_name])[0]

            new_label[secondary_label_encoded] = 1
        return mixed_audio,new_label,filenames
    def mix_with_multiple_birds(self,audio,label,filenames,num_birds=1):
        mixed_audio=audio.copy()
        for i in range(num_birds):
            try:
                mixed_audio,label,filenames=self.mix_with_other_bird(mixed_audio,filenames,label)
            except:
                print('in multiple birds function')
        if self.write_one:
            sample_rate = 32000

            # Save the mixed audio
            sf.write('/kaggle/working/mixed_sample.wav', mixed_audio, sample_rate)
            print("Saved to /kaggle/working/mixed_sample.wav")
            np.save('/kaggle/working/mixed_sample_labels.npy',filenames)
            self.write_one=False
        return mixed_audio, label

    def mix_with_same_bird(self,audio,filename,mix_type='different_bird'):

        other_filename,start_time= self._get_different_filename_same_bird(filename)

        if other_filename:
            other_bird_audio, _ = librosa.load(os.path.join(self.audio_dir, other_filename), sr=self.sr,
                                                    duration=self.chunk_duration, offset=start_time)
            if random.random()<0.5:
                other_bird_audio=self._augment_audio(other_bird_audio)
            if len(other_bird_audio) < self.sr * self.chunk_duration:
                other_bird_audio = np.pad(other_bird_audio,(0, self.sr * self.chunk_duration - len(other_bird_audio)),
                                        mode='constant')
            bird_level = random.uniform(0.35, 0.65)
            other_bird_level = 1 - bird_level
            mixed_audio = audio * bird_level + other_bird_audio * other_bird_level
        return mixed_audio,filename



# ****************************************augmented_version ends here ***************************************************************

# Example usage (replace with your actual data loading)
#df = pd.read_csv('speech_cleaned_audio_with_duration.csv', dtype={'primary_label': 'object'})

def create_dataset(primary_bird,train_filenames,val_filenames, files_multiplier=10,do_skip=True, use_non_birdnet=False,take_from_test=False):
  original_test_filenames=np.load(os.path.join(processed_dir,'test_filenames.npy'))

  labeled_soundscapes=pd.read_csv(os.path.join(csv_dir, 'birdnet_soft_labels.csv'))
  labeled_soundscape_fnames=labeled_soundscapes['file'].values
  labeled_soundscape_mapping=labeled_soundscapes.set_index('file')['primary_label'].to_dict()

  train_filenames_with_bird=[f for f in train_filenames if f[0].split('/')[0]==primary_bird]
  train_filenames_without_bird=[f for f in train_filenames if f[0].split('/')[0]!=primary_bird]
  #train_filenames_from_soundscapes=[f for f in train_filenames if labeled_soundscape_mapping[f[0]]==primary_bird]

  val_filenames_with_bird=[f for f in val_filenames if f[0].split('/')[0]==primary_bird]
  val_filenames_without_bird=[f for f in val_filenames if f[0].split('/')[0]!=primary_bird]
  #val_filenames_from_soundscapes=[f for f in val_filenames if labeled_soundscape_mapping[f[0]]==primary_bird]
  if take_from_test:
    test_filenames=[f for f in original_test_filenames if f.split('/')[0]==primary_bird]
    if len(test_filenames)>1:
      train_filenames_with_bird+=[(filename_mapping[f],random.choice([True,False])) for f in test_filenames[:1]]*(587//4)
      val_filenames_with_bird+=[(filename_mapping[f],False) for f in test_filenames[1:]]
    else:
      val_filenames_with_bird+=[(filename_mapping[f],False) for f in test_filenames]

  unique_labels = df[df['primary_label']==primary_bird]['primary_label'].unique()
  bird_dir=os.path.join(models_dir,primary_bird)
  try:
    label_encoder=joblib.load(os.path.join(bird_dir, f'bird_label_encoder_{primary_bird}.joblib'))
  except:
    label_encoder = LabelEncoder().fit(unique_labels)

    joblib.dump(label_encoder, os.path.join(bird_dir, f'bird_label_encoder_{primary_bird}.joblib'))
  train_filenames_with_bird=train_filenames_with_bird

  shortened_train_filenames=train_filenames_with_bird+random.sample(train_filenames_without_bird,len(train_filenames_with_bird))

  val_filenames_with_bird=val_filenames_with_bird*files_multiplier
  shortened_val_filenames=val_filenames_with_bird+random.sample(val_filenames_without_bird,len(val_filenames_with_bird))


  random.shuffle(shortened_train_filenames)
  # --- Parameters ---
  audio_dir = os.path.join(main_dir,'train_audio')
  soundscapes_dir = os.path.join(main_dir,'train_soundscapes')
  sr = 32000
  chunk_duration = 5
  batch_size_train = 128
  batch_size_val = 32
  target_time_length_spectrogram = 320
  target_time_length_mfcc = 320
  n_mfcc = 40
  start_time = 0
  num_classes = len(unique_labels) # Get the number of unique bird species
  IMPLEMENTATION = 'sigmoid'
  print('Length of train_filenames',len(shortened_train_filenames))
  # --- Training Data Generator ---
  train_generator = AudioFeatureGeneratorMixedPadded(
      filenames=shortened_train_filenames,
      labels_df=df[['cleaned_filename', 'primary_label','secondary_labels','filename','duration','isOneBird']],
      audio_dir=audio_dir,
      soundscapes_dir=soundscapes_dir,
      sr=sr,
      chunk_duration=chunk_duration,
      batch_size=batch_size_train,
      shuffle=True,
      target_time_length_spectrogram=target_time_length_spectrogram,
      num_classes=num_classes,
      imp_type=IMPLEMENTATION,
      main_bird=primary_bird,
      all_possible_filenames=train_filenames,
      do_skip=do_skip,
      use_non_birdnet=use_non_birdnet


  )
  val_chunk_duration=5
  random.shuffle(val_filenames)
  val_filenames_to_use=shortened_val_filenames
  val_generator = AudioFeatureGeneratorMixedPadded(
      filenames=val_filenames_to_use,
      labels_df=df[['cleaned_filename', 'primary_label','secondary_labels','filename','duration','isOneBird']],
      audio_dir=audio_dir,
      soundscapes_dir=soundscapes_dir,
      sr=sr,
      chunk_duration=chunk_duration,
      batch_size=batch_size_val,  # Use a potentially different batch size for validation
      shuffle=False,             # Important: No shuffling for validation
      target_time_length_spectrogram=target_time_length_spectrogram,            # Or a fixed start time for consistency
      num_classes=num_classes,
      kind='val' ,
      imp_type=IMPLEMENTATION,
      main_bird=primary_bird,
      all_possible_filenames=val_filenames,# Set the 'kind' to 'val'
      do_skip=do_skip,
      use_non_birdnet=use_non_birdnet
  )
  print('Length of Validation Filenames', len(shortened_val_filenames))
  print("Train generator size:", len(train_generator))
  print("Validation generator size:", len(val_generator))
  if IMPLEMENTATION=='sigmoid':
    print("Train generator size:", len(train_generator))
    print("Validation generator size:", len(val_generator))

    print_first=True

    if print_first:
        for i in range(1):
            # Your __getitem__ now returns only mel_features and multi_labels
            features, multi_labels = train_generator[i]
            print(f"Batch {i+1} - Mel shape: {features.shape}, Multi-label shape: {multi_labels.shape}")
            print("Multi-labels (first example in batch):", multi_labels[0])
            original_label_index = np.argmax(multi_labels[0])
            original_label = train_generator.label_encoder.inverse_transform([original_label_index])[0]
            print("Original label index:", original_label_index, "Original label:", original_label)


    num_classes = len(train_generator.label_encoder.classes_)

    # Determine output signature
    # Only one input tensor for features (Mel Spectrogram)
    example_features, example_labels = train_generator[0]
    output_signature = (
        tf.TensorSpec(shape=(None, train_generator.target_time_length_spectrogram, 128, 3), dtype=tf.float32), # Updated shape for 3 channels
        tf.TensorSpec(shape=(None, num_classes), dtype=tf.float32)
    )

    # Create tf.data.Dataset from the generators
    def train_gen():
        for i in range(len(train_generator)):
            features, labels = train_generator[i] # Only mel_features and labels
            yield (
                tf.convert_to_tensor(features, dtype=tf.float32), # Only one feature tensor
                tf.convert_to_tensor(labels, dtype=tf.float32)
            )
    train_dataset = tf.data.Dataset.from_generator(train_gen, output_signature=output_signature)

    def val_gen():
        for i in range(len(val_generator)):
            features, labels = val_generator[i] # Only mel_features and labels
            yield (
                tf.convert_to_tensor(features, dtype=tf.float32), # Only one feature tensor
                tf.convert_to_tensor(labels, dtype=tf.float32)
            )
    val_dataset = tf.data.Dataset.from_generator(val_gen, output_signature=output_signature)
    return train_generator,val_generator,train_dataset,val_dataset,shortened_train_filenames,val_filenames_to_use
train_generator, val_generator, train_dataset, val_dataset,tnames,vnames=create_dataset('22976',train_filenames,val_filenames)

Length of train_filenames 1178
Length of Validation Filenames 100
Train generator size: 10
Validation generator size: 4
Train generator size: 10
Validation generator size: 4
Batch 1 - Mel shape: (124, 320, 128, 3), Multi-label shape: (124, 1)
Multi-labels (first example in batch): [0.]
Original label index: 0 Original label: 22976


In [12]:
# Extract the first 5 labels from train_dataset
first_5_labels = []
for features, labels in train_dataset.unbatch().take(20):
    first_5_labels.append(labels.numpy())

print("First 5 labels in train dataset:")
for i, label in enumerate(first_5_labels):
    print(f"Label {i+1}: {label}")


First 5 labels in train dataset:
Label 1: [0.]
Label 2: [0.]
Label 3: [0.]
Label 4: [0.]
Label 5: [1.]
Label 6: [1.]
Label 7: [1.]
Label 8: [0.]
Label 9: [0.]
Label 10: [1.]
Label 11: [0.]
Label 12: [1.]
Label 13: [1.]
Label 14: [1.]
Label 15: [0.]
Label 16: [0.]
Label 17: [0.]
Label 18: [1.]
Label 19: [1.]
Label 20: [0.]


In [13]:
# Extract first 5 labels directly from val_generator
first_5_val_labels = []
for i in range(len(val_generator)):
    _, batch_labels = val_generator[i]
    for label in batch_labels:
        if len(first_5_val_labels) < 5:
            first_5_val_labels.append(label)
        else:
            break
    if len(first_5_val_labels) >= 5:
        break

print("First 5 labels in validation generator:")
for i, label in enumerate(first_5_val_labels):
    print(f"Label {i+1}: {label}")


First 5 labels in validation generator:
Label 1: [1.]
Label 2: [1.]
Label 3: [1.]
Label 4: [1.]
Label 5: [1.]


# **Data Validation: Ensure labels and spectrograms shapes match expectations.**

In [14]:
def play_audio(filename, audio_dir='/kaggle/input/birdclef-2025/train_audio', sample_rate=None):
    """
    Plays the audio from a specified file.

    Args:
        filename (str): The name of the audio file.
        audio_dir (str, optional): The directory where the audio file is located.
                                     Defaults to 'train_audio'.
        sample_rate (int, optional): The sample rate to play the audio at.
                                      If None, the original sample rate of the file is used.
    """
    filepath=filename
    if not os.path.exists(filepath):
        print(f"Error: Audio file not found at {filepath}")
        return

    try:
        audio, sr = librosa.load(filepath, sr=sample_rate)
        display(Audio(data=audio, rate=sr))
    except Exception as e:
        print(f"Error playing audio file {filename}: {e}")
analysis=False
if analysis:
    labels=np.load('/kaggle/working/mixed_sample_labels.npy',allow_pickle=True)
    for label in labels:
        filepath=os.path.join(audio_dir,label)
        print(label)
        play_audio(filepath)
    print(labels)
    play_audio('/kaggle/working/mixed_sample.wav',audio_dir=audio_dir)

In [15]:
IMPLEMENTATION='sigmoid'

In [16]:
working_dir=drive_dir
# state 1 0.4928
def write_initial_f1_file(filepath="best_val_f1.txt",score=0.00):
    """Writes a text file with "0.0" in the first line.
    Args:
      filepath: The path to the text file to be created or overwritten.
                Defaults to "best_val_f1.txt" in the current directory.
    """
    try:
        with open(filepath, 'w') as f:
            f.write(f"{score}\n")
        print(f"Successfully wrote '0.5236' to '{filepath}'")
    except Exception as e:
        print(f"An error occurred while writing to '{filepath}': {e}")

if IMPLEMENTATION=='softmax':
    target_filepath = os.path.join(working_dir, 'best_val_accuracy_softmax_mobilenet.txt')
    print(f"Attempting to write to: {target_filepath}")
    write_initial_f1_file(target_filepath,score=0.6014)

In [17]:

def build_bird_presence_model(input_shape, l2_reg=1e-4, base_trainable=True):
    """
    Builds a MobileNetV2 model for binary bird presence classification.

    Args:
        input_shape (tuple): Shape of the input spectrograms (height, width, channels).
        l2_reg (float): L2 regularization factor.
        base_trainable (bool): Whether to fine-tune the base model.

    Returns:
        tf.keras.Model: Compiled MobileNetV2 model for binary classification.
    """
    # Base MobileNetV2
    base_model = MobileNetV2(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape,
        pooling='avg'
    )
    base_model.trainable = base_trainable

    inputs = tf.keras.Input(shape=input_shape)
    x = layers.Rescaling(255)(inputs)  # If your data is [0, 1], rescale to [0, 255]
    x = mobilenet_preprocess_input(x)
    x = base_model(x, training=base_trainable)

    # Classification head for binary output
    x = layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(1, activation='sigmoid', kernel_regularizer=regularizers.l2(l2_reg))(x)

    model = models.Model(inputs=inputs, outputs=outputs, name='MobileNetV2_BirdPresence')
    return model

# Example usage:
input_shape = (320, 128, 3)  # Adjust as needed
model = build_bird_presence_model(input_shape, l2_reg=1e-4, base_trainable=True)
#model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [18]:
def write_initial_f1_file(filepath="best_val_f1.txt",score=0.00):
    """Writes a text file with "0.0" in the first line.
    Args:
      filepath: The path to the text file to be created or overwritten.
                Defaults to "best_val_f1.txt" in the current directory.
    """
    try:
        with open(filepath, 'w') as f:
            f.write(f"{score}\n")
        print(f"Successfully wrote '0.5236' to '{filepath}'")
    except Exception as e:
        print(f"An error occurred while writing to '{filepath}': {e}")
def ensure_bird_dir(drive_dir, primary_bird):
    bird_dir = os.path.join(drive_dir, primary_bird)
    if not os.path.exists(bird_dir):
        os.makedirs(bird_dir)
    return bird_dir

# Batch metrics callback
def print_batch_metrics(batch, logs):
    print(f"Batch {batch}: loss={logs['loss']:.4f}, Precision={logs.get('precision', 0):.3f}, Recall={logs.get('recall', 0):.3f}")
class ValidationF1Callback(tf.keras.callbacks.Callback):
    def __init__(self, val_data, model_save_path, metrics_path, patience=4):
        super().__init__()
        self.val_data = val_data
        self.model_save_path = model_save_path
        self.metrics_path = metrics_path
        with open(self.metrics_path, 'r') as f:
            self.best_f1 = float(f.readline())
        self.patience = patience
        self.wait = 0  # Number of epochs since last improvement

    def on_epoch_end(self, epoch, logs=None):
        y_true, y_pred = [], []
        for X_batch, y_batch in self.val_data:
            preds = self.model.predict(X_batch)
            y_true.extend(np.array(y_batch).reshape(-1))
            y_pred.extend((preds.reshape(-1) > 0.5).astype(int))
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        print(f"\nEpoch {epoch+1} Validation - Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")
        num_positives = sum(y_true)
        print(f"Validation positives: {num_positives}/{len(y_true)}")

        # Save model and F1 if improved
        if f1 >= self.best_f1:
            self.best_f1 = f1
            self.wait = 0  # Reset wait counter
            self.model.save(self.model_save_path)
            with open(self.metrics_path, 'w') as f:
                f.write(f"{f1:.6f}\n")
            print(f"Model and F1 saved for {self.model_save_path}")
        else:
            self.wait += 1
            print(f"No improvement in F1 for {self.wait} epoch(s).")

        # Early stopping condition
        if self.wait >= self.patience:
            print(f"Early stopping triggered. No improvement in F1 for {self.patience} consecutive epochs.")
            self.model.stop_training = True


all_labels=df['primary_label'].unique()
input_shape = (320, 128, 3)
base_model_path=os.path.join(models_dir, 'best_model_by_val_loss_softmax.keras')


batch_callback = tf.keras.callbacks.LambdaCallback(on_train_batch_end=print_batch_metrics)

all_birds=np.load(os.path.join(supplemental_files_dir,'non_birdnet_labels.npy'),allow_pickle=True)
for primary_bird in all_birds[3:]:
    num_files_for_bird=len(df[df['primary_label']==primary_bird])
    if num_files_for_bird<4:
      multiplier=5
    else:
      multiplier=5
    print('Files Multiplier: ',multiplier)
    bird_dir = ensure_bird_dir(models_dir, primary_bird)
    model_save_path = os.path.join(bird_dir, f'best_model_{primary_bird}.keras')
    target_metrics_filepath = os.path.join(bird_dir, f'best_val_f1_{primary_bird}.txt')

    train_generator, val_generator, train_dataset, val_dataset,tfnames,vfnames=create_dataset(primary_bird,train_filenames,val_filenames,files_multiplier=multiplier)
    print(val_generator.filenames)
    # Write initial F1 file if it doesn't exist
    if not os.path.exists(target_metrics_filepath):
        with open(target_metrics_filepath, 'w') as f:
            f.write('0.0\n')
        print(f"Initialized F1 file at: {target_metrics_filepath}")


    model = build_bird_presence_model(input_shape, l2_reg=1e-4, base_trainable=True)
    model.load_weights(base_model_path,skip_mismatch=True)
    print(f"Loaded Weights from : {base_model_path}")

    val_f1_callback = ValidationF1Callback(val_dataset, model_save_path, target_metrics_filepath)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=[
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )
    early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',      # Monitor validation loss
    patience=4,              # Stop after 4 epochs without improvement
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored quantity
)
    history = model.fit(
        train_dataset,
        epochs=10,
        validation_data=val_dataset,
        callbacks=[batch_callback, val_f1_callback,early_stopping]
    )


Files Multiplier:  5
Length of train_filenames 1178
Length of Validation Filenames 10
Train generator size: 10
Validation generator size: 1
Train generator size: 10
Validation generator size: 1
Batch 1 - Mel shape: (124, 320, 128, 3), Multi-label shape: (124, 1)
Multi-labels (first example in batch): [1.]
Original label index: 0 Original label: 126247
[('126247/iNat888527.ogg', False), ('126247/iNat888527.ogg', False), ('126247/iNat888527.ogg', False), ('126247/iNat888527.ogg', False), ('126247/iNat888527.ogg', False), ('blcant4/XC443836.ogg', False), ('21211/XC882658.ogg', False), ('strowl1/iNat843366.ogg', False), ('bkcdon/XC704853.ogg', False), ('grekis/XC735981.ogg', False)]
Loaded Weights from : /content/drive/MyDrive/Main_Birdclef/models/best_model_by_val_loss_softmax.keras
Epoch 1/10
Batch 0: loss=0.7498, Precision=0.702, Recall=0.485
      1/Unknown [1m112s[0m 112s/step - loss: 0.7498 - precision: 0.7021 - recall: 0.4853Batch 1: loss=0.4864, Precision=0.814, Recall=0.724
    

KeyboardInterrupt: 

In [26]:
def print_bird_accuracies(models_dir):
    print("Bird Model Accuracies:")
    for bird_name in all_birds:
        bird_dir = os.path.join(models_dir, bird_name)
        if os.path.isdir(bird_dir):
            metrics_file = os.path.join(bird_dir, f'best_val_f1_{bird_name}.txt')
            if os.path.isfile(metrics_file):
                try:
                    with open(metrics_file, 'r') as f:
                        score_str = f.readline().strip()
                        score = float(score_str)
                    print(f"{bird_name}: {score:.4f}")
                except Exception as e:
                    print(f"Error reading {metrics_file}: {e}")
            else:
                print(f"{bird_name}: No metrics file found")

# Example usage
print_bird_accuracies(models_dir)


Bird Model Accuracies:
1139490: 0.0000
1192948: 1.0000
1194042: 1.0000
126247: 0.9091
1346504: 1.0000
134933: 1.0000
135045: 1.0000
1462711: 0.9756
1462737: 1.0000
1564122: 1.0000
21038: 1.0000
21116: 1.0000
21211: 0.7500
22333: 0.8333
22973: 0.9091
22976: 0.8000
24272: 1.0000
24292: 1.0000
24322: 1.0000
41663: 0.6875
41778: 1.0000
41970: 1.0000
42007: 0.8462
42087: 0.4000
42113: 0.3333
46010: 0.7500
47067: 0.0833
476537: 1.0000
476538: 1.0000
48124: 0.7500
50186: 0.8000
517119: 0.7451
523060: 0.5714
528041: 0.3333
52884: 0.8571
548639: 0.8889
555086: 1.0000
555142: 1.0000
566513: 0.7059
64862: 1.0000
65336: 0.6667
65344: 0.8000
65349: 0.8000
65373: 1.0000
65419: 1.0000
65448: 0.9412
65547: 1.0000
65962: 0.6667
66016: 1.0000
66531: 1.0000
66578: 0.6061
66893: 1.0000
67082: 1.0000
67252: 1.0000
714022: 0.8889
715170: 1.0000
787625: 0.9091
81930: 0.1600
868458: 1.0000
963335: 1.0000
shghum1: 0.9091
