# Shallow classification for separating songs vs calls

## workflow:
1. select detections of species across dataset / stratification levels
- select 2-5x more clips than you plan to review in TDL, so that you still have enough after filtering out calls
2. select a random subset (ideally, stratified by point) of clips and annotate for song vs call
(not shown in this notebook bc in this case Brooke had already created some labels; I used labels from 20% of points for training and 80% for validation)
- eg, 200-500 clips total
- rapidly annotate with a binary label for "song" (or whatever the desired sound type is) in a notebook or in Dipper
- its ok to go fast and leave some clips as 'uncertain' or have a few wrong labels
2. embed all detections using Perch2 (works better than other foundation models for shallow classification)
3. train a shallow classifier on your annotations
- exclude any clips labeled 'uncertain' from train/validation
- use 80-90% of labels for training (I used 80% of _points_ for validation here just to demonstrate that the classifier generalizes very well)
4. apply the shallow classifier to all detections
5. filter the detections by the shallow classifier's score to retain songs and remove other sound types

In [7]:
import plot_utils # pip install plotly, anywidget
import pandas as pd
import bioacoustics_model_zoo as bmz
import torch

from sklearn.model_selection import train_test_split
from opensoundscape.ml.shallow_classifier import MLPClassifier, quick_fit

select an embedding model from the bioacoustics model zoo

In [8]:
embedder = bmz.Perch2()

                    This architecture is not listed in opensoundscape.ml.cnn_architectures.ARCH_DICT.
                    It will not be available for loading after saving the model with .save() (unless using pickle=True). 
                    To make it re-loadable, define a function that generates the architecture from arguments: (n_classes, n_channels) 
                    then use opensoundscape.ml.cnn_architectures.register_architecture() to register the generating function.

                    The function can also set the returned object's .constructor_name to the registered string key in ARCH_DICT

                    See opensoundscape.ml.cnn_architectures module for examples of constructor functions
                    


load detections

here, the detections already have some labels. 1 means song, 0 means not FISP, 'c' means call

our goal is to remove calls and only get the songs for review

In [9]:
from opensoundscape.data_selection import resample
train_clips=resample(pd.read_csv('Experiment4/train_set.csv', index_col=([0,1,2])), n_samples_per_class= 1267,random_state=0)
val_clips=pd.read_csv('Experiment4/valid_set.csv', index_col=([0,1,2]))
train_clips.KAAM_song.value_counts()

KAAM_song
True     1267
False    1267
Name: count, dtype: int64

In [10]:
train_clips.head()
len(train_clips)

2534

embed the samples with hawkears

Note: Perch2 takes about 30-75 seconds to get going but then runs quickly on the gpu

In [11]:
embs_train = embedder.embed(train_clips[[]], batch_size=32,num_workers=4)

  0%|          | 0/80 [00:00<?, ?it/s]

2026-01-08 14:21:40.859403: I external/local_xla/xla/service/service.cc:163] XLA service 0x7f6cc406ded0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2026-01-08 14:21:40.859467: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2026-01-08 14:21:40.859497: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (1): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2026-01-08 14:21:41.359218: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2026-01-08 14:21:41.441558: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91002
2026-01-08 14:21:41.666804: I external/local_xla/xla/service/gpu/autotuning/dot_search_space.cc:208] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints

In [12]:
embs_val = embedder.embed(val_clips[[]], batch_size=32,num_workers=4)

  0%|          | 0/16 [00:00<?, ?it/s]

2026-01-08 14:22:35.949110: I external/local_xla/xla/service/gpu/autotuning/dot_search_space.cc:208] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
2026-01-08 14:22:35.949153: I external/local_xla/xla/service/gpu/autotuning/dot_search_space.cc:208] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.
2026-01-08 14:22:35.949182: I external/local_xla/xla/service/gpu/autotuning/dot_search_space.cc:208] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs? Working around this by using the full hints set instead.










2026-01-08 14:22:43.169572: E 

In [13]:
embs_train.shape

(2534, 1536)

In [14]:
type(embs_train)        # should be pandas.core.frame.DataFrame

pandas.core.frame.DataFrame

In [15]:
# create and fit 1-layer MLP classifier on embeddings
clf = MLPClassifier(input_size=embs_train.shape[1], output_size=1, hidden_layer_sizes=())
y_train = train_clips['KAAM_song'].values.reshape(-1, 1).astype(float)
y_val   = val_clips['KAAM_song'].values.reshape(-1, 1).astype(float)
quick_fit(clf, train_features=embs_train.to_numpy(), train_labels=y_train,
          validation_features=embs_val.to_numpy() , validation_labels=y_val)

