<a href="https://colab.research.google.com/github/noahdanieldsouza/PAM-classification/blob/main/run_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio numpy

In [None]:
!pip install git+https://github.com/google-research/perch-hoplite.git@782acd0e409eb27df51a695de4cb6608dae0db25

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Correct way to copy a folder:
!cp -r  /content/drive/MyDrive/labeled_fish/DB /content

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
#@title Load model and connect to database. { vertical-output: true }

#@markdown Location of database containing audio embeddings.
db_path = '/content/DB'  #@param {type:'string'}
#@markdown Identifier (eg, name) to attach to labels produced during validation.
annotator_id = 'linnaeus'  #@param {type:'string'}

db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
if hasattr(embedding_model, 'window_size_s'):
  window_size_s = embedding_model.window_size_s
else:
  window_size_s = 5.0
audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)

print("Embeddings in DB:", len(db.get_embedding_ids()))
print(db.get_classes())

In [None]:


import os
import torch
import torchaudio
import numpy as np
import csv
import gc
import tempfile

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from perch_hoplite.agile import colab_utils, embed, source_info
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.zoo import model_configs
from perch_hoplite.agile.classifier import LinearClassifier

# --- Paths ---
db_path = '/content/drive/MyDrive/labeled_fish/DB'
classifier_path = '/content/drive/MyDrive/labeled_fish/DB/agile_classifier_v2.pt'
wav_path = '/content/drive/MyDrive/labeled_fish/white_sucker1.wav'
output_csv_path = '/content/chunk_classification_results.csv'

# --- Load DB + model ---
db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
model_class = model_configs.get_model_class(db_model_config['model_key'])
embedding_model = model_class.from_config(db_model_config['model_config'])
embedding_ids = db.get_embedding_ids()


# --- Load classifier ---
classifier = LinearClassifier.load(classifier_path)
class_names = classifier.classes
print("‚úÖ Loaded classifier with classes:", class_names)



# --- Audio config ---
sample_rate = embedding_model.sample_rate
window_size_s = getattr(embedding_model, 'window_size_s', 5.0)
chunk_samples = int(window_size_s * sample_rate)

# --- Output CSV header ---
with open(output_csv_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['start_time', 'end_time', 'predicted_class', 'confidence'])

# --- Load audio info ---
info = torchaudio.info(wav_path)
total_frames = info.num_frames

# --- Process chunks ---
for i in range(0, total_frames, chunk_samples):
    if i + chunk_samples > total_frames:
        break
    print(f"üîÑ Processing chunk: {i / sample_rate:.1f}s to {(i + chunk_samples) / sample_rate:.1f}s")

    # Load chunk
    chunk, sr = torchaudio.load(wav_path, frame_offset=i, num_frames=chunk_samples)
    if chunk.shape[0] > 1:
        chunk = chunk.mean(dim=0, keepdim=True)
    chunk_np = chunk.squeeze(0).numpy()

    try:
        # Get embeddings directly
        emb_outputs = embedding_model.embed(chunk_np)

        emb_vectors = emb_outputs.embeddings  # shape: (1, 1, D) or similar
        #print(f'emb_ouputs.embeddings: {emb_vectors}')
        if emb_vectors is None or len(emb_vectors) == 0:
            raise ValueError("No embeddings returned")

        # Classify each embedding in the chunk
        for emb in emb_vectors:
            for emb_vec in emb:
               # print(f'emb_vec: {emb_vec}')
                logits = classifier(emb_vec)
                #print(f'logits: {logits}')
                probs = np.exp(logits) / np.sum(np.exp(logits))
                pred_idx = np.argmax(probs)
                pred_label = class_names[pred_idx]
                confidence = probs[pred_idx]
                if confidence > .8:
                  print(f"‚úÖ Predicted class: {pred_label} with confidence {confidence:.4f}")
                  start_time = i / sample_rate
                  end_time = (i + chunk_samples) / sample_rate
                  print(f"‚úÖ Chunk start: {start_time:.1f}s, end: {end_time:.1f}s")
                  with open(output_csv_path, 'a', newline='') as f:
                      writer = csv.writer(f)
                      writer.writerow([start_time, end_time, pred_label, f"{confidence:.4f}"])

    except Exception as e:
        print(f"‚ùå Error processing chunk at {i/sample_rate:.1f}s: {e}")

    # Cleanup
    del chunk, chunk_np
    gc.collect()

print("‚úÖ Classification complete. Results saved to:", output_csv_path)
