In [None]:
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import librosa
from scipy import signal as scipysig
from scipy.signal import find_peaks 
from scipy.stats import gaussian_kde
from natsort import natsorted

In [None]:
# Setting parameters
dir = '/Users/lillianwang/Documents/bird-counts-25/tetrahedral-data/'
birdnet_dir = dir+'birdnet/'
annotations_dir = dir+'annotations/'

In [None]:
def load_txts(dir):
    files = natsorted([f for f in os.listdir(dir) if f.endswith('.txt')])
    dfs = []
    for i, file in enumerate(files):
        df = pd.read_csv(dir+file, encoding='latin1', delimiter='\t')
        df['Recording'] = i+1
        dfs.append(df)
    return dfs

def load_csvs(dir):
    files = natsorted([f for f in os.listdir(dir) if f.endswith('.csv')])
    dfs = []
    for i, file in enumerate(files):
        df = pd.read_csv(dir+file, encoding='latin1')
        df['Recording'] = i+1
        dfs.append(df)
    return dfs

In [None]:
# Read data
annotations_dfs = load_txts(annotations_dir)
birdnet_dfs = load_csvs(birdnet_dir)

# Map species codes
key = pd.read_excel(annotations_dir+'CTC_Metadata.xlsx')
code_map = dict(zip(key['Code'], key['Common Name']))

for df in annotations_dfs:
    df['Species'] = df['Species'].map(code_map)

combined_birdnet = pd.concat(birdnet_dfs)
combined_annotations = pd.concat(annotations_dfs)

## Helper functions

In [None]:
# Setting parameters
n_fft = 1024
nf = 220
highpass_filter = 500

In [None]:
# ---- filtering function ----
def butter_highpass_filter(data, cut, fs, order=5):
    cutoff = 2 * cut / fs
    b, a = scipysig.butter(order, cutoff, btype='high', analog=False)
    y = scipysig.filtfilt(b, a, data)
    return y

# ---- function to convert linear units to dB ----
def logTransform(spec, scale=10**(-5)):
    return 20 * np.log10(spec + scale * np.ones(spec.shape))

# ---- function to load 4-channel audio and convert to B-format ----
def loadData(uploaded_file):
    # ---- read in file ----
    (s, framerate) = librosa.core.load(uploaded_file, sr=None, mono=False)

    # ---- get part of signal ----
    s = s/np.max(s)**2

    # ---- filter signal ----
    s = butter_highpass_filter(s, highpass_filter, framerate)
    s = s[:4, :]
    # ---- convert to b-format through matrix multiplication ----
    b_transform = np.asarray([[1, 1, 1, 1], [1, 1, -1, -1], [1, -1, 1, -1], [1, -1, -1, 1]])
    s_B = b_transform @ s

    return s, s_B, framerate

In [None]:
# Load in file
recording = 1
target_species = 'Red-eyed Vireo'
duration = 600
bin_num = 60
file = os.path.join(dir+f'recordings/{recording}.WAV')
s, s_B, framerate = loadData(file)

In [None]:
# Make b-format spectrogram
def spectrogram(file, s, s_B, framerate, start, end):
  end = int(np.ceil(end * framerate))
  if start is not None:
      start = int(np.trunc(start * framerate))
      s = s[:, start:end]

  b_transform = np.asarray([[1, 1, 1, 1], [1, 1, -1, -1], [1, -1, 1, -1], [1, -1, -1, 1]])
  s_B = b_transform @ s

  specs = []
  for num in np.arange(4):
      freqs, inds, spec = scipysig.stft(s_B[num,:], fs=framerate, nperseg=n_fft)
      nf_full = len(freqs)
      freqs = freqs[0:nf]
      specs.append(spec[0:nf, :].T)

  # directly get the three components
  w = specs[0] # p(f,t)
  x = specs[1] # v(f,t) x
  y = specs[2] # v(f,t)
  # azimuth values for all pixels
  azimuth = np.arctan2(np.real(w.conj() * y), np.real(w.conj() * x)) # eq 10

  # weight the azimuth values by the intensity
  weights = np.abs(w)**2

  # get grids for time and frequency
  f_grid, time_grid = np.meshgrid(freqs, inds)

  # need to set these parameters for histogram
  duration = len(s_B[0])/framerate
  time_step = 0.05
  num_time = int(duration * 1/time_step)
  num_azim = 60

  # histogram
  hist, azim_edges, time_edges = np.histogram2d(x = azimuth.ravel(), y = time_grid.ravel(),
                                                bins=[num_azim, num_time],
                                                weights = weights.ravel())

  log_hist = np.log(hist + 0.01 * np.ones(hist.shape))

  return azimuth, weights, freqs, inds, hist, azim_edges, time_edges, log_hist

