In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from models_mamba import VisionMamba

# 数据处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48538744, 0.45147955, 0.41428733], std=[0.22886747, 0.22356716, 0.22420578]),
])

# 加载数据
train_dataset = datasets.ImageFolder(root='/root/autodl-tmp/Dataset/Train', transform=transform)
val_dataset = datasets.ImageFolder(root='/root/autodl-tmp/Dataset/Valid', transform=transform)
test_dataset = datasets.ImageFolder(root='/root/autodl-tmp/Dataset/Test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 使用定义好的 VisionMamba 类
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将 num_classes 参数改为 4
model = VisionMamba(
    img_size=224,
    patch_size=16,
    depth=12,
    embed_dim=192,
    num_classes=4,  # 修改为四分类
).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 定义学习率调节器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# 训练模型
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25):
    train_losses, val_losses = [], []  # 两个损失列表
    train_accuracies, val_accuracies = [], []  # 两个准确率列表
    train_loss, val_loss = 0.0, 0.0
    correct_train, total_train = 0, 0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        
        # 训练阶段
        model.train()
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
        
        train_loss /= len(train_loader.dataset)
        train_accuracy = correct_train / total_train
        
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        
        # 验证阶段
        model.eval()
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader.dataset)
        val_accuracy = correct_val / total_val

        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        # 调整学习率
        scheduler.step(val_loss)
        
        print(f'Train Loss: {train_loss:.4f} Train Accuracy: {train_accuracy:.4f} Val Loss: {val_loss:.4f} Val Accuracy: {val_accuracy:.4f}')
        
    return train_losses, val_losses, train_accuracies, val_accuracies

train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25)

# 测试模型
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            probs = torch.softmax(outputs, dim=1)  # 获取所有类别的概率
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())  # 保存所有类别的概率
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

    return all_labels, all_preds, all_probs

all_labels, all_preds, all_probs = test_model(model, test_loader)

# 绘制ROC曲线和混淆矩阵
plt.figure()

for i in range(4):
    fpr, tpr, _ = roc_curve([1 if label == i else 0 for label in all_labels], [prob[i] for prob in all_probs])
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, lw=2, label=f'Class {i} ROC curve (area = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic for All Classes')
plt.legend(loc="lower right")
plt.show()

# 混淆矩阵
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues') 
plt.show()

ModuleNotFoundError: No module named 'models_mamba'

In [2]:
import sys
print(sys.path)

['/root', '/root/miniconda3/lib/python38.zip', '/root/miniconda3/lib/python3.8', '/root/miniconda3/lib/python3.8/lib-dynload', '', '/root/miniconda3/lib/python3.8/site-packages']
