In [None]:
import os
from datasets import Dataset, DatasetDict
from transformers import ViTForImageClassification, AutoProcessor, TrainingArguments, Trainer
from PIL import Image
import torch
from sklearn.metrics import precision_score, recall_score, f1_score

# 数据集路径
train_dataset_path = "/content/drive/MyDrive/AI/data/plantdataset1226"
test_dataset_path = "/content/drive/MyDrive/AI/data/testdata20241226"

# 获取子文件夹列表（0-22），并创建标签映射
class_names = sorted(os.listdir(train_dataset_path))
class_labels = {class_name: idx for idx, class_name in enumerate(class_names)}
id2label = {idx: class_name for class_name, idx in class_labels.items()}

# 准备训练数据
train_data = {"image_path": [], "label": []}

for class_name, label in class_labels.items():
    class_folder = os.path.join(train_dataset_path, class_name)
    if os.path.isdir(class_folder):
        for file_name in os.listdir(class_folder):
            file_path = os.path.join(class_folder, file_name)
            if file_name.endswith((".png", ".jpg", ".jpeg")):  # 过滤图像文件
                train_data["image_path"].append(file_path)
                train_data["label"].append(label)

# 创建 Hugging Face 数据集
dataset = Dataset.from_dict(train_data)

# 划分训练集和验证集
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)

# 加载 ViT 处理器
processor = AutoProcessor.from_pretrained("google/vit-base-patch16-224-in21k", use_fast=True)

# 定义预处理函数
def preprocess_images(examples):
    images = [Image.open(path).convert("RGB") for path in examples["image_path"]]
    processed = processor(images=images, return_tensors="pt")
    processed["label"] = examples["label"]
    return processed

# 应用预处理
processed_dataset = split_dataset.map(preprocess_images, batched=True, remove_columns=["image_path"])

# 转换为 PyTorch 格式
train_dataset = processed_dataset["train"].with_format("torch")
eval_dataset = processed_dataset["test"].with_format("torch")

# 加载 ViT 模型
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(class_labels),  # 类别数量
    id2label=id2label,
    label2id=class_labels,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1).numpy()

    # 使用 scikit-learn 的评估指标
    precision = precision_score(labels, predictions, average='weighted', zero_division=0)
    recall = recall_score(labels, predictions, average='weighted', zero_division=0)
    f1 = f1_score(labels, predictions, average='weighted', zero_division=0)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

# 训练参数
training_args = TrainingArguments(
    output_dir="./results",            # 模型保存路径
    eval_strategy="epoch",            # 每个 epoch 进行评估
    save_strategy="epoch",            # 每个 epoch 保存模型
    learning_rate=5e-5,               # 学习率
    per_device_train_batch_size=256,   # 训练批量大小
    per_device_eval_batch_size=256,    # 验证批量大小
    num_train_epochs=200,               # 训练轮数
    logging_dir="./logs",             # 日志保存路径
    logging_steps=10,                 # 日志记录频率
    load_best_model_at_end=True,      # 加载最佳模型
    metric_for_best_model="f1",       # 最佳模型评估指标
    greater_is_better=True,           # 指标越大越好
    save_total_limit=2,               # 保存模型的数量限制
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=processor,
    compute_metrics=compute_metrics,
)

# 训练模型
trainer.train()

# 保存模型
trainer.save_model("./vit-plant-classifier")

# 验证集结果
metrics = trainer.evaluate()
print("验证集结果:")
print(f"Precision: {metrics.get('eval_precision', 0.0)}")
print(f"Recall: {metrics.get('eval_recall', 0.0)}")
print(f"F1: {metrics.get('eval_f1', 0.0)}")


Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,No log,2.936173,0.434262,0.46,0.374389
2,No log,2.773307,0.562358,0.6,0.543663
3,No log,2.621866,0.599611,0.646667,0.586791
4,2.841100,2.48027,0.690669,0.693333,0.647983
5,2.841100,2.34231,0.777778,0.766667,0.73089
6,2.841100,2.212509,0.794649,0.826667,0.793706
7,2.296200,2.075393,0.846111,0.86,0.831699
8,2.296200,1.950212,0.910044,0.913333,0.895123
9,2.296200,1.832065,0.92,0.94,0.923273
10,1.844600,1.71923,0.932101,0.96,0.943695