In [None]:
# Find timestamps of audio segments based on annotations
def annotations_timestamps(recording=recording, target_species=None, start_time=0, end_time=duration):
    df = annotations_dfs[recording-1]
    
    if target_species is None:
      annotations_rows = df
    else:
      annotations_rows = df[df['Species'] == target_species]

    selections = []
    timestamps = []
    species = []
    low_freqs = []
    high_freqs = []

    for i in range(len(annotations_rows)):
      if not isinstance(annotations_rows.iloc[i]['Species'], str):
        continue

      start = annotations_rows.iloc[i]['Begin Time (s)']
      end = annotations_rows.iloc[i]['End Time (s)']
      
      if start >= start_time and end <= end_time:
        selections.append(annotations_rows.iloc[i]['Selection'])
        timestamps.append((start, end))
        low_freqs.append(annotations_rows.iloc[i]['Low Freq (Hz)'])
        high_freqs.append(annotations_rows.iloc[i]['High Freq (Hz)'])
        species.append(annotations_rows.iloc[i]['Species'])
      else:
        continue

    return selections, timestamps, species, low_freqs, high_freqs

In [None]:
# Absolute peak of histogram
def absolute_peaks(azimuth, weights):
  counts, bin_edges = np.histogram(azimuth, weights=weights, bins=bin_num)
  peaks, _ = find_peaks(counts, height=np.max(counts))
  bin_centers = 0.5*(bin_edges[1:]+bin_edges[:-1])
  return counts, bin_centers[peaks]

# Histogram peaks using percent of peak method
def histogram_peaks(azimuth, weights):
  counts, bin_edges = np.histogram(azimuth, weights=weights, bins=bin_num)
  peaks, _ = find_peaks(counts, prominence=0.5*np.max(counts))
  bin_centers = 0.5*(bin_edges[1:]+bin_edges[:-1])
  return counts, bin_centers[peaks]

# Histogram peaks using kernel density estimation method
def kde_peaks(azimuth, weights):
  kde = gaussian_kde(azimuth, weights=weights)
  x_kde = np.linspace(np.min(azimuth), np.max(azimuth))
  y_kde = kde(x_kde)
  kde_peaks_indices, _ = find_peaks(y_kde, prominence=0.01)
  kde_peaks = x_kde[kde_peaks_indices]
  return kde_peaks, x_kde, y_kde

In [None]:
# Filter out azimuths outside range [low_freq, high_freq]
def azimuth_filter(azimuth, weights, freqs, low_freq, high_freq):
    freq_range = (freqs >= low_freq) & (freqs <= high_freq)

    filtered_azimuths = azimuth[:, freq_range]
    filtered_weights = weights[:, freq_range]
    filtered_freqs = freqs[freq_range]

    return filtered_azimuths, filtered_weights, filtered_freqs

## Azimuth

In [None]:
# Annotations on azigram
fig, ax = plt.subplots(1, 1, figsize=(14, 4))

start = 0
end = 30

azimuth, weights, freqs, inds, *_ = spectrogram(file, s, s_B, framerate, start, end)
ax.pcolormesh(inds, freqs, azimuth.T, cmap='turbo')

selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, start_time=start, end_time=end)

for i in range(len(timestamps)):
  x = timestamps[i][0] - start
  y = low_freqs[i]
  width = timestamps[i][1] - timestamps[i][0]
  height = high_freqs[i] - low_freqs[i]
  rect = mpatches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
  ax.add_patch(rect)