Epoch 100/1000, Loss: 0.294359028339386, Val Loss: 1.3898756504058838
val AU ROC: 0.782
val MAP: 0.451
Epoch 200/1000, Loss: 0.22988708317279816, Val Loss: 1.6369023323059082
val AU ROC: 0.775
val MAP: 0.435
Epoch 300/1000, Loss: 0.1988362967967987, Val Loss: 1.7770893573760986
val AU ROC: 0.772
val MAP: 0.432
Epoch 400/1000, Loss: 0.17799879610538483, Val Loss: 1.8734829425811768
val AU ROC: 0.769
val MAP: 0.431
Epoch 500/1000, Loss: 0.1620756983757019, Val Loss: 1.944496989250183
val AU ROC: 0.767
val MAP: 0.427
Epoch 600/1000, Loss: 0.14910900592803955, Val Loss: 1.9984170198440552
val AU ROC: 0.768
val MAP: 0.426
Epoch 700/1000, Loss: 0.1381618082523346, Val Loss: 2.040426254272461
val AU ROC: 0.768
val MAP: 0.433
Epoch 800/1000, Loss: 0.12870177626609802, Val Loss: 2.073974370956421
val AU ROC: 0.768
val MAP: 0.436
Epoch 900/1000, Loss: 0.12038873881101608, Val Loss: 2.1013917922973633
val AU ROC: 0.769
val MAP: 0.438
Epoch 1000/1000, Loss: 0.11298775672912598, Val Loss: 2.1242887

In [None]:
# save classifier weights
# torch.save(clf.state_dict(),'fisp_song_classifier_1layer.pth')

In [None]:
# to reload the classifier in another script/notebook:
if False:
    clf = MLPClassifier(input_size=embs.shape[1], output_size=1, hidden_layer_sizes=())
    clf.load_state_dict(torch.load('fisp_song_classifier_1layer.pth'))

NameError: name 'embs' is not defined

In [None]:
# make "KAAM_song" predictions on all clips by applying classifier to embeddings
val_clips['KAAM_song_score'] = clf(torch.tensor(embs_val.values,dtype=torch.float32)).detach().numpy()[:,0]
val_clips

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,KAAM_song,KAAM_song_score
file,start_time,end_time,Unnamed: 3_level_1,Unnamed: 4_level_1
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240322_110000.wav,93.0,96.0,True,5.220784
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240322_110000.wav,147.0,150.0,True,6.255454
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240324_085800.wav,78.0,81.0,True,4.698673
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240412_174100.wav,102.0,105.0,True,3.368534
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240412_174100.wav,108.0,111.0,True,4.580192
...,...,...,...,...
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,150.0,153.0,False,2.131114
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,165.0,168.0,False,2.066825
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,171.0,174.0,False,1.101629
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240604_121400.wav,9.0,12.0,False,1.152523


In [19]:
# filter clips to those with high song score
song_clips = val_clips[val_clips['KAAM_song_score']>0]
song_clips

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,KAAM_song,KAAM_song_score
file,start_time,end_time,Unnamed: 3_level_1,Unnamed: 4_level_1
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240322_110000.wav,93.0,96.0,True,5.220784
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240322_110000.wav,147.0,150.0,True,6.255454
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240324_085800.wav,78.0,81.0,True,4.698673
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240412_174100.wav,102.0,105.0,True,3.368534
/media/kiwi/datasets/finalized/kaua2024a/AKEK_grid/A1-0_11883/Data/SMM11883_20240412_174100.wav,108.0,111.0,True,4.580192
...,...,...,...,...
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,150.0,153.0,False,2.131114
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,165.0,168.0,False,2.066825
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240527_061500.wav,171.0,174.0,False,1.101629
/media/kiwi/datasets/finalized/kaua2024a/PUAI_OS/OS-UUK-20_11919/Data/SMM11919_20240604_121400.wav,9.0,12.0,False,1.152523


## Optional visualization

In [20]:
#plot_utils
# # requires plotly: pip install plotly
# from plot_utils import inspect
from opensoundscape import Audio, Spectrogram


import plotly.express as px
import plotly.graph_objects as go
from IPython.display import clear_output, display
import io
import base64
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML, display
from scipy.io import wavfile
from opensoundscape import Audio, Spectrogram


def plot_row(row):
    center_t = (row.start_time + row.end_time) / 2
    a = Audio.from_file(row.file, offset=center_t - 0.25, duration=0.5)
    s = Spectrogram.from_audio(a).bandpass(0, 3000)
    a.show_widget()
    s.plot()


