In [20]:
import os
import gc
import warnings
import logging
import time
import math
import cv2
from pathlib import Path

import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm.auto import tqdm


logging.basicConfig(level=logging.ERROR)

from module import datasets_lib, models_lib, utils_lib, learning_lib, preprocess_lib, inference_lib, config_lib

In [21]:
"TODO: config はkaggle側で変えれたほうがよさそう．いちいちアップするのは面倒"
cfg = config_lib.CFG(mode='inference', kaggle_notebook=False, debug=True)

In [22]:
print(f"Using device: {cfg.device}")
print(f"Loading taxonomy data...")
taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
species_ids = taxonomy_df['primary_label'].tolist()
num_classes = len(species_ids)
print(f"Number of classes: {num_classes}")

Using device: cpu
Loading taxonomy data...
Number of classes: 206


In [23]:

def predict_on_spectrogram(audio_path, models, cfg, species_ids):
    """Process a single audio file and predict species presence for each 5-second segment"""
    predictions = []
    row_ids = []
    soundscape_id = Path(audio_path).stem
    
    try:
        print(f"Processing {soundscape_id}")
        audio_data, _ = librosa.load(audio_path, sr=cfg.FS)
        
        total_segments = int(len(audio_data) / (cfg.FS * cfg.WINDOW_SIZE))
        
        for segment_idx in range(total_segments):
            # intにキャスト
            start_sample = int(segment_idx * cfg.FS * cfg.WINDOW_SIZE)
            end_sample = int(start_sample + cfg.FS * cfg.WINDOW_SIZE)
            segment_audio = audio_data[start_sample:end_sample]
            
            end_time_sec = (segment_idx + 1) * cfg.WINDOW_SIZE
            row_id = f"{soundscape_id}_{end_time_sec}"
            row_ids.append(row_id)

            if cfg.use_tta:
                all_preds = []
                
                for tta_idx in range(cfg.tta_count):
                    mel_spec = preprocess_lib.process_audio_segment(segment_audio, cfg)
                    mel_spec = inference_lib.apply_tta(mel_spec, tta_idx)

                    mel_spec = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                    mel_spec = mel_spec.to(cfg.device)

                    if len(models) == 1:
                        with torch.no_grad():
                            outputs = models[0](mel_spec)
                            probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                            all_preds.append(probs)
                    else:
                        segment_preds = []
                        for model in models:
                            with torch.no_grad():
                                outputs = model(mel_spec)
                                probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                                segment_preds.append(probs)
                        
                        avg_preds = np.mean(segment_preds, axis=0)
                        all_preds.append(avg_preds)

                final_preds = np.mean(all_preds, axis=0)
            else:
                mel_spec = preprocess_lib.process_audio_segment(segment_audio, cfg)
                
                mel_spec = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                mel_spec = mel_spec.to(cfg.device)
                
                if len(models) == 1:
                    with torch.no_grad():
                        outputs = models[0](mel_spec)
                        final_preds = torch.sigmoid(outputs).cpu().numpy().squeeze()
                else:
                    segment_preds = []
                    for model in models:
                        with torch.no_grad():
                            outputs = model(mel_spec)
                            probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                            segment_preds.append(probs)

                    final_preds = np.mean(segment_preds, axis=0)
                    
            predictions.append(final_preds)
            
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
    
    return row_ids, predictions

In [24]:


def run_inference(cfg, models, species_ids):
    """Run inference on all test soundscapes"""
    test_files = list(Path(cfg.test_soundscapes).glob('*.ogg'))
    
    if cfg.debug:
        print(f"Debug mode enabled, using only {cfg.debug_count} files")
        test_files = test_files[:cfg.debug_count]
    
    print(f"Found {len(test_files)} test soundscapes")

    all_row_ids = []
    all_predictions = []

    for audio_path in tqdm(test_files):
        row_ids, predictions = predict_on_spectrogram(str(audio_path), models, cfg, species_ids)
        all_row_ids.extend(row_ids)
        all_predictions.extend(predictions)
    
    return all_row_ids, all_predictions



In [None]:
def main():
    start_time = time.time()
    print("Starting BirdCLEF-2025 inference...")
    print(f"TTA enabled: {cfg.use_tta} (variations: {cfg.tta_count if cfg.use_tta else 0})")

    models = models_lib.load_models(cfg, num_classes)
    
    if not models:
        print("No models found! Please check model paths.")
        return
    
    print(f"Model usage: {'Single model' if len(models) == 1 else f'Ensemble of {len(models)} models'}")

    row_ids, predictions = run_inference(cfg, models, species_ids)
    submission_df = utils_lib.create_submission(row_ids, predictions, species_ids, cfg)
    submission_path = os.path.join(cfg.OUTPUT_DIR, 'submission.csv')
    submission_df.to_csv(submission_path, index=False)
    print(f"Submission saved to {submission_path}")
    
    end_time = time.time()
    print(f"Inference completed in {(end_time - start_time)/60:.2f} minutes")

In [26]:
if __name__ == "__main__":
    main()

Starting BirdCLEF-2025 inference...
TTA enabled: False (variations: 0)
Found a total of 1 model files.
Loading model: ../models/model_fold0.pth
Model usage: Single model
Debug mode enabled, using only 3 files
Found 2 test soundscapes


  0%|          | 0/2 [00:00<?, ?it/s]

Processing H02_20230420_074000
Processing H02_20230420_112000
['H02_20230420_074000_5.0', 'H02_20230420_074000_10.0', 'H02_20230420_074000_15.0', 'H02_20230420_074000_20.0', 'H02_20230420_074000_25.0', 'H02_20230420_074000_30.0', 'H02_20230420_074000_35.0', 'H02_20230420_074000_40.0', 'H02_20230420_074000_45.0', 'H02_20230420_074000_50.0', 'H02_20230420_074000_55.0', 'H02_20230420_074000_60.0', 'H02_20230420_112000_5.0', 'H02_20230420_112000_10.0', 'H02_20230420_112000_15.0', 'H02_20230420_112000_20.0', 'H02_20230420_112000_25.0', 'H02_20230420_112000_30.0', 'H02_20230420_112000_35.0', 'H02_20230420_112000_40.0', 'H02_20230420_112000_45.0', 'H02_20230420_112000_50.0', 'H02_20230420_112000_55.0', 'H02_20230420_112000_60.0'] [array([0.00027689, 0.00019808, 0.00022631, 0.0002259 , 0.00030879,
       0.0001733 , 0.00034921, 0.00024073, 0.00019031, 0.00033517,
       0.00019377, 0.00017845, 0.00030298, 0.00034699, 0.00026703,
       0.00030773, 0.00024324, 0.00026817, 0.00031859, 0.0002936 