ax.set_xlabel('Time (sec)')
ax.set_ylabel('Frequency (Hz)')
ax.set_title(f'Azigram ({start}-{end}s)')

fig.tight_layout()
fig.savefig('azigram-with-annotations.png', dpi=200)

In [None]:
# Azimuth over time from annotations (size weighted by amplitude)
plt.figure(figsize=(8, 6))

x_vals = []
y_vals = []
sizes = []

selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, target_species)

for i in range(len(timestamps)):
  start, end = (timestamps[i][0], timestamps[i][1])
  if end - start < 0.05:
    continue

  azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)

  azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])

  azimuth = azimuth.flatten()
  weights = weights.flatten()
  amplitude = np.sqrt(weights)

  counts, peaks = absolute_peaks(azimuth, weights)
  #peaks, x_kde, y_kde = kde_peaks(azimuth, weights)
  #counts, peaks = histogram_peaks(azimuth, weights)

  for peak in peaks:
    x_vals.append(start)
    y_vals.append(peak)
    sizes.append(np.sum(amplitude))

plt.scatter(x_vals, y_vals, s=[size*5 for size in sizes], alpha=.75, color='b')
plt.ylim(-np.pi, np.pi)
plt.grid(True)

plt.title(f'{target_species} azimuth over time')
plt.xlabel('Time (s)')
plt.ylabel('Azimuth')

plt.tight_layout()
plt.savefig(f'{target_species}-time-vs-azimuth.png', dpi=200)
plt.show()


In [None]:
# Clustering based on time, azimuth, and amplitude (DBSCAN)
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
plt.figure(figsize=(8, 6))

X = np.array([x_vals, y_vals, sizes]).T
X_scaled = StandardScaler().fit_transform(X)

db = DBSCAN(eps=.75, min_samples=5).fit(X_scaled)
labels = db.labels_
n_clusters = len(set(labels))
if -1 in labels:
  n_clusters -= 1

noise = labels == -1
cluster = labels != -1
x_vals = np.array(x_vals)
y_vals = np.array(y_vals)
sizes = np.array(sizes)

print('Number of clusters: '+str(n_clusters))

plt.scatter(x_vals[cluster], y_vals[cluster], c=labels[cluster], cmap=plt.cm.turbo,
            s=[size*5 for size in sizes[cluster]], alpha=0.75)
plt.scatter(x_vals[noise], y_vals[noise], c='k', s=[size*5 for size in sizes[noise]], alpha=0.5)
plt.ylim(-np.pi, np.pi)
plt.grid(True)

plt.title(f'{target_species} azimuth over time (DBSCAN)\n{n_clusters} clusters')
plt.xlabel('Time (s)')
plt.ylabel('Azimuth')

plt.tight_layout()
plt.savefig(f'{target_species}-dbscan.png', dpi=200)
plt.show()

In [None]:
# Clustering based on time, azimuth, and amplitude (BGM)
from sklearn.mixture import BayesianGaussianMixture
plt.figure(figsize=(8, 6))

X = np.array([x_vals, y_vals, sizes]).T
X_scaled = StandardScaler().fit_transform(X)

bgm = BayesianGaussianMixture(tol=1e-2, n_components=20, weight_concentration_prior=.1)
labels = bgm.fit_predict(X_scaled)

active_clusters = np.sum(bgm.weights_ > .1)

scatter = plt.scatter(x_vals, y_vals, c=labels, cmap=plt.cm.turbo, s=[size*5 for size in sizes], alpha=0.75)
plt.ylim(-np.pi, np.pi)

plt.title(f'{target_species} azimuth over time (BGM)\n{active_clusters} clusters')
plt.xlabel('Time (s)')
plt.ylabel('Azimuth')

plt.tight_layout()
plt.savefig(f'{target_species}-bgm.png', dpi=200)
plt.show()

In [None]:
# All species azimuth over time from annotations (size weighed by amplitude)
selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording)
species = set(species)
figs, axes = plt.subplots(nrows=len(species), ncols=1, figsize=(8, len(species)*4))