def inspect(
    rows, dur=None, N=20, bandpass_range=None, dB_range=[-100, -20], cmap="Greys"
):
    rows = rows.sample(min(N, len(rows)))

    cells = []

    for _, row in rows.iterrows():
        if dur is None:
            start = row.start_time
            dur = row.end_time - row.start_time
        else:
            center_t = (row.start_time + row.end_time) / 2
            start = max(0, center_t - dur / 2)

        a = Audio.from_file(
            row.file,
            offset=start,
            duration=dur,
            out_of_bounds_mode="ignore",
        ).normalize()

        s = Spectrogram.from_audio(a)
        if bandpass_range is not None:
            s = s.bandpass(*bandpass_range)

        # --- spectrogram array ---
        spec = s.spectrogram  # (freq, time)

        spec = np.clip(spec, a_min=dB_range[0], a_max=dB_range[1])

        # --- render spectrogram to PNG ---
        fig, ax = plt.subplots(figsize=(2.2, 2.2))
        ax.imshow(
            spec,
            origin="lower",
            aspect="auto",
            cmap=cmap,
        )
        ax.axis("off")

        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
        plt.close(fig)

        img_b64 = base64.b64encode(buf.getvalue()).decode()

        # --- audio → WAV bytes ---
        wav_buf = io.BytesIO()

        samples = a.samples
        # if samples.ndim > 1:
        #     samples = samples.mean(axis=0)

        # normalize safely to int16
        samples = samples / max(1e-9, np.max(np.abs(samples)))
        samples_int16 = (samples * 32767).astype(np.int16)

        wavfile.write(wav_buf, a.sample_rate, samples_int16)

        audio_b64 = base64.b64encode(wav_buf.getvalue()).decode()

        cells.append(
            f"""
            <div class="cell">
                <img src="data:image/png;base64,{img_b64}"
                     onclick="this.nextElementSibling.play()"/>
                <audio src="data:audio/wav;base64,{audio_b64}"></audio>
            </div>
            """
        )

    html = f"""
    <style>
    .grid {{
        display: grid;
        grid-template-columns: repeat(auto-fill, minmax(140px, 1fr));
        gap: 10px;
    }}
    .cell {{
        cursor: pointer;
        border-radius: 6px;
        overflow: hidden;
        box-shadow: 0 2px 6px rgba(0,0,0,0.15);
        transition: transform 0.1s ease;
    }}
    .cell:hover {{
        transform: scale(1.03);
    }}
    .cell img {{
        width: 100%;
        display: block;
    }}
    </style>

    <div class="grid">
        {''.join(cells)}
    </div>
    """

    display(HTML(html))


import ipywidgets as widgets
from IPython.display import display

def get_selected_row_ids(fw):
    row_ids = []

    for tr in fw.data:
        if tr.selectedpoints is None:
            continue

        pts = np.asarray(tr.selectedpoints, dtype=int)
        row_ids.extend(tr.customdata[pts, 0])

    return np.unique(row_ids).astype(int)

def explore_features(
    df,
    x_col="x",
    y_col="y",
    color_col=None,
    symbol_col=None,
    size_col=None,
    hover_name_col=None,
    duration=0.5,
    N=12,
    dB_range=[-100, -20],
    bandpass_range=None,
    spec_window_samples=512,
):

    fig_out = widgets.Output()
    inspect_out = widgets.Output()

    df = df.copy()
    df["_row_id"] = np.arange(len(df))
    df["x"] = df[x_col]
    df["y"] = df[y_col]

    fig = px.scatter(
        df,
        x="x",
        y="y",
        color=color_col,
        hover_name=hover_name_col,
        symbol=symbol_col,
        size=size_col,
        opacity=0.8,
        custom_data=["_row_id"],
    )

    fw = go.FigureWidget(fig)
    # fw.selected_row_ids = np.array([], dtype=int)

    def on_select(trace, points, selector):
        # Only handle box selections
        if not hasattr(selector, "xrange"):
            return

        # xmin, xmax = selector.xrange
        # ymin, ymax = selector.yrange

        # selected = df[
        #     (df["x"] >= xmin)
        #     & (df["x"] <= xmax)
        #     & (df["y"] >= ymin)
        #     & (df["y"] <= ymax)
        # ]
        row_ids = []
        for tr in fw.data:
            # skip if trace is not visible

            if tr.selectedpoints is None:
                continue

            # selectedpoints are trace-local indices
            tr_points = np.asarray(tr.selectedpoints, dtype=int)

            # map to row ids via customdata
            tr_row_ids = tr.customdata[tr_points, 0]
            row_ids.extend(tr_row_ids)

        # fw.selected_row_ids = np.array(row_ids)

        row_ids = np.unique(row_ids).astype(int)
        if len(row_ids) == 0:
            return
        selected = df.iloc[row_ids]

        with inspect_out:
            inspect_out.clear_output(wait=True)
            print(f"{len(selected)} points selected")

            inspect(
                selected,
                dur=duration,
                N=N,
                bandpass_range=bandpass_range,
                dB_range=dB_range,
                cmap="Greys",
            )

    # Attach to ONE trace intentionally
    fw.data[0].on_selection(on_select)
    # for tr in fw.data:
    #     tr.on_selection(on_select)

    with fig_out:
        display(fw)

    display(fig_out, inspect_out)

    return fw