验证集结果:
Precision: 0.9885714285714285
Recall: 0.98
F1: 0.9793939393939394


Mounted at /content/drive


In [26]:
from google.colab import drive
drive.mount('/content/drive')

import os
from PIL import Image
import torch
from transformers import AutoProcessor, ViTForImageClassification
import json

# 加载保存的模型和处理器
model_path = "/content/drive/MyDrive/AI/data/vit-plant-classifier"
model = ViTForImageClassification.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path, use_fast=True)

# 测试集路径
test_dir = "/content/drive/MyDrive/AI/data/testImage/"   #測試3萬張
# test_dir = "/content/drive/MyDrive/AI/data/miniTestImage/"  #迷你測試23種植物

# 加载模型并忽略大小不匹配
model = ViTForImageClassification.from_pretrained(
    model_path,
    ignore_mismatched_sizes=True
)

# 调整分类头为 23 类
num_features = model.classifier.in_features

# 验证模型分类头的输出
print(model.classifier)

true_labels = {
    "A-01.png": "A-01_鼠尾草",
    "A-02.png": "A-02_芸香",
    "A-03.png": "A-03_甜薰衣草",
    "A-04.png": "A-04_艾草",
    "A-05.png": "A-05_檸檬百里香",
    "A-06.png": "A-06_冰淇淋蘭香草",
    "A-07.png": "A-07_檸檬馬鞭草",
    "A-08.png": "A-08_芳香萬壽菊",
    "A-09.png": "A-09_紫蘇",
    "A-10.png": "A-10_小葉左手香",
    "A-11.png": "A-11_白斑左手香",
    "A-12.png": "A-12_檸檬左手香",
    "A-13.png": "A-13_金錢薄荷",
    "A-14.png": "A-14_防蚊草",
    "A-15.png": "A-15_迷迭香",
    "B-01.png": "B-01_九層塔",
    "B-02.png": "B-02_甜羅勒",
    "B-03.png": "B-03_甜薄荷",
    "B-04.png": "B-04_巧克力薄荷",
    "B-05.png": "B-05_奧勒岡",
    "C-01.png": "C-01_綠薄荷",
    "C-02.png": "C-02_香蜂草",
    "C-03.png": "C-03_貓薄荷"
}

id2label = {
    0: "A-01",
    1: "A-02",
    2: "A-03",
    3: "A-04",
    4: "A-05",
    5: "A-06",
    6: "A-07",
    7: "A-08",
    8: "A-09",
    9: "A-10",
    10: "A-11",
    11: "A-12",
    12: "A-13",
    13: "A-14",
    14: "A-15",
    15: "B-01",
    16: "B-02",
    17: "B-03",
    18: "B-04",
    19: "B-05",
    20: "C-01",
    21: "C-02",
    22: "C-03"
}
label2id = {v: k for k, v in id2label.items()}

# 统计结果
correct = 0
incorrect = 0

# def predict_image(image_path):
#     image = Image.open(image_path).convert("RGB")
#     inputs = processor(images=image, return_tensors="pt")

#     with torch.no_grad():
#         outputs = model(**inputs)
#         probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)

#         # 获取所有类别的概率
#         all_probabilities = probabilities.squeeze().tolist()

#         # 获取最高类别和其概率
#         predicted_class_idx = torch.argmax(probabilities, dim=1).item()
#         confidence = probabilities[0, predicted_class_idx].item()

#     return predicted_class_idx, confidence, all_probabilities

# # 对所有 .JPG 图片进行预测
# for filename in os.listdir(test_dir):
#     if filename.endswith(".png"):
#         image_path = os.path.join(test_dir, filename)
#         predicted_class_idx, confidence, all_probabilities = predict_image(image_path)

#         # 输出所有类别的概率
#         probabilities_str = ", ".join([f"{model.config.id2label[idx]}: {prob:.4f}" for idx, prob in enumerate(all_probabilities)])