for k, target_species in enumerate(species):
  x_vals = []
  y_vals = []
  sizes = []

  if not isinstance(target_species, str):
    continue

  selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, target_species)

  for i in range(len(timestamps)):
    start, end = (timestamps[i][0], timestamps[i][1])
    if end - start < 0.05:
      continue

    azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)

    azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])

    azimuth = azimuth.flatten()
    weights = weights.flatten()
    amplitude = np.sqrt(weights)

    counts, peaks = absolute_peaks(azimuth, weights)
    #peaks, x_kde, y_kde = kde_peaks(azimuth, weights)
    #counts, peaks = histogram_peaks(azimuth, weights)

    for peak in peaks:
      x_vals.append(start)
      y_vals.append(peak)
      sizes.append(np.sum(amplitude))

  axes[k].scatter(x_vals, y_vals, s=[size*5 for size in sizes], alpha=.75, color='b')

  axes[k].set_title(f'{target_species} azimuth over time')
  axes[k].set_xlabel('Time (s)')
  axes[k].set_ylabel('Azimuth')
  axes[k].set_ylim(-np.pi, np.pi)
  axes[k].grid(True)

plt.tight_layout()
plt.savefig('time-vs-azimuth-all-species.png', dpi=200)
plt.show()


In [None]:
# Combined azimuth over time
plt.figure(figsize=(12, 9))

x_vals = []
y_vals = []
sizes = []
colors = []

selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording)

species_list = sorted(set(species))  # Consistent order
species_to_color = {sp: i for i, sp in enumerate(species_list)}

for i in range(len(timestamps)):
  start, end = timestamps[i][0], timestamps[i][1]
  if end - start < 0.05:
    continue
  azimuth, weights, *_ = spectrogram(file, s, s_B, framerate, start, end)

  azimuth = azimuth.flatten()
  weights = weights.flatten()
  amplitude = np.sqrt(weights)

  counts, peaks = absolute_peaks(azimuth, weights)
  #peaks, x_kde, y_kde = kde_peaks(azimuth, weights)
  #counts, peaks = histogram_peaks(azimuth, weights)

  for peak in peaks:
    x_vals.append(start)
    y_vals.append(peak)
    sizes.append(np.sum(amplitude))
    colors.append(species_to_color[species[i]])

plt.scatter(x_vals, y_vals, c=colors, cmap='turbo', s=[size * 5 for size in sizes], alpha=.75)
cmap = plt.cm.turbo
num_species = len(species_list)
patches = [
    mpatches.Patch(color=cmap(species_to_color[sp]/(num_species-1)), label=sp)
    for sp in species_list
]

plt.title('Azimuth over time')
plt.xlabel('Time (s)')
plt.ylabel('Azimuth')
plt.ylim(-np.pi, np.pi)
plt.grid(True)
plt.legend(handles=patches, title='Species',loc=(1.05, 0))

plt.tight_layout()
plt.savefig('time-vs-azimuth-combined.png', dpi=200)
plt.show()

In [None]:
# Polar plots by species (weighted)
selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording)
species = set(species)
figs, axes = plt.subplots(nrows=len(species), ncols=1, figsize=(12, len(species)*6), subplot_kw={'polar': True})

for k, bird in enumerate(species):
  azimuth_list = []
  weights_list = []

  if not isinstance(bird, str):
    continue

  selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, bird)

  for i in range(len(timestamps)):
    start, end = (timestamps[i][0], timestamps[i][1])
    if end - start < 0.05:
      continue

    azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)
    azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])

    azimuth_list.append(azimuth.flatten())
    weights_list.append(weights.flatten())

  if len(azimuth_list) == 0:
    continue

  counts, bin_edges = np.histogram(np.concatenate(azimuth_list), weights=np.concatenate(weights_list), bins=bin_num)
  bin_centers = 0.5*(bin_edges[1:]+bin_edges[:-1])

  axes[k].plot(bin_centers, counts, color='b')
  axes[k].set_title(f'{bird} weighted azimuth', pad=30)
  axes[k].set_xlabel('Azimuth (degrees)')
  axes[k].grid(True)

