In [None]:
# This file is subject to the terms and conditions defined in file
# `COPYING.md`, which is part of this source code package.

import os
import time
import numpy as np
import cv2
import subprocess
from matplotlib import pyplot as plt
from typing import List, Tuple
import numpy.typing as npt
from ultralytics.utils.plotting import Colors
from IPython.display import clear_output

from classifiers.fastsam.model import FastSAM
from classifiers.helpers import draw_detections, get_spectra_from_mask
from ultralytics import YOLO

from lo_dataset_reader import DatasetReader, rle_to_mask
from lo.sdk.analysis.ml.models.spectral_classifier import CPURFClassifier as classifier
from lo.sdk.api.acquisition.data.formats import LORAWtoLOGRAY12, _debayer
from lo.sdk.helpers.import_numpy_or_cupy import xp


In [None]:
%matplotlib inline

In [None]:
fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.3.0/FastSAM-x.pt"
filename = "FastSAM-x.pt"

if not os.path.exists(filename):
    print(f"Downloading {filename}...")
    subprocess.run(["wget", fastsam_url, "-O", filename])
else:
    print(f"{filename} already exists. Skipping download.")

yolo_url = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l-seg.pt"
filename = "yolov8l-seg.pt"

if not os.path.exists(filename) or os.path.getsize(filename) < 85 * 1024 * 1024:
    print(f"Downloading {filename}...")
    subprocess.run(["wget", yolo_url, "-O", filename])
else:
    print(f"{filename} already exists. Skipping download.")

In [None]:
# Methods

def label_fn(ann):
    """
    Generate a class label string from the metadata of an annotation
    Args:
        ann (lo.data.tools.tools.Annotation):

    Returns:
        label (str): the class label for the annotation
    """
    return ann["category_name"]

def percentile_norm(im: np.ndarray, low: int = 1, high: int = 99) -> np.ndarray:
    """
    Normalise the image based on percentile values.

    Args:
        im (xp.ndarray): The input image.
        low (int): The lower percentile for normalization.
        high (int): The higher percentile for normalization.

    Returns:
        xp.ndarray: The normalised image.
    """

    ## ::30 is subsampling
    low, high = np.percentile(im[::40, ::40], (low, high), axis=(0, 1))
    im = (im - low) / (high - low)
    return np.clip(im, 0, 1) * 255

