<a href="https://colab.research.google.com/github/noahdanieldsouza/PAM-classification/blob/main/chart_positives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio numpy
!pip install git+https://github.com/google-research/perch-hoplite.git@782acd0e409eb27df51a695de4cb6608dae0db25

In [None]:
import os
import torch
import torchaudio
import numpy as np
import csv
import gc
import tempfile

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from perch_hoplite.agile import colab_utils, embed, source_info
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.zoo import model_configs
from perch_hoplite.agile.classifier import LinearClassifier

# --- Paths ---
db_path = '/content/drive/MyDrive/sept26/DB'
classifier_path = '/content/drive/MyDrive/full_labeled_fish/DB/agile_classifier_v2.pt'


# --- Load DB + model ---
db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
model_class = model_configs.get_model_class(db_model_config['model_key'])
embedding_model = model_class.from_config(db_model_config['model_config'])
embedding_ids = db.get_embedding_ids()
print(f"✅ loaded {len(embedding_ids)} embeddings")
id = embedding_ids[0]
vector = db.get_embedding(id)
print(f"✅ vector {vector}")

# --- Load classifier ---
classifier = LinearClassifier.load(classifier_path)
class_names = classifier.classes
print("✅ Loaded classifier with classes:", class_names)

graph = []

for i in range(20):
  threshold = i * .05
  total = 0
  for emb_id in embedding_ids:
      print (f"processing embedding: {emb_id}")
      vector = db.get_embedding(emb_id)
      logits = classifier(vector)
      #print(f"✅ logits: {logits}")
      probs = np.exp(logits) / np.sum(np.exp(logits))
      #print(f"✅ probs: {probs}")
      pred_idx = np.argmax(probs)
      pred_label = class_names[pred_idx]
      confidence = probs[pred_idx]
      #and pred_label != "boat"
      if confidence > threshold:
        total +=1

  graph.append((threshold, total))




In [None]:
print(graph)

In [None]:
import matplotlib.pyplot as plt

# unzip the list of tuples into two lists
x_vals, y_vals = zip(*graph)

plt.figure(figsize=(6, 4))
plt.plot(x_vals, y_vals, marker='o', linestyle='-', color='b')
plt.title("Number of Predictions Above Confidence Threshold")
plt.xlabel("Confidence Threshold")
plt.ylabel("Count of Predictions")
plt.grid(True)
plt.show()