plt.tight_layout()
plt.savefig('polar-weighted-all-species.png', dpi=200)
plt.show()


In [None]:
# Polar plots by species (unweighted)
selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording)
species = set(species)
figs, axes = plt.subplots(nrows=len(species), ncols=1, figsize=(12, len(species)*6), subplot_kw={'polar': True})

for k, bird in enumerate(species):
  azimuth_list = []
  weights_list = []

  if not isinstance(bird, str):
    continue

  selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, bird)

  for i in range(len(timestamps)):
    start, end = (timestamps[i][0], timestamps[i][1])
    if end - start < 0.05:
      continue

    azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)
    azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])
    
    azimuth_list.append(azimuth.flatten())

  if len(azimuth_list) == 0:
    continue

  counts, bin_edges = np.histogram(np.concatenate(azimuth_list), bins=bin_num)
  bin_centers = 0.5*(bin_edges[1:]+bin_edges[:-1])

  axes[k].plot(bin_centers, counts, color='b')
  
  axes[k].set_title(bird+' unweighted azimuth', pad=30)
  axes[k].set_xlabel('Azimuth (degrees)')
  axes[k].grid(True)

plt.tight_layout()
plt.savefig('polar-unweighted-all-species.png', dpi=200)
plt.show()


In [None]:
target_species = 'Common Yellowthroat'

In [None]:
# Unweighted cumulative histogram
plt.figure(figsize=(12, 9))

selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, target_species)

azimuth_list = []

for i in range(len(timestamps)):
  start, end = (timestamps[i][0], timestamps[i][1])
  if end - start < 0.05:
    continue

  azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)
  azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])

  azimuth_list.append(azimuth.flatten())

azimuth_list = np.concatenate(azimuth_list)

plt.hist(azimuth_list, bins=bin_num, color='b', density=True)

plt.title(f'{target_species} cumulative histogram of azimuth')
plt.xlabel('Azimuth (radians)')
#axes[i].plot(x_kde, y_kde, linewidth=.75, label='KDE', color='skyblue')

plt.tight_layout()
plt.savefig(f'{target_species}-cumulative-histogram.png', dpi=200)
plt.show()

In [None]:
# Histograms of annotations
selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, target_species)
fig, axes = plt.subplots(nrows=len(timestamps), ncols=1, figsize=(8, len(timestamps)*4))

x_vals = []
y_vals = []
sizes = []

for i in range(len(timestamps)):
  start, end = (timestamps[i][0], timestamps[i][1])
  if end - start < 0.05:
    continue
  azimuth, weights, freqs, *_ = spectrogram(file, s, s_B, framerate, start, end)
  azimuth, weights, freqs = azimuth_filter(azimuth, weights, freqs, low_freqs[i], high_freqs[i])

  azimuth = azimuth.flatten()
  weights = weights.flatten()

  if len(azimuth) <= 1:
    continue

  counts, peaks = absolute_peaks(azimuth, weights)
  #peaks, x_kde, y_kde = kde_peaks(azimuth, weights)
  #counts, peaks = histogram_peaks(azimuth, weights)

  axes[i].hist(azimuth, weights=weights, bins=bin_num, color='b', density=True)
  ymin, ymax = axes[i].get_ylim()
  axes[i].vlines(peaks, ymin=0, ymax=ymax, colors='red', linestyles='dashed', linewidth=0.75)
  #axes[i].plot(x_kde, y_kde, linewidth=.75, label='KDE', color='skyblue')

  axes[i].set_title(f'{target_species} weighted histogram of azimuth ({recording}.{selections[i]})')
  axes[i].set_xlabel('Azimuth (radians)', fontsize=10)
  #print(f'{recording}.{selections[i]} azimuth peaks: {peaks}')

plt.tight_layout()
plt.savefig(f'{target_species}-azimuth-histograms.png', dpi=200)
plt.show()

## Statistical analysis

