In [2]:
import cv2
import torchvision  # scope for faster 'import ultralytics'
from pathlib import Path
from ultralytics.data.dataset import classify_transforms
from ultralytics.nn.tasks import torch_safe_load, ClassificationModel
import torch
from safetensors.torch import load_file
from tqdm import tqdm
from PIL import Image

In [3]:
base = torchvision.datasets.ImageFolder(root="/home/yuzhong/data1/image_classifier_data/test")

In [6]:
id_to_class_names = {}

for imfile, id in base.samples:
    class_name = imfile.split('/')[-2]
    if id in id_to_class_names:
        assert id_to_class_names[id] == class_name
    else:
        id_to_class_names[id] = class_name

In [7]:
id_to_class_names

{0: 'animal', 1: 'human', 2: 'no_target'}

In [2]:
def get_predict_image(root, imgsz):
    base = torchvision.datasets.ImageFolder(root=root)
    samples = base.samples
    torch_transforms = classify_transforms(size=imgsz)
    for file_name, class_id in samples:
        im = cv2.imread(file_name)
        im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        sample = torch_transforms(im)
        yield {"img": sample, "cls": class_id}

In [3]:
def load_yolo_model(path, nc):
    ckpt, w = torch_safe_load(path)
    model = ckpt["model"]
    ClassificationModel.reshape_outputs(model, nc)
    for p in model.parameters():
        p.requires_grad = False  # for training
    model = model.float()
    return model

model = load_yolo_model("/home/yuzhong/data1/models/yolo/yolov8m-cls.pt", 3)
weights = load_file('/home/yuzhong/data1/code/object_detection/image_classifier/epoch_31/model.safetensors')
model.load_state_dict(weights)

<All keys matched successfully>

In [4]:
def predict_images(model, image_paths, imgsz, device_id=0):
    model.eval()
    model.to(device_id)
    predict_class = []
    ground_truth_class = []
    with torch.no_grad():
        for im_info in tqdm(get_predict_image(image_paths, imgsz), desc="predict image"):
            im = im_info['img'].unsqueeze(0).to(device_id)
            class_id = im_info['cls']
            ground_truth_class.append(class_id)
            outputs = model(im)
            predictions = outputs.argmax(dim=-1)
            predict_class.append(predictions.item())
    return predict_class, ground_truth_class

def predict_batch_images(model, image_paths, imgsz, device_id=0, batch_size=32):
    model.eval()
    model.to(device_id)
    predict_class = []
    ground_truth_class = []
    
    batch_images = []
    batch_classes = []
    
    with torch.no_grad():
        for im_info in tqdm(get_predict_image(image_paths, imgsz), desc="predict image"):
            im = im_info['img']
            class_id = im_info['cls']
            
            batch_images.append(im)
            batch_classes.append(class_id)
            
            if len(batch_images) == batch_size:
                batch_tensor = torch.stack(batch_images).to(device_id)
                outputs = model(batch_tensor)
                predictions = outputs.argmax(dim=-1)
                
                # Store predictions and ground truth
                predict_class.extend(predictions.cpu().numpy().tolist())
                ground_truth_class.extend(batch_classes)
                
                # Clear the batch lists
                batch_images = []
                batch_classes = []
        
        # Process any remaining images in the batch (if total % batch_size != 0)
        if batch_images:
            batch_tensor = torch.stack(batch_images).to(device_id)
            outputs = model(batch_tensor)
            predictions = outputs.argmax(dim=-1)
            
            predict_class.extend(predictions.cpu().numpy().tolist())
            ground_truth_class.extend(batch_classes)
    
    return predict_class, ground_truth_class

In [5]:
predict_class, ground_truth_class = predict_batch_images(model, "/home/yuzhong/data1/image_classifier_data/test", 224, batch_size=1280)

predict image: 251286it [8:03:57,  8.65it/s]


In [6]:
len(predict_class)

251286

In [7]:
len(ground_truth_class)

251286

In [8]:
import evaluate

In [9]:
metric = evaluate.load("accuracy")

In [10]:
metric.add_batch(
                predictions=predict_class,
                references=ground_truth_class,
            )

In [11]:
eval_metric = metric.compute()

In [12]:
print(eval_metric)

{'accuracy': 0.944843723884339}
