In [4]:
import os
import json

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
from typing import List, Dict, Any
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from embedder import Embedder
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from torchvision import transforms

In [5]:
to_tensor = transforms.ToTensor()
resize = transforms.Resize(
    (224, 224),
    transforms.InterpolationMode.BILINEAR,
    antialias=True,
)

def embed_crops(
    embedder, crops: List[torch.Tensor]
) -> torch.Tensor:
    """
    Embeds a list of crops using the embedder model.
    :param crops: List of crops as torch tensors.
    :param resize_size: Size to resize the crops to before embedding.
    :return: Embedded vectors as a torch tensor.
    """

    return embedder.embed(crops)

In [6]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [7]:
model_root = Path("./model")
embedder = Embedder(trunk_weights=model_root / "trunk_weights.pth",
            embedder_weights=model_root / "embedder_weights.pth",
            device=device)

In [10]:
det_path = Path("/Users/iman/345-data/ml-datasets/ccbf/recognition/ccbf-20250324-231335-det renders v2-angles")

images_dir = Path("/Users/iman/345-data/ml-datasets/ccbf/recognition/ccbf-20250324-231335-det renders v2-fixed")
for angle in range(0,15):
    vectors_output_path = det_path / f"det_vectors_{angle}"
    batch_size = 1000
    all_vectors = []
    image_paths = sorted(list(Path(f"/Users/iman/345-data/ml-datasets/ccbf/recognition/ccbf-20250324-231335-det renders v2-angles/det renders {angle}").glob("*.jpg")))
    for i in tqdm(range(0, len(image_paths), batch_size)):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = [resize(to_tensor(Image.open(path))).to(device) for path in batch_paths]
        batch_tensors = torch.stack(batch_images)
        batch_vectors = embed_crops(embedder, batch_tensors)
        all_vectors.append(batch_vectors.cpu())  # Move to CPU immediately
        del batch_tensors, batch_vectors
    # Concatenate all batch results
    vectors = torch.cat(all_vectors, dim=0)
    torch.save(vectors, vectors_output_path.with_suffix(".pt"))

classes_output_path = det_path / "det_classes_angles.json"
image_paths = sorted(list(Path(f"/Users/iman/Downloads/det renders {angle}").glob("*.jpg")))

# Extract class names from image paths
class_names = ["_".join(path.stem.split("_")[:-1]) for path in image_paths]

# Save as JSON
with open(classes_output_path, 'w') as f:
    json.dump(class_names, f, indent=2)

print(f"Saved {len(class_names)} class names to {classes_output_path}")

 17%|█▋        | 1/6 [00:08<00:41,  8.39s/it]


KeyboardInterrupt: 

In [None]:
all_vectors = []
for angle in range(0,15):
    vectors = torch.load(Path(det_path / f"det_vectors_{angle}").with_suffix(".pt"), map_location="cpu")
    all_vectors.append(vectors)
all_vectors = torch.stack(all_vectors, dim=0)

with open(det_path / "det_classes_angles.json") as f:
    all_classes = json.load(f)


In [None]:
VECTORS_DIR = Path("/Users/iman/345-data/ml-datasets/ccbf/vectors")
TEST_DATASET = "ccbf-20241127-20250326"
DATASET = "ccbf-20250324-231335-det renders v2-angles"
det_svs_path = VECTORS_DIR / DATASET 

det_vectors = torch.load(det_svs_path / "supervectors_det.pt", map_location="cpu")
print(det_vectors.shape)
print(det_vectors[:5])

with open(det_svs_path / "classes.json") as f:
    det_classes = json.load(f)

print(len(det_classes))
print(det_classes[:5])

test_vectors_dir = os.path.join(VECTORS_DIR, TEST_DATASET, "test set")

test_vectors = torch.load(
    os.path.join(test_vectors_dir, "testvectors.pt"), map_location="cpu"
)
print(test_vectors.shape)
print(test_vectors[:5])

with open(os.path.join(test_vectors_dir, "classes.json")) as f:
    test_classes = json.load(f)

print(len(test_classes))
print(test_classes[:5])



torch.Size([5984, 256])
tensor([[-0.0024, -0.0078,  0.0130,  ...,  0.0026, -0.0177, -0.0138],
        [-0.0131, -0.0113, -0.0019,  ...,  0.0155, -0.0185, -0.0164],
        [-0.0123, -0.0091, -0.0005,  ...,  0.0170, -0.0158, -0.0155],
        [-0.0084, -0.0185,  0.0038,  ..., -0.0117, -0.0135, -0.0122],
        [-0.0116, -0.0230, -0.0155,  ...,  0.0102,  0.0236,  0.0156]])
5984
['002cd8f0-7908-44ac-a688-d17d626cbb16_00078000003659_back', '002cd8f0-7908-44ac-a688-d17d626cbb16_00078000003659_front', '002cd8f0-7908-44ac-a688-d17d626cbb16_00078000003659_left', '002cd8f0-7908-44ac-a688-d17d626cbb16_00078000003659_right', '00467eba-8f78-4e8a-9ee9-7f8232bf361e_00810014530345_back']
torch.Size([2566, 256])
tensor([[-0.0389,  0.0474, -0.0156,  ...,  0.0650, -0.0370,  0.0159],
        [-0.0479, -0.0581, -0.0751,  ..., -0.0274, -0.0260,  0.0655],
        [-0.0666, -0.0245, -0.0247,  ..., -0.0045,  0.0244,  0.0335],
        [-0.0579, -0.0683, -0.1388,  ..., -0.0103, -0.0062,  0.0094],
        [ 0.0

In [None]:
# vectorise somewhere
# mla_vectors = [out.vector for out in valid_outputs]
# mla_gt_classes = [out.gt_label for out in valid_outputs]

In [None]:
det_upcs = [cls.split("_")[1] for cls in det_classes]

cls_to_vec = {
    cls: vec
    for cls, vec in zip(mla_gt_classes, mla_vectors)
    if cls in det_upcs
}
mla_filtered_classes = [
    cls for cls in mla_gt_classes if cls in det_upcs and cls != "missing" and cls != ""
]
mla_filtered_vectors = torch.stack(
    [cls_to_vec[cls] for cls in mla_gt_classes if cls != "missing" and cls != ""]
)
print(len(mla_filtered_vectors))


In [None]:
accuracies = []
distances = []
for angle in range(0,15):

    mean_dist = torch.mean(torch.diag(torch.cdist(det_vectors, all_vectors[angle], p=2)))

    knn = KNeighborsClassifier(n_neighbors=1)
    knn.fit(all_vectors[angle], all_classes)

    predicted_labels = knn.predict(mla_filtered_vectors.cpu().numpy())
    correct_predictions = sum(
        [
            1
            for gt, pred in zip(mla_filtered_classes, predicted_labels)
            if gt == pred.split("_")[1]
        ]
    )
    accuracy = correct_predictions / len(mla_filtered_classes)
    accuracies.append(accuracy)
    distances.append(mean_dist)

    print(f"Mean distance, image {angle}: {mean_dist:.4f}. Accuracy: {100*accuracy:.2f}%")
for i, angle in enumerate(range(0, 15)):
    plt.annotate(f'{angle}', (distances[i], accuracies[i]), 
                xytext=(5, 5), textcoords='offset points', fontsize=8)
plt.scatter(distances, accuracies, marker='o')
plt.xlabel('Mean Distance')
plt.ylabel('Accuracy')
plt.xlim(0.815,0.85)