In [None]:
# Find timestamps of audio segments containing target vocalization
def birdnet_timestamps(recording, target_species=None, start_time=0, end_time=duration):
    df = birdnet_dfs[recording-1]

    if target_species is None:
      birdnet_rows = df
    else:
      birdnet_rows = df[df['Common name'] == target_species]

    timestamps = []
    species = []
    confidence = []

    for i in range(len(birdnet_rows)):
      start = birdnet_rows.iloc[i]['Start (s)']
      end = birdnet_rows.iloc[i]['End (s)']
      if start >= start_time and end <= end_time:
        timestamps.append((start, end))
        confidence.append(birdnet_rows.iloc[i]['Confidence'])
        if target_species is None:
          species.append(birdnet_rows.iloc[i]['Common name'])
      else:
        continue

    return timestamps, species, confidence

In [None]:
# Convert annotations to 3s windows with 1.5s overlap
def annotations_to_windows(recording, target_species=None, start_time=0, end_time=duration):
  windows = []
  ind = 10
  while ind < duration:
    windows.append((ind, ind+3))
    ind += 1.5

  selections, timestamps, species, low_freqs, high_freqs = annotations_timestamps(recording, target_species)

  new_timestamps = []
  new_species = []

  for i in range(len(windows)):
    i_start, i_end = windows[i]
    window_duration = i_end - i_start

    candidate_timestamps = []
    candidate_species = []

    for j in range(len(timestamps)):
      j_start, j_end = timestamps[j]
      annotation_duration = j_end - j_start

      max_overlap = 0

      # Calculate overlap
      overlap_start = max(i_start, j_start)
      overlap_end = min(i_end, j_end)
      overlap = max(0, overlap_end - overlap_start)

      if overlap >= 0.1 * annotation_duration and overlap > max_overlap:
        max_overlap = overlap
        candidate_timestamps.append((i_start, i_end))
        candidate_species.append(species[j])

    if len(candidate_timestamps) > 0:
      new_timestamps.append(candidate_timestamps[-1])
      new_species.append(candidate_species[-1])

  return new_timestamps, new_species

### Overall

In [None]:
# Total number of detections and annotations
detection_timestamps, detection_species, confidence = birdnet_timestamps(recording)
annotation_timestamps, annotation_species = annotations_to_windows(recording)

print('Number of detections: '+str(len(detection_timestamps)))
print('Number of annotations: '+str(len(annotation_timestamps)))

In [None]:
# Precision and recall at different thresholds (based on overlap)
print('10% overlap between annotation and detection')
print('-----------------------------------------------')

thresholds = np.round(np.arange(0, 0.4, 0.05), 2)

for threshold in thresholds:
  print(f'\n{threshold} confidence threshold')

  filtered_timestamps = []
  filtered_species = []

  for t, sp, c in zip(detection_timestamps, detection_species, confidence):
    if c >= threshold:
        filtered_timestamps.append(t)
        filtered_species.append(sp)
  print(f'Number of detections: {len(filtered_timestamps)}')

  true_positives = 0
  false_positives = 0
  false_negatives = 0

  for i in range(len(filtered_timestamps)):
    i_start, i_end = filtered_timestamps[i]
    matched = False

    for j in range(len(annotation_timestamps)):
        j_start, j_end = annotation_timestamps[j]

        overlap_start = max(i_start, j_start)
        overlap_end = min(i_end, j_end)
        overlap = max(0, overlap_end-overlap_start)

        if overlap >= 0.3 and filtered_species[i] == annotation_species[j]:
            matched = True
            break

    if matched:
        true_positives += 1
    else:
        false_positives += 1

  matched = None

  for i in range(len(annotation_timestamps)):
    i_start, i_end = annotation_timestamps[i]
    matched = False

    for j in range(len(filtered_timestamps)):
        j_start, j_end = filtered_timestamps[j]

        overlap_start = max(i_start, j_start)
        overlap_end = min(i_end, j_end)
        overlap = max(0, overlap_end-overlap_start)

        if overlap >= 0.3 and annotation_species[i] == filtered_species[j]:
            matched = True
            break

    if matched:
      continue
    else:
      false_negatives += 1

  print(f'True positives: {true_positives}')
  print(f'False positives: {false_positives}')
  print(f'False negatives: {false_negatives}')

  p = true_positives/(true_positives + false_positives) if (true_positives+false_positives) > 0 else 0
  r = true_positives/(true_positives + false_negatives) if (true_positives+false_negatives) > 0 else 0

  print(f'Precision: {p:.3f}')
  print(f'Recall: {r:.3f}')