def get_true_values_from_mask(values: np.ndarray, sampling_coordinates: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """
    Select elements of values whose coordinates in sampling_coordinates map to True values in mask.
    Args:
        values (np.ndarray): An array of values with shape (n_spectra, ...)
        sampling_coordinates (np.ndarray): An array of sampling coordinates with shape (n_spectra, 2)
        mask (np.ndarray): a binary mask with shape (h, w)

    Returns:
        (np.ndarray): an array of selected values with shape (n_selected, ...)
    """
    sampling_coordinates = np.int32(sampling_coordinates)
    in_mask = mask[sampling_coordinates[:, 0], sampling_coordinates[:, 1]]
    return values[in_mask.astype(bool), :]

def visualise(scene, pred):
    """
    Args:
        scene: Original image (H, W) or (H, W, 3)
        pred: Dictionary mapping class IDs to coordinate arrays
    Returns:
        vis_img: RGB image with overlaid predictions
    """
    # Create empty mask
    h, w = scene.shape[:2]
    
    # Define color map (adjust according to your actual class labels)
    color_map = {
        0: (255, 255, 0),       # background (cyan)
        1: (0, 255, 0),     # class 1 (green) - grapes
        2: (255, 0, 0),     # class 2 (blue) - tray
        3: (0, 0, 255),     # class 3 (red) - tyvec
    }
    
    # Create a copy of the scene to draw on
    vis_img = scene.copy()
    
    # Draw each class with its corresponding color
    for class_id, coords in pred.items():            
        # Ensure coordinates are integers and within bounds
        coords = coords.astype(int)
        valid_coords = (coords[:, 0] < h) & (coords[:, 1] < w) & (coords[:, 0] >= 0) & (coords[:, 1] >= 0)
        coords = coords[valid_coords]
        
        # Draw each point as a small circle
        for y, x in coords:
            cv2.circle(vis_img, (x, y), radius=5, color=color_map[class_id], thickness=-1)
    
    # Add legend
    legend_height = 60
    legend = np.zeros((legend_height, scene.shape[1], 3), dtype=np.uint8)

    class_names = {
        0: "Background",
        1: "Grapes",
        2: "Tray",
        3: "Tyvec"
    }

    font_scale = 0.9
    font_thickness = 2
    box_width = 25
    box_height = 25
    y_box_top = 17
    y_text = y_box_top + box_height - 5

    x_pos = 10
    spacing = 16
    for class_id, color in color_map.items():
        x_box = x_pos
        x_text = x_box + box_width + spacing

        # Draw rectangle
        cv2.rectangle(legend, (x_box, y_box_top), (x_box + box_width, y_box_top + box_height), color, -1)

        # Draw outlined text
        text = class_names[class_id]
        cv2.putText(legend, text, (x_text, y_text), cv2.FONT_HERSHEY_SIMPLEX,
                    font_scale, (0, 0, 0), font_thickness + 1, lineType=cv2.LINE_AA)
        cv2.putText(legend, text, (x_text, y_text), cv2.FONT_HERSHEY_SIMPLEX,
                    font_scale, (255, 255, 255), font_thickness, lineType=cv2.LINE_AA)

        x_pos += box_width + spacing + cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)[0][0] + 40

    # Combine image and legend
    vis_img = np.vstack([vis_img, legend])

    return vis_img

### Download the training dataset

In [None]:
# Load dataset using the DatasetReader class
# path = "/path/to/dataset"
path = "/home/seonghyun/hugging_face/Grapes-Dataset.zip"
reader = DatasetReader(dataset_path=path)

In [None]:
# Path to save classifier post training
model_path = './fruit_classifier'

### Train the spectral classifier

In [None]:
all_spectra = []
all_labels = []

class_number_to_label = {
    0: "background",
    1: "grapes",
    2: "tray",
    3: "tyvec",

}
label_to_class_number = {v: k for k, v in class_number_to_label.items()}

for (info, scene_frame, spectra, *_), _, annotations, *_ in reader:
    print("scene_frame shape:", scene_frame.shape)
    
    h, w = scene_frame.shape[:2]
    coordinates = info.sampling_coordinates.astype(int)  # shape: (N, 2)

    # Create a label mask for the whole frame
    label_mask = np.zeros((h, w), dtype=np.uint8)

    for ann in annotations:
        class_name = label_fn(ann)
        print(class_name)
        if class_name not in label_to_class_number:
            continue

        label_index = label_to_class_number[class_name]

        if ann.get("segmentation"):
            mask = rle_to_mask(ann["segmentation"], (h, w))
            label_mask[(mask > 0) & (label_mask == 0)] = label_index

    # Map each spectrum point to its label
    for i, (y, x) in enumerate(coordinates):
        if 0 <= y < h and 0 <= x < w:
            label = label_mask[y, x]
            if label in class_number_to_label:
                spectrum = spectra[i]  # shape: (96,)
                all_spectra.append(spectrum)
                all_labels.append(label)

# Convert to arrays
all_spectra = np.stack(all_spectra)  # (N, 96)
all_labels = np.array(all_labels)    # (N,)

In [None]:
if not os.path.exists(model_path):
    os.makedirs(model_path)

# Instantiate Classifier
classifier = classifier(
    classifier_path=model_path,
    plot_spectra=False,
    do_reflectance=False,
    class_number_to_label=class_number_to_label,
)

print("Starting training")
classifier.train(
    all_spectra=all_spectra,
    all_labels=all_labels,
    n_estimators=70,
    warm_start=False,
)