#         # 输出结果
#         print(f"{filename}: Predicted: {model.config.id2label[predicted_class_idx]} (confidence: {confidence:.4f})")
#         print(f"All probabilities: {probabilities_str}\n")


def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)

        # 获取所有类别的概率
        all_probabilities = probabilities.squeeze().tolist()

        # 获取最高类别和其概率
        predicted_class_idx = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0, predicted_class_idx].item()

    return predicted_class_idx, confidence, all_probabilities

# 对所有图片进行预测
for filename in os.listdir(test_dir):
    if filename.endswith(".png"):
        image_path = os.path.join(test_dir, filename)
        predicted_class_idx, confidence, all_probabilities = predict_image(image_path)

        # 获取文件名前四个字符（如 A-01, B-02）
        true_prefix = filename[:4]  # 提取文件名前四个字符
        predicted_label = id2label[predicted_class_idx]

        # 判断正确与否
        if predicted_label.startswith(true_prefix):
            print(f"正確_Correct: {filename} -> {predicted_label}, confidence: {confidence:.4f}")
            correct += 1
        else:
            print(f"錯誤_Incorrect: {filename} -> Predicted: {predicted_label}, True Prefix: {true_prefix}, confidence: {confidence:.4f}")
            incorrect += 1

        # 输出所有类别的概率
        probabilities_str = ", ".join([f"{model.config.id2label[idx]}: {prob:.4f}" for idx, prob in enumerate(all_probabilities)])
        print(f"All probabilities: {probabilities_str}\n")

# 计算和打印总结
total = correct + incorrect
accuracy = correct / total * 100 if total > 0 else 0
print(f"Total Correct: {correct}")
print(f"Total Incorrect: {incorrect}")
print(f"Accuracy: {accuracy:.2f}%")



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Linear(in_features=768, out_features=24, bias=True)
錯誤_Incorrect: A-20_01_512x512_h612_sp_f0.png -> Predicted: A-01, True Prefix: A-20, confidence: 0.1054
All probabilities: 00: 0.1054, 01: 0.0327, 02: 0.0299, 03: 0.0447, 04: 0.0313, 05: 0.0368, 06: 0.0509, 07: 0.0334, 08: 0.0339, 09: 0.0367, 10: 0.0508, 11: 0.0290, 12: 0.0360, 13: 0.0340, 14: 0.0382, 15: 0.0401, 16: 0.0331, 17: 0.0425, 18: 0.0379, 19: 0.0722, 20: 0.0412, 21: 0.0476, 22: 0.0302, .ipynb_checkpoints: 0.0315

錯誤_Incorrect: A-19_32_512x512_h612_turn_light_t_f0.png -> Predicted: B-04, True Prefix: A-19, confidence: 0.1217
All probabilities: 00: 0.0292, 01: 0.0252, 02: 0.0302, 03: 0.0262, 04: 0.0403, 05: 0.0601, 06: 0.0327, 07: 0.0305, 08: 0.0383, 09: 0.0577, 10: 0.0394, 11: 0.0584, 12: 0.0394, 13: 0.0287, 14: 0.0319, 15: 0.0500, 16: 0.0372, 17: 0.0459, 18: 0.1217, 19: 0.0450, 20: 0.0385, 21: 0.036

KeyboardInterrupt: 

In [None]:
import os
from PIL import Image
import torch
from transformers import AutoProcessor, ViTForImageClassification

# 加載保存的模型和處理器
model_path = "./vit-plant-classifier"
model = ViTForImageClassification.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path, use_fast=True)

# 測試集路徑
test_dir = "/content/20241226twst123"

def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()

    return model.config.id2label[predicted_class_idx]

# 對所有 .JPG 圖片進行預測
for filename in os.listdir(test_dir):
    if filename.endswith(".png"):
        image_path = os.path.join(test_dir, filename)
        prediction = predict_image(image_path)
        print(f"{filename}: {prediction}")


9deba0c6337205476f0c9727f49e5cc.png: 3
857c2df8b3ba9866e86fd3e57f4ffdb.png: 10
5a93af0b516403d5ea258307b521ada.png: 2
f882479e0c66b0d73133173ad966947.png: 4


In [None]:
!mv /content/vit-plant-classifier.zip /content/drive/MyDrive/