In [None]:
# Build precision-recall curve
plt.figure(figsize=(8, 6))

thresholds = np.arange(0, 1, 0.01)
precision = []
recall = []

for threshold in thresholds:
  # Filter based on the confidence threshold
  filtered_timestamps = []
  filtered_species = []

  for t, sp, c in zip(detection_timestamps, detection_species, confidence):
      if c >= threshold:
          filtered_timestamps.append(t)
          filtered_species.append(sp)

  true_positives = 0
  false_positives = 0
  false_negatives = 0

  for i in range(len(filtered_timestamps)):
      i_start, i_end = filtered_timestamps[i]
      matched = False

      for j in range(len(annotation_timestamps)):
          j_start, j_end = annotation_timestamps[j]

          overlap_start = max(i_start, j_start)
          overlap_end = min(i_end, j_end)
          overlap = max(0, overlap_end-overlap_start)

          if overlap >= 0.3 and filtered_species[i] == annotation_species[j]:
              matched = True
              break

      if matched:
          true_positives += 1
      else:
          false_positives += 1

  for i in range(len(annotation_timestamps)):
    i_start, i_end = annotation_timestamps[i]
    matched = False

    for j in range(len(detection_timestamps)):
        j_start, j_end = detection_timestamps[j]

        overlap_start = max(i_start, j_start)
        overlap_end = min(i_end, j_end)
        overlap = max(0, overlap_end-overlap_start)

        if overlap >= 0.3 and annotation_species[i] == detection_species[j]:
            matched = True
            break

    if matched:
      continue
    else:
      false_negatives += 1

  p = true_positives/(true_positives+false_positives) if (true_positives+false_positives) > 0 else 0
  r = true_positives/(true_positives+false_negatives) if (true_positives+false_negatives) > 0 else 0

  precision.append(p)
  recall.append(r)

plt.plot(recall, precision, color='b')
plt.grid(True)
plt.xlim(0, 1)
plt.ylim(0, 1)

plt.title('Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')

### By species

In [None]:
# Find set of all species
detection_timestamps, detection_species, confidence = birdnet_timestamps(1, start_time=0, end_time=600)
annotation_timestamps, annotation_species = annotations_to_windows(1)

detection_species=set(detection_species)
annotation_species=set(annotation_species)
all_species = detection_species.union(annotation_species)

In [None]:
# Precision and recall by species
print('By each species (0.5 confidence threshold)')
print('-----------------------------------------------')

for bird in all_species:
  if not isinstance(bird, str):
    continue

  print('\n' + bird)

  true_positives = 0
  false_positives = 0
  false_negatives = 0

  detection_timestamps, *_ = birdnet_timestamps(1, target_species=bird, start_time=0, end_time=600)
  annotation_timestamps, *_ = annotations_to_windows(1, target_species=bird)

  # Loop through detections to find true and false pos
  for i in range(len(detection_timestamps)):
    i_start, i_end = detection_timestamps[i]
    matched = False

    for j in range(len(annotation_timestamps)):
        j_start, j_end = annotation_timestamps[j]

        overlap_start = max(i_start, j_start)
        overlap_end = min(i_end, j_end)
        overlap = max(0, overlap_end - overlap_start)

        if overlap >= 0.3:
            matched = True
            break

    if matched:
      true_positives += 1
    else:
      false_positives += 1

  # Loop through annotations to find false neg
  for i in range(len(annotation_timestamps)):
    i_start, i_end = annotation_timestamps[i]
    matched = False

    for j in range(len(detection_timestamps)):
        j_start, j_end = detection_timestamps[j]

        overlap_start = max(i_start, j_start)
        overlap_end = min(i_end, j_end)
        overlap = max(0, overlap_end - overlap_start)

        if overlap >= 0.3:
            matched = True
            break

    if matched:
      continue
    else:
      false_negatives += 1

  print(f'\tTrue positives: {true_positives}')
  print(f'\tFalse positives: {false_positives}')
  print(f'\tFalse negatives: {false_negatives}')

  p = true_positives/(true_positives + false_positives) if (true_positives+false_positives) > 0 else 0
  r = true_positives/(true_positives + false_negatives) if (true_positives+false_negatives) > 0 else 0

  print(f'\tPrecision: {p:.3f}')
  print(f'\tRecall: {r:.3f}')