# Print class labels
for k, v in classifier.metadata.items():
    print(k, v[0])

## Run segmentation enhanced with spectral classification

This inference example demonstrates a simple integration between the YOLO SAM model and a spectral classifier to achieve subclass classification and improved recognition while preserving the semantic understanding of the model.


In [None]:
# Load trained spectral classifier
prob_thresh = 0.7

# load segment anything model developed by ultralytics.
model = FastSAM('FastSAM-x.pt') 
plt.figure()

count = 0
for (info, scene_frame, spectra, *_), *_ in reader:

        if scene_frame is None:
            break

        clear_output()
        
        scene_frame = np.ascontiguousarray(scene_frame)

        # Prepare the scene image
        if len(scene_frame.shape) == 3:
            scene_frame = scene_frame.squeeze()
        if np.amax(scene_frame) > 1000:
            scene_frame = LORAWtoLOGRAY12(scene_frame)

        if scene_frame.shape[0] % 2 == 1 or scene_frame.shape[1] % 2 == 1:
            scene_frame = np.dstack([scene_frame, scene_frame, scene_frame])
        else:
            scene_frame = _debayer(scene_frame)

        results = model(scene_frame, device='cpu', retina_masks=True, imgsz=480, conf=0.6, iou=0.9)
        masks = results[0].masks.data.detach().cpu().numpy()
        segs = results[0].masks.xy
        boxes = results[0].boxes.data.detach().cpu().numpy()

        frame = (info, scene_frame, spectra)
        metadata_out, classes, probs = classifier(
            frame,
            confidence=prob_thresh,
            sa_factor=4,
            similarity_filter=False,
        )

        vis = visualise(scene_frame, metadata_out)

        save_dir = "visualisations"
        save_path = os.path.join(save_dir, f"output_{count:03d}.jpg")
        cv2.imwrite(save_path, vis)

        # assign classes to objects identify by segment anything
        for bbox, mask, seg in zip(boxes, masks, segs):
            bb_probs = get_spectra_from_mask(probs, info.sampling_coordinates, mask)
            
            bb_class = np.argmax(np.median(bb_probs, 0))
            if bb_class == 0:
                continue
            bb_prob = bb_probs[:,bb_class].mean(0)
            if bb_prob < prob_thresh:
                continue
            scene_frame = draw_detections(classifier.classes, scene_frame, bbox, bb_prob, bb_class, seg)    

        count += 1
        plt.imshow(scene_frame)
        plt.show()
        time.sleep(0.2) 
        

## Run vanila YOLO segmentation model:

This inference example shows a visualisation of the performance of a pretrained YOLOv8, for performance comparison purposes.


In [None]:
#Run model
reader = DatasetReader(dataset_path=path)

model = YOLO("yolov8l-seg.pt")

plt.figure()
count = 0
for (info, scene_frame, spectra, *_), *_ in reader:
        if scene_frame is None:
            break
        
        clear_output()
        scene_frame = np.ascontiguousarray(scene_frame)

        # Prepare the scene image
        if len(scene_frame.shape) == 3:
            scene_frame = scene_frame.squeeze()
        if np.amax(scene_frame) > 1000:
            scene_frame = LORAWtoLOGRAY12(scene_frame)

        if scene_frame.shape[0] % 2 == 1 or scene_frame.shape[1] % 2 == 1:
            scene_frame = np.dstack([scene_frame, scene_frame, scene_frame])
        else:
            scene_frame = _debayer(scene_frame)

        results = model(scene_frame, device='cpu', retina_masks=True, imgsz=480, conf=0.05, iou=0.9)
        
        segments = results[0].masks.xy
        boxes = results[0].boxes.data.detach().cpu().numpy()
        for ((*box, conf, cls_), segment) in zip(boxes, segments):
            scene_frame = draw_detections(classifier.classes, scene_frame, bbox, bb_prob, bb_class, seg)    

        count += 1
        plt.imshow(scene_frame)
        plt.show()