In [1]:
import argparse
import cv2
import numpy as np
import os
from matplotlib import pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm, trange

# 读取数据
数据来源：scene

In [2]:
x_list = []
y_list = []
for pic in tqdm(os.listdir('scene')):
    pic_path = './scene/' + pic
    # 读取RGB三通道图像(640, 640, 3)
    pic_data = cv2.imread(pic_path, cv2.IMREAD_COLOR)
    pic_data = cv2.resize(pic_data, (224, 224))
    x_list.append(pic_data)
    y_list.append(int(pic[6:8]))
x_list = np.array(x_list)
y_list_int = np.array(y_list)

scene_label = pd.read_excel('scene_label.xlsx')
y_list = np.zeros((y_list_int.shape[0], 20))
for i in trange(y_list_int.shape[0]):
    y_list[i, scene_label[scene_label.id==y_list_int[i]].iloc[:, 1:].dropna(axis=1).astype(int).to_numpy()[0].tolist()] = 1

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

In [3]:
from sklearn.model_selection import StratifiedShuffleSplit

# 假设标签数据保存在label_list中，其中每个标签是一个整数
X = x_list
y = y_list

# 分层抽样，其中train_size和test_size分别表示训练集和测试集的比例
# n_splits表示抽取的次数，random_state表示随机数种子
split = StratifiedShuffleSplit(n_splits=1, train_size=0.8, test_size=0.2, random_state=42)
train_index, test_index = next(split.split(X, y))

# 得到训练集和测试集
X_train, X_test = [X[i] for i in train_index], [X[i] for i in test_index]
y_train, y_test = [y[i] for i in train_index], [y[i] for i in test_index]

# 将训练集进一步划分为训练集和验证集，其中test_size表示验证集的比例
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_index, valid_index = next(split.split(X_train, y_train))

# 得到训练集、验证集和测试集的索引
train_index = [train_index[i] for i in range(len(train_index))]
valid_index = [valid_index[i] for i in range(len(valid_index))]
test_index = [test_index[i] for i in range(len(test_index))]

# 分类器

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm, trange

# 定义超参数
batch_size = 8
learning_rate = 0.0001
num_epochs = 100

# 设置 GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class MyDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

# 加载数据集
train_data = x_list[train_index]
train_labels = y_list[train_index]
valid_data = x_list[valid_index]
valid_labels = y_list[valid_index]
test_data = x_list[test_index]
test_labels = y_list[test_index]

# # 定义数据增强和标准化
# # 在scene数据集中只做了标准化
transform = transforms.Compose([
#     transforms.RandomAffine(5),
#     transforms.ColorJitter(hue=.05, saturation=.05),
#     transforms.RandomCrop((88, 88)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
    transforms.ToTensor(), # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化张量
])

