In [2]:
import argparse
import cv2
import numpy as np
import os
from matplotlib import pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm, trange

# 读取数据
数据来源：scene

In [3]:
x_list = []
y_list = []
for pic in tqdm(os.listdir('scene')):
    pic_path = './scene/' + pic
    # 读取RGB三通道图像(640, 640, 3)
    pic_data = cv2.imread(pic_path, cv2.IMREAD_COLOR)
    pic_data = cv2.resize(pic_data, (224, 224))
    x_list.append(pic_data)
    y_list.append(int(pic[6:8]))
x_list = np.array(x_list)
y_list_int = np.array(y_list)

scene_label = pd.read_excel('scene_label.xlsx')
y_list = np.zeros((y_list_int.shape[0], 20))
for i in trange(y_list_int.shape[0]):
    y_list[i, scene_label[scene_label.id==y_list_int[i]].iloc[:, 1:].dropna(axis=1).astype(int).to_numpy()[0].tolist()] = 1

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

In [4]:
from sklearn.model_selection import StratifiedShuffleSplit

# 假设标签数据保存在label_list中，其中每个标签是一个整数
X = x_list
y = y_list

# 分层抽样，其中train_size和test_size分别表示训练集和测试集的比例
# n_splits表示抽取的次数，random_state表示随机数种子
split = StratifiedShuffleSplit(n_splits=1, train_size=0.8, test_size=0.2, random_state=42)
train_index, test_index = next(split.split(X, y))

# 得到训练集和测试集
X_train, X_test = [X[i] for i in train_index], [X[i] for i in test_index]
y_train, y_test = [y[i] for i in train_index], [y[i] for i in test_index]

# 将训练集进一步划分为训练集和验证集，其中test_size表示验证集的比例
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_index, valid_index = next(split.split(X_train, y_train))

# 得到训练集、验证集和测试集的索引
train_index = [train_index[i] for i in range(len(train_index))]
valid_index = [valid_index[i] for i in range(len(valid_index))]
test_index = [test_index[i] for i in range(len(test_index))]

# 分类器

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm, trange

# 定义超参数
batch_size = 8
learning_rate = 0.0001
num_epochs = 100

# 设置 GPU
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

class MyDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

# 加载数据集
train_data = x_list[train_index]
train_labels = y_list[train_index]
valid_data = x_list[valid_index]
valid_labels = y_list[valid_index]
test_data = x_list[test_index]
test_labels = y_list[test_index]

# # 定义数据增强和标准化
# # 在scene数据集中只做了标准化
transform = transforms.Compose([
#     transforms.RandomAffine(5),
#     transforms.ColorJitter(hue=.05, saturation=.05),
#     transforms.RandomCrop((88, 88)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
    transforms.ToTensor(), # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化张量
])

# 加载数据集
train_dataset = MyDataset(train_data, train_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = MyDataset(valid_data, valid_labels, transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_dataset = MyDataset(test_data, test_labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [7]:
import torchvision

model = torchvision.models.resnet18()
path = './pretrained_model/resnet18-5c106cde.pth'
trained_state_dict = torch.load(path)
model.load_state_dict(trained_state_dict,strict=True)

num_fits = model.fc.in_features
model.fc = nn.Linear(num_fits, 20)
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [10]:
# 训练模型
for epoch in trange(num_epochs):
    running_loss = 0.0
    model.train()
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_dataset)
    print('Epoch [%d/%d], Loss: %.4f' % (epoch + 1, num_epochs, epoch_loss))
    torch.save(model.state_dict(), "./classifier/model-ResNet18-scene/epoch-%d.pt" % epoch)
    
    model.eval()
    with torch.no_grad():
        running_loss = 0.0
        for inputs, labels in tqdm(valid_loader):
            inputs, labels = inputs.to(device), labels
            outputs = model(inputs)
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(valid_dataset)
        print('Loss of the model on the valid images: %f' % loss)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [1/100], Loss: 4.7953


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.901807


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [2/100], Loss: 4.7940


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.895648


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [3/100], Loss: 4.7896


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.773497


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [4/100], Loss: 4.7957


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.506450


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [5/100], Loss: 4.8674


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.814692


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [6/100], Loss: 4.8407


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.103386


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [7/100], Loss: 4.8248


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.146227


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [8/100], Loss: 4.7945


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.963235


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [9/100], Loss: 4.8143


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.967705


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [10/100], Loss: 4.7914


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.488840


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [11/100], Loss: 4.8034


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.713840


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [12/100], Loss: 4.7905


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.091721


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [13/100], Loss: 4.7910


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.496632


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [14/100], Loss: 4.7871


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.683073


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [15/100], Loss: 4.7878


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.898036


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [16/100], Loss: 4.7864


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.063105


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [17/100], Loss: 4.8040


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.104675


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [18/100], Loss: 4.8750


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.549660


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [19/100], Loss: 4.8034


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.313094


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [20/100], Loss: 4.7933


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.963004


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [21/100], Loss: 4.7883


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.116914


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [22/100], Loss: 4.8300


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.215246


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [23/100], Loss: 4.8020


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.864482


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [24/100], Loss: 4.7939


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.646520


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [25/100], Loss: 4.7891


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.691770


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [26/100], Loss: 4.7896


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.177428


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [27/100], Loss: 4.7842


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.474560


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [28/100], Loss: 4.7861


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.680490


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [29/100], Loss: 4.7852


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.492956


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [30/100], Loss: 4.7864


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.464730


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [31/100], Loss: 4.7858


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.894096


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [32/100], Loss: 4.7880


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.926810


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [33/100], Loss: 4.7870


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.334473


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [34/100], Loss: 4.7925


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.086189


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [35/100], Loss: 4.8344


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.057760


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [36/100], Loss: 4.9403


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.212748


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [37/100], Loss: 4.8273


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.698150


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [38/100], Loss: 4.8164


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.379069


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [39/100], Loss: 4.8339


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.490939


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [40/100], Loss: 4.7888


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.488869


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [41/100], Loss: 4.8062


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.050070


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [42/100], Loss: 4.7857


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.273911


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [43/100], Loss: 4.7851


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.367670


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [44/100], Loss: 4.7856


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.458423


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [45/100], Loss: 4.7872


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.190000


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [46/100], Loss: 4.7945


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.487318


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [47/100], Loss: 4.7831


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.047701


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [48/100], Loss: 4.7826


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.786321


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [49/100], Loss: 4.7853


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.679335


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [50/100], Loss: 4.7893


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.491793


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [51/100], Loss: 4.7855


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.773150


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [52/100], Loss: 4.8137


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.233562


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [53/100], Loss: 4.8006


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.180178


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [54/100], Loss: 4.7865


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.894031


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [55/100], Loss: 4.7834


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.364329


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [56/100], Loss: 4.7825


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.242115


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [57/100], Loss: 4.7916


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.178680


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [58/100], Loss: 4.7874


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.740199


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [59/100], Loss: 4.8154


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.772052


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [60/100], Loss: 4.7921


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.089188


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [61/100], Loss: 4.7852


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.050958


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [62/100], Loss: 4.7821


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.303711


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [63/100], Loss: 4.7845


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.273695


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [64/100], Loss: 4.7845


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.738312


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [65/100], Loss: 4.7814


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.769133


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [66/100], Loss: 4.7813


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.456463


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [67/100], Loss: 4.8015


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.391040


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [68/100], Loss: 4.8129


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.718743


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [69/100], Loss: 4.8160


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.085867


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [70/100], Loss: 4.7899


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.558538


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [71/100], Loss: 4.7853


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.382175


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [72/100], Loss: 4.7818


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.768828


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [73/100], Loss: 4.7822


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.174462


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [74/100], Loss: 4.7916


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.899906


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [75/100], Loss: 4.7813


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.485273


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [76/100], Loss: 4.7827


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.769829


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [77/100], Loss: 4.7834


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.486916


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [78/100], Loss: 4.7819


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.768056


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [79/100], Loss: 4.7925


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.894032


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [80/100], Loss: 4.7809


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.181737


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [81/100], Loss: 4.7810


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.455408


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [82/100], Loss: 4.8121


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.531521


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [83/100], Loss: 4.8205


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.906982


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [84/100], Loss: 4.7872


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.745525


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [85/100], Loss: 4.7843


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.085052


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [86/100], Loss: 4.7813


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.080417


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [87/100], Loss: 4.7850


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.491780


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [88/100], Loss: 4.7816


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 6.549940


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [89/100], Loss: 4.7800


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.675001


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [90/100], Loss: 4.7798


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.768403


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [91/100], Loss: 4.7812


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.361742


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [92/100], Loss: 4.7818


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.172867


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [93/100], Loss: 4.7819


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.678056


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [94/100], Loss: 4.7796


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.767706


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [95/100], Loss: 4.7881


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.909723


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [96/100], Loss: 4.7836


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.582081


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [97/100], Loss: 4.7833


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 5.362476


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [98/100], Loss: 4.7900


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.524560


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [99/100], Loss: 4.9075


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 4.536358


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [100/100], Loss: 4.7980


  0%|          | 0/50 [00:00<?, ?it/s]

Loss of the model on the valid images: 3.586980


# 评价

In [11]:
y_pred = []
y_true = []
# 测试模型（这里是针对单个杯子设计的准确率，多个杯子需修改代码）
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels
        outputs = model(inputs)
        predicted = np.int64(np.array(outputs.data.cpu()) > 0)
        y_pred.append(predicted)
        y_true.append(np.int64(np.array(labels.cpu())))

  0%|          | 0/63 [00:00<?, ?it/s]

In [12]:
y_pred = np.row_stack(y_pred)
y_true = np.row_stack(y_true)

In [13]:
from sklearn.metrics import multilabel_confusion_matrix, precision_score, recall_score, f1_score

# 计算混淆矩阵
mcm = multilabel_confusion_matrix(y_true, y_pred)

# 计算精确度、召回率和F1分数
precision = precision_score(y_true, y_pred, average='micro')
recall = recall_score(y_true, y_pred, average='micro')
f1 = f1_score(y_true, y_pred, average='micro')

# 打印结果
print("Multilabel Confusion Matrix:")
print(mcm)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Multilabel Confusion Matrix:
[[[460   0]
  [  0  40]]

 [[360   0]
  [  0 140]]

 [[380   0]
  [  0 120]]

 [[400   0]
  [  0 100]]

 [[400   0]
  [  0 100]]

 [[440   0]
  [  0  60]]

 [[380   0]
  [  0 120]]

 [[300   0]
  [  0 200]]

 [[380   0]
  [  0 120]]

 [[400   0]
  [  0 100]]

 [[380   0]
  [  0 120]]

 [[380   0]
  [  0 120]]

 [[480   0]
  [  0  20]]

 [[440   0]
  [  0  60]]

 [[419   1]
  [  0  80]]

 [[459   1]
  [  0  40]]

 [[400   0]
  [  0 100]]

 [[439   1]
  [  0  60]]

 [[458   2]
  [  0  40]]

 [[419   1]
  [  0  80]]]
Precision: 0.9967141292442497
Recall: 1.0
F1 Score: 0.9983543609434997
