In [1]:
from ultralytics.nn.tasks import torch_safe_load, ClassificationModel
import torch
from safetensors.torch import load_file
from ultralytics.data.dataset import classify_transforms
from PIL import Image
import numpy as np
import cv2
import os
from tqdm import tqdm
import random

In [2]:
def load_yolo_model(path, nc):
    ckpt, w = torch_safe_load(path)
    model = ckpt["model"]
    ClassificationModel.reshape_outputs(model, nc)
    for p in model.parameters():
        p.requires_grad = False  # for training
    model = model.float()
    return model

In [3]:
device = 0
model_cls_ori_path = "/mnt/data1/model_1122/image_classifier/yolov8m-cls.pt"
checkpoint_path = "/mnt/data1/model_1122/image_classifier/epoch_30/model.safetensors"
model = load_yolo_model(model_cls_ori_path, 3)
weights = load_file(checkpoint_path)
model.load_state_dict(weights)
model.to(device)

ClassificationModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (1): Conv(
      (conv): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (2): C2f(
      (cv1): Conv(
        (conv): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (cv2): Conv(
        (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (m): ModuleList(
        (0-1): 2 x Bottlene

In [4]:
torch_transforms = classify_transforms(size=224) 
def process_image_classifier_image(image_binary, torch_transforms, device = 0):
    try:
        nparr = np.frombuffer(image_binary, np.uint8)
        im = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

        # 转换为PIL图像 (RGB格式)
        im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        # 应用转换
        sample = torch_transforms(im).to(device)
    except Exception as e:
        print(f"Error decoding image: {str(e)}")
        return None 
    return sample

In [5]:
def get_label_and_image_file(directory, label, files_and_label):
    supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(supported_formats):  # 检查文件扩展名
                file_path = os.path.join(root, file)
                files_and_label.append((file_path, label))
    return files_and_label

In [6]:
files_and_label = []
files_and_label = get_label_and_image_file("/mnt/data1/model_1122/image_classifier/test/animal", 0, files_and_label)
files_and_label = get_label_and_image_file("/mnt/data1/model_1122/image_classifier/test/human", 1, files_and_label)
files_and_label = get_label_and_image_file("/mnt/data1/model_1122/image_classifier/test/no_target", 2, files_and_label)
random_files_and_label_elements = random.sample(files_and_label, 1000)
print(f"原始数据个数:{len(files_and_label)},随机选择的元素个数:{len(random_files_and_label_elements)}")

原始数据个数:107459,随机选择的元素个数:1000


In [19]:
predict_and_labels = []
import torch.nn.functional as F
for image_path, label in tqdm(random_files_and_label_elements, desc="预测图片得到分类"):
    with open(image_path, 'rb') as file:
        image_content = file.read()
    image = process_image_classifier_image(image_content, torch_transforms, device)
    if image is None:
        continue
    image = image.unsqueeze(0)
    outputs = model(image)
    max_prob, predicted_class = torch.max(
        F.softmax(outputs, dim=1), dim=1)
    print(max_prob, predicted_class)
    break

预测图片得到分类:   0%|          | 0/1000 [00:00<?, ?it/s]

tensor([0.5761], device='cuda:0') tensor([0], device='cuda:0')





In [9]:
print(predict_and_labels[:10])

[(0, 0), (0, 0), (2, 2), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (1, 1), (0, 0)]


In [16]:
from evaluate import load
import numpy as np
from typing import List, Union, Dict
from collections import defaultdict

def calculate_metrics(predict_and_labels, 
                     metric_names: List[str] = ["accuracy", "precision", "recall", "f1"]) -> Dict[str, float]:
    """
    Calculate multiple metrics between predictions and ground truth
    
    Args:
        predictions: List of predicted values
        references: List of ground truth values
        metric_names: List of metrics to calculate
        
    Returns:
        Dictionary containing calculated metrics
    """
    predictions, references = zip(*predict_and_labels)
    if len(predictions) != len(references):
        raise ValueError(f"Length mismatch: predictions ({len(predictions)}) != references ({len(references)})")
        
    results = {}
    
    for metric_name in metric_names:
        metric = load(metric_name)
        if metric_name == "accuracy":
            result = metric.compute(predictions=predictions, references=references)
        else:
            result = metric.compute(predictions=predictions, references=references, average='macro')
        results.update(result)
        
    return results

In [17]:
print(calculate_metrics(predict_and_labels))

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

{'accuracy': 0.942242355605889, 'precision': 0.9408027593630216, 'recall': 0.9509068626620701, 'f1': 0.944641139188016}


In [12]:
predictions, references = zip(*predict_and_labels)

In [13]:
print(predictions[:10], references[:10])

(0, 0, 2, 0, 0, 0, 0, 0, 1, 0) (0, 0, 2, 0, 0, 0, 0, 0, 1, 0)