# 加载数据集
train_dataset = MyDataset(train_data, train_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = MyDataset(valid_data, valid_labels, transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_dataset = MyDataset(test_data, test_labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [5]:
import torchvision

model = torchvision.models.densenet121()
path = './pretrained_model/densenet121-a639ec97.pth'
trained_state_dict = torch.load(path)
model.load_state_dict(trained_state_dict, strict=False)

num_fits = model.classifier.in_features
model.classifier = nn.Linear(num_fits, 20)
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# 训练模型
for epoch in trange(num_epochs):
    running_loss = 0.0
    model.train()
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_dataset)
    print('Epoch [%d/%d], Loss: %.4f' % (epoch + 1, num_epochs, epoch_loss))
    torch.save(model.state_dict(), "./classifier/model-DenseNet121-scene/epoch-%d.pt" % epoch)
    
    model.eval()
    with torch.no_grad():
        running_loss = 0.0
        for inputs, labels in tqdm(valid_loader):
            inputs, labels = inputs.to(device), labels
            outputs = model(inputs)
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(valid_dataset)
        print('Loss of the model on the valid images: %f' % loss)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [1/100], Loss: 8.8782


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 8.657558


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [2/100], Loss: 7.4128


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 7.709405


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [3/100], Loss: 6.7536


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.853624


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [4/100], Loss: 6.3798


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.710020


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [5/100], Loss: 6.1651


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.716359


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [6/100], Loss: 5.8506


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.076292


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [7/100], Loss: 5.7217


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.035676


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [8/100], Loss: 5.6439


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.079361


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [9/100], Loss: 5.4664


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.461147


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [10/100], Loss: 5.4140


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.429276


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [11/100], Loss: 5.4058


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.589503


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [12/100], Loss: 5.3373


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.379922


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [13/100], Loss: 5.2523


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.285785


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [14/100], Loss: 5.1860


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.386427


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [15/100], Loss: 5.1994


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.024231


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [16/100], Loss: 5.1761


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.172375


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [17/100], Loss: 5.0843


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.140359


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [18/100], Loss: 5.0833


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.702830


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [19/100], Loss: 5.0144


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.415386


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [20/100], Loss: 5.0309


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.389511


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [21/100], Loss: 5.0364


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.356087


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [22/100], Loss: 5.0610


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.426845


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [23/100], Loss: 5.1086


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.599070


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [24/100], Loss: 5.0936


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.645826


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [25/100], Loss: 5.0658


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.629208


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [26/100], Loss: 5.0270


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.520233


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [27/100], Loss: 4.9567


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.831051


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [28/100], Loss: 4.9685


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.436938


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [29/100], Loss: 4.9461


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.253244


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [30/100], Loss: 4.8968


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.811858


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [31/100], Loss: 4.9374


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.359899


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [32/100], Loss: 4.9737


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.133453


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [33/100], Loss: 4.9235


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.753240


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [34/100], Loss: 4.9421


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.616634


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [35/100], Loss: 4.9068


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.719683


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [36/100], Loss: 4.8906


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.278895


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [37/100], Loss: 4.8869


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.581178


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [38/100], Loss: 4.9665


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.374812


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [39/100], Loss: 4.9314


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.781280


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [40/100], Loss: 5.0025


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.544955


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [41/100], Loss: 4.9471


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.156031


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [42/100], Loss: 4.8755


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.529191


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [43/100], Loss: 4.8488


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.223619


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [44/100], Loss: 4.8943


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.114357


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [45/100], Loss: 4.9986


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.615485


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [46/100], Loss: 4.9315


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.125572


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [47/100], Loss: 4.8696


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.292795


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [48/100], Loss: 4.8421


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.100406


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [49/100], Loss: 4.8558


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.209223


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [50/100], Loss: 4.8672


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.214032


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [51/100], Loss: 4.8593


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.487169


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [52/100], Loss: 4.8742


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.525194


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [53/100], Loss: 4.8378


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.138196


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [54/100], Loss: 4.8712


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.574375


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [55/100], Loss: 5.0142


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.512813


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [56/100], Loss: 4.8637


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.038022


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [57/100], Loss: 4.8227


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.516407


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [58/100], Loss: 4.8554


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.009151


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [59/100], Loss: 4.8271


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.611812


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [60/100], Loss: 4.8186


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.594384


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [61/100], Loss: 4.8310


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.096306


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [62/100], Loss: 4.8177


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.787265


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [63/100], Loss: 4.8053


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.464726


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [64/100], Loss: 5.1191


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.581954


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [65/100], Loss: 4.8664


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.214308


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [66/100], Loss: 4.8554


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.145203


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [67/100], Loss: 4.8397


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.140167


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [68/100], Loss: 4.8147


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.324782


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [69/100], Loss: 4.8123


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.364812


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [70/100], Loss: 4.8161


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.097525


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [71/100], Loss: 4.8131


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.919645


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [72/100], Loss: 4.8732


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.919109


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [73/100], Loss: 4.9760


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.692953


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [74/100], Loss: 4.8688


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.265627


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [75/100], Loss: 4.8275


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.889950


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [76/100], Loss: 4.8280


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.808469


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [77/100], Loss: 4.9386


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.827088


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [78/100], Loss: 4.8808


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.459457


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [79/100], Loss: 4.8667


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.529133


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [80/100], Loss: 4.8205


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.591698


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [81/100], Loss: 4.8238


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.903832


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [82/100], Loss: 4.8170


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.504185


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [83/100], Loss: 4.8123


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.524061


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [84/100], Loss: 4.8416


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.935579


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [85/100], Loss: 4.8397


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.110709


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [86/100], Loss: 4.8435


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.285095


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [87/100], Loss: 4.8130


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.196454


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [88/100], Loss: 4.9603


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.199155


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [89/100], Loss: 4.8412


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.375648


  0%|          | 0/200 [00:00<?, ?it/s]

# 评价

In [None]:
y_pred = []
y_true = []
# 测试模型（这里是针对单个杯子设计的准确率，多个杯子需修改代码）
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels
        outputs = model(inputs)
        predicted = np.int64(np.array(outputs.data.cpu()) > 0)
        y_pred.append(predicted)
        y_true.append(np.int64(np.array(labels.cpu())))

In [None]:
y_pred = np.row_stack(y_pred)
y_true = np.row_stack(y_true)

In [None]:
from sklearn.metrics import multilabel_confusion_matrix, precision_score, recall_score, f1_score

# 计算混淆矩阵
mcm = multilabel_confusion_matrix(y_true, y_pred)

# 计算精确度、召回率和F1分数
precision = precision_score(y_true, y_pred, average='micro')
recall = recall_score(y_true, y_pred, average='micro')
f1 = f1_score(y_true, y_pred, average='micro')

# 打印结果
print("Multilabel Confusion Matrix:")
print(mcm)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)