# 🪄 yolo/prediction

In [1]:
import torch
import time
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from bluer_options import string
from bluer_objects import objects, file
from bluer_objects.metadata import get_from_object
from bluer_objects import storage
from bluer_sandbox import notebooks

from bluer_algo.host import signature
from bluer_algo.yolo.dataset.classes import YoloDataset
from bluer_algo import env
from bluer_algo.logger import logger

logger.info(f"{'.'.join(signature())},\nbuilt on {string.pretty_date()}")

  from .autonotebook import tqdm as notebook_tqdm
🪄  bluer_algo-4.386.1.bluer_ai-12.259.1.bluer_objects-6.258.1.bluer_options-5.164.1.torch-2.2.2.Python 3.12.9.Darwin 23.6.0..Jupyter-Notebook,
built on 15 September 2025, 12:00:13


In [2]:
prediction_object_name = objects.unique_object("yolo-prediction")

🌀  📂 yolo-prediction-2025-09-15-12-00-13-3kvc3l


---

In [3]:
dataset_object_name = env.BLUER_ALGO_COCO128_TEST_DATASET
# assert storage.download(dataset_object_name)

In [4]:
dataset = YoloDataset(object_name=dataset_object_name)
assert dataset.valid

🪄  found 128 image(s).
🪄  found 128 label(s).
🪄  missing 2 image(s): 000000000508, 000000000250
🪄  missing 2 label(s): 000000000659, 000000000656
🪄  YoloDataset, 126 record(s),  80 class(es): person, bicycle, car, motorcycle, airplane, ...


---

In [5]:
model_object_name = env.BLUER_ALGO_COCO128_TEST_MODEL
# assert storage.download(model_object_name)

In [6]:
model_filename = objects.path_of(
    object_name=model_object_name,
    filename="model.pth",
)

In [7]:
assert False

AssertionError: 

In [None]:
model_metadata = get_from_object(
    object_name=model_object_name,
    key="model",
)

class_count = model_metadata["dataset"]["class_count"]
logger.info(f"class_count: {class_count}")

In [None]:
model = TinyCNN(class_count)
model.load_state_dict(torch.load(model_filename, map_location="cpu"))
model.eval()

---

In [None]:
test_row = dataset.df[dataset.df["subset"] == "test"].sample(n=1)
logger.info(f"test_row: {test_row}")

success, image = file.load_image(
    objects.path_of(
        object_name=dataset_object_name,
        filename=test_row["filename"].values[0],
    )
)
assert success

class_index = test_row["class_index"].values[0]

In [None]:
# np_img is shape (H, W, 3) in RGB
assert isinstance(image, np.ndarray)
assert image.ndim == 3 and image.shape[2] == 3

elapsed_time = time.time()
# Convert to PIL for transforms
image_ = Image.fromarray(image.astype("uint8"))

# Apply same transform as training
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
])
input_tensor = transform(image_).unsqueeze(0)  # Shape: [1, 3, 100, 100]

with torch.no_grad():
    output = model(input_tensor)
    predicted_class = torch.argmax(output, dim=1).item()

elapsed_time = time.time() - elapsed_time

In [None]:
message = "prediction: {} [#{}] {}- took {}".format(
    dataset.dict_of_classes[predicted_class],
    predicted_class,
    " (correct) " if class_index == predicted_class else 
    "<> {} [#{}] ".format(
        dataset.dict_of_classes[class_index],
        class_index,
    ),
    string.pretty_duration(
        elapsed_time, 
        include_ms=True,
        short=True,
        ),
)
logger.info(message)

plt.imshow(image)
plt.title(message)
plt.axis("off")

assert file.save_fig(
    objects.path_of(
        object_name=prediction_object_name,
        filename="prediction.png",
    )
)

---

In [None]:
assert notebooks.upload(prediction_object_name)

In [None]:
# END