def make_label_buttons(fw, df, label_col="label"):
    btn0 = widgets.Button(
        description="Label selected = 0",
        button_style="danger",
        icon="times",
    )

    btn1 = widgets.Button(
        description="Label selected = 1",
        button_style="success",
        icon="check",
    )

    out = widgets.Output()

    def apply_label(label):
        with out:
            out.clear_output(wait=True)

            idx = get_selected_row_ids(fw)

            if len(idx) == 0:
                print("No points selected")
                return

            df.loc[idx, label_col] = label
            print(f"Labeled {len(idx)} points as {label}")

    btn0.on_click(lambda b: apply_label(0))
    btn1.on_click(lambda b: apply_label(1))

    display(widgets.HBox([btn0, btn1]), out)


import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
def explore_histogram(
    df,
    value_col,
    label_col,
    positive_value=1,
    negative_value=0,
    bins=30,
    sample_n=6,
    inspect_fn=None,
):

    # --- Widgets ---
    fig_out = widgets.Output()
    status = widgets.Output()

    show_pos = widgets.ToggleButton(
        description="Toggle Positives",
        value=True,
        button_style="success",

    )

    # grey for negatives
    show_neg = widgets.ToggleButton( 
        description="Toggle Negatives",
        value=True,
        button_style="",
    )

    sample_btn = widgets.Button(
        description="Inspect from visible",
        button_style="info",
    )

    # --- Figure ---
    fw = go.FigureWidget()

    df_pos = df[df[label_col] == positive_value]
    df_neg = df[df[label_col] == negative_value]

    fw.add_histogram(
        x=df_pos[value_col],
        nbinsx=bins,
        name="Positive",
        opacity=0.6,
        marker_color="green",
    )

    fw.add_histogram(
        x=df_neg[value_col],
        nbinsx=bins,
        name="Negative",
        opacity=0.6,
        marker_color="lightgrey",
    )

    fw.update_layout(
        barmode="overlay",
        dragmode="zoom",
        width=900,
        height=450,
    )

    # --- Logic ---
    def get_visible_range():
        r = fw.layout.xaxis.range
        if r is None:
            return None
        return float(r[0]), float(r[1])

    def get_selected_rows():
        r = get_visible_range()
        lo, hi = r if r else (df[value_col].min(), df[value_col].max())

        mask = (df[value_col] >= lo) & (df[value_col] <= hi)

        if show_pos.value and not show_neg.value:
            mask &= df[label_col] == positive_value
        elif show_neg.value and not show_pos.value:
            mask &= df[label_col] == negative_value

        return df[mask]

    def update_visibility(change=None):
        fw.data[0].visible = show_pos.value
        fw.data[1].visible = show_neg.value

    def on_sample(b):
        with status:
            status.clear_output(wait=True)
            sel = get_selected_rows()
            if len(sel) == 0:
                print("No samples selected")
                return
            sampled = sel.sample(min(sample_n, len(sel)))
            print(f"Inspecting {len(sampled)} samples")
            if inspect_fn:
                inspect_fn(sampled)

    show_pos.observe(update_visibility, names="value")
    show_neg.observe(update_visibility, names="value")
    sample_btn.on_click(on_sample)

    # --- Assemble UI ---
    controls = widgets.HBox([show_neg,show_pos,sample_btn])
    ui = widgets.VBox([fig_out, controls, status])

    # attach figure AFTER container exists
    fig_out.append_display_data(fw)

    return ui, fw


In [21]:
# show interactive histogram 
ui,fw=explore_histogram(val_clips.reset_index(),'KAAM_song_score',label_col='KAAM_song',inspect_fn=lambda rows: plot_utils.inspect(rows, dur=2),)
ui

VBox(children=(Output(outputs=({'output_type': 'display_data', 'data': {'application/vnd.jupyter.widget-view+j…

### inspect random clips from subsets

clips predicted to have song:

In [22]:
plot_utils.inspect(val_clips[val_clips.KAAM_song_score>0].reset_index(),N=16)

clips predicted to not have song:

In [24]:
plot_utils.inspect(val_clips[val_clips.KAAM_song_score<0].reset_index(),N=16)