In [None]:
# Precision-recall curves by species
fig, axes = plt.subplots(nrows=len(all_species), ncols=1, figsize=(8, len(all_species)*4))

for k, bird in enumerate(all_species):
  if not isinstance(bird, str):
    continue

  detection_timestamps, detection_species, confidence = birdnet_timestamps(1, target_species=bird, start_time=0, end_time=120)
  annotation_timestamps, annotation_species = annotations_to_windows(1, target_species=bird)

  precision = []
  recall = []

  # Filter based on confidence threshold
  thresholds = np.arange(0, 1, 0.01)

  for threshold in thresholds:
    filtered_timestamps = []

    for t, c in zip(detection_timestamps, confidence):
      if c >= threshold:
          filtered_timestamps.append(t)

    true_positives = 0
    false_positives = 0
    false_negatives = 0

    # Loop through filtered detections to find true and false pos
    for i in range(len(filtered_timestamps)):
      i_start, i_end = filtered_timestamps[i]
      matched = False

      for j in range(len(annotation_timestamps)):
          j_start, j_end = annotation_timestamps[j]

          overlap_start = max(i_start, j_start)
          overlap_end = min(i_end, j_end)
          overlap = max(0, overlap_end - overlap_start)

          if overlap >= 0.3:
              matched = True
              break

      if matched:
        true_positives += 1
      else:
        false_positives += 1

    # Loop through annotations to find false neg
    for i in range(len(annotation_timestamps)):
      i_start, i_end = annotation_timestamps[i]
      matched = False

      for j in range(len(detection_timestamps)):
          j_start, j_end = detection_timestamps[j]

          overlap_start = max(i_start, j_start)
          overlap_end = min(i_end, j_end)
          overlap = max(0, overlap_end - overlap_start)

          if overlap >= 0.3:
              matched = True
              break

      if matched:
        continue
      else:
        false_negatives += 1

    p = true_positives/(true_positives+false_positives) if (true_positives+false_positives) > 0 else 0
    r = true_positives/(true_positives+false_negatives) if (true_positives+false_negatives) > 0 else 0

    precision.append(p)
    recall.append(r)

  axes[k].plot(recall, precision, color='b')
  axes[k].grid(True)
  axes[k].set_xlim(0, 1)
  axes[k].set_ylim(0, 1)

  axes[k].set_title(bird)
  axes[k].set_xlabel('Recall')
  axes[k].set_ylabel('Precision')

plt.tight_layout()
plt.show()

### Other

In [None]:
# Precision and recall (based on presence/absence)
print('Species presence/absence (.1 confidence threshold)')
print('-----------------------------------------------')

detection_timestamps, detection_species, confidence = birdnet_timestamps(1, start_time=0, end_time=600)
annotation_timestamps, annotation_species = annotations_to_windows(1)

detection_species=set(detection_species)
annotation_species=set(annotation_species)

true_positives = 0
false_positives = 0
false_negatives = 0

for bird in detection_species:
  if bird in annotation_species:
    true_positives += 1
  else:
    false_positives += 1

for bird in annotation_species:
  if not bird in detection_species:
    false_negatives += 1

print(f'True positives: {true_positives}')
print(f'False positives: {false_positives}')
print(f'False negatives: {false_negatives}')

precision = true_positives/(true_positives+false_positives)
print(f'Precision: {precision}')

recall = true_positives/(true_positives+false_negatives)
print(f'Recall: {recall}')