In [70]:
import os
import cv2
import torch
import albumentations

import tkinter as tk
import pandas as pd
import numpy as np
import torch.nn as nn
import torchvision.models as models

from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
from albumentations.pytorch.transforms import ToTensorV2

In [71]:
# 加载标签映射
label_mapping = pd.read_csv('../dataset/label_mapping.csv', header=None, skiprows=1)
labels_unique = label_mapping[0].tolist()

In [72]:
# 定义预处理
transform = albumentations.Compose([
        albumentations.Resize(320, 320),
        albumentations.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(p=1.0)])

In [73]:
# 加载所有折叠模型
models_list = []
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

for fold_n in range(5):
    model_path = f'../models/fold_{fold_n}_best_model.pth'
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model = models.resnet50()
        model.fc = nn.Linear(model.fc.in_features, 176)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        models_list.append(model)
        print(f'Loaded model from fold {fold_n}')
    else:
        print(f'Model not found: {model_path}')

Loaded model from fold 0
Loaded model from fold 1
Loaded model from fold 2
Loaded model from fold 3
Loaded model from fold 4


In [74]:
class LeafPredictorApp:
    def __init__(self, root):
        self.root = root
        self.root.title('树叶分类预测系统')
        self.root.geometry("800x600")

        self.image_paths = []
        self.current_image_idx = 0

        # 创建界面组件
        self.create_widgets()

    def create_widgets(self):
        # 上传按钮
        self.upload_btn = tk.Button(self.root, text="上传图片", command=self.upload_images)
        self.upload_btn.pack(pady=10)

        # 图片显示区域
        self.image_frame = tk.Frame(self.root)
        self.image_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=10)

        self.image_label = tk.Label(self.image_frame)
        self.image_label.pack(fill=tk.BOTH, expand=True)

        # 导航按钮
        self.nav_frame = tk.Frame(self.root)
        self.nav_frame.pack(pady=5)

        self.prev_btn = tk.Button(self.nav_frame, text="上一张", command=self.show_prev_image)
        self.prev_btn.pack(side=tk.LEFT, padx=5)

        self.next_btn = tk.Button(self.nav_frame, text="下一张", command=self.show_next_image)
        self.next_btn.pack(side=tk.LEFT, padx=5)

        # 预测按钮
        self.predict_btn = tk.Button(self.root, text="执行预测", command=self.predict_images)
        self.predict_btn.pack(pady=10)

        # 结果显示区域
        self.result_frame = tk.Frame(self.root)
        self.result_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=10)

        self.result_label = tk.Label(self.result_frame, text="预测结果显示", 
                                    font=("Arial", 14), wraplength=600)
        self.result_label.pack(fill=tk.BOTH, expand=True)

    def upload_images(self):
        filetypes = [("图片文件", "*.jpg *.jpeg *.png *.bmp")]
        paths = filedialog.askopenfilenames(title="选择树叶图片", filetypes=filetypes)
        if paths:
            self.image_paths = list(paths)
            self.current_image_idx = 0
            self.show_current_image()

    def show_current_image(self):
        if self.image_paths:
            path = self.image_paths[self.current_image_idx]
            image = Image.open(path)
            image.thumbnail((600, 400))
            photo = ImageTk.PhotoImage(image)
            self.image_label.config(image=photo)
            self.image_label.image = photo

    def show_next_image(self):
        if self.image_paths:
            self.current_image_idx = (self.current_image_idx + 1) % len(self.image_paths)
            self.show_current_image()

    def show_prev_image(self):
        if self.image_paths:
            self.current_image_idx = (self.current_image_idx - 1) % len(self.image_paths)
            self.show_current_image()

    def predict_images(self):
        if not self.image_paths:
            messagebox.showwarning("警告", "请先上传图片")
            return

        if not models_list:
            messagebox.showerror("错误", "没有找到模型文件")

        results = []
        for path in self.image_paths:
            # 读取并预处理图像
            image = cv2.imread(path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            transformed = transform(image=image)
            image_tensor = transformed['image'].unsqueeze(0).to(device)

            # 使用所有模型进行预测
            predictions = []
            for model in models_list:
                with torch.no_grad():
                    output = model(image_tensor)
                    pred = torch.argmax(output, dim=1).item()
                    predictions.append(pred)

            # 众数投票
            final_pred = max(set(predictions), key=predictions.count)
            label = labels_unique[final_pred]
            results.append(f'图片：{os.path.basename(path)} \n预测结果：{label}')

        # 显示当前图片的结果
        current_result = results[self.current_image_idx]
        self.result_label.config(text=current_result)

In [75]:
# 启动应用
if __name__ == "__main__":
    root = tk.Tk()
    app = LeafPredictorApp(root)
    root.mainloop()