In [2]:
import os
import numpy as np
import torch
from model_init import load_model, preprocess_image, get_normalized_features

# from scipy.spatial import distance


def mahalanobis_distance(x, y, inv_cov):
    diff = x - y
    return np.sqrt(diff.T @ inv_cov @ diff)

# Paths
test_folder = "../60_images_of_6_cows/test-images"
features_file = "reference_features.npy"
filenames_file = "reference_filenames.npy"

# Load reference features and filenames
reference_features = np.load(features_file)  # (N, 1536)
reference_filenames = np.load(filenames_file)

# Device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")
model = load_model("./models/best_model-3.pth", device)

# Precompute covariance matrix and its inverse for Mahalanobis distance
cov_matrix = np.cov(reference_features, rowvar=False)
inv_cov_matrix = np.linalg.inv(cov_matrix + 1e-6 * np.eye(cov_matrix.shape[0]))  # Add small value for stability

# Matching and evaluation
correct = 0
total = 0

for test_filename in os.listdir(test_folder):
    if test_filename.lower().endswith((".jpg", ".png", ".jpeg")):
        test_path = os.path.join(test_folder, test_filename)
        image_tensor = preprocess_image(test_path)
        test_feature = get_normalized_features(model, image_tensor, device)  # (1, 1536)

        test_feature = test_feature.squeeze()

        # Mahalanobis distance with all reference features
        distances = np.array([
            mahalanobis_distance(test_feature, ref_feat, inv_cov_matrix)
            for ref_feat in reference_features
            ])
        top_idx = np.argmin(distances)
        matched_filename = reference_filenames[top_idx]

        test_id = test_filename.split("_")[0]
        matched_id = matched_filename.split("_")[0]

        is_correct = test_id == matched_id
        total += 1
        correct += int(is_correct)

        print(f"Test Image: {test_filename} -> Matched: {matched_filename} | Correct: {is_correct}")

accuracy = correct / total * 100
print(f"\nFinal Accuracy: {accuracy:.2f}% ({correct}/{total})")


Using device: mps
Test Image: 221_05.jpg -> Matched: 221_01.jpg | Correct: True
Test Image: 207_01.jpg -> Matched: 207_01.jpg | Correct: True
Test Image: 209_04.jpg -> Matched: 209_01.jpg | Correct: True
Test Image: 209_10.jpg -> Matched: 209_01.jpg | Correct: True
Test Image: 217_04.jpg -> Matched: 217_01.jpg | Correct: True
Test Image: 217_10.jpg -> Matched: 217_01.jpg | Correct: True
Test Image: 217_05.jpg -> Matched: 217_01.jpg | Correct: True
Test Image: 209_05.jpg -> Matched: 209_01.jpg | Correct: True
Test Image: 221_04.jpg -> Matched: 221_01.jpg | Correct: True
Test Image: 221_10.jpg -> Matched: 221_01.jpg | Correct: True
Test Image: 221_06.jpg -> Matched: 221_01.jpg | Correct: True
Test Image: 207_02.jpg -> Matched: 207_01.jpg | Correct: True
Test Image: 209_07.jpg -> Matched: 209_01.jpg | Correct: True
Test Image: 217_07.jpg -> Matched: 217_01.jpg | Correct: True
Test Image: 217_06.jpg -> Matched: 217_01.jpg | Correct: True
Test Image: 209_06.jpg -> Matched: 209_01.jpg | Corr