In [1]:
from d2l import torch as d2l
import utils as u
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import seaborn as sns
from sklearn.preprocessing import LabelBinarizer
from torchvision import transforms
from PIL import Image
import numpy as np


In [None]:
img_size, patch_size, batch_size = 128, 16, 64
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1

# 创建 ISIC 数据集加载器
isic_data = u.ISICDataModule(batch_size, resize=img_size)

# 获取训练和测试数据加载器
train_loader = isic_data.get_dataloader(train=True)
test_loader = isic_data.get_dataloader(train=False)

# 可视化一批训练数据
batch = next(iter(train_loader))
isic_data.visualize(batch, nrows=1, ncols=9)


In [None]:
model = u.ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = u.Trainer(max_epochs=10, num_gpus=1)
data = u.ISICDataModule(batch_size, resize=(img_size, img_size))
trainer.fit(model, data)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

model.eval()  # 设置模型为评估模式
# 获取验证集的DataLoader``
data_module = u.ISICDataModule(batch_size, resize=img_size)
val_loader = data_module.get_dataloader(train=False)

In [5]:

# 计算准确率
def accuracy(model, dataloader):
    correct, total = 0, 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = correct / total
    return accuracy, all_preds, all_labels

# 绘制混淆矩阵
def plot_confusion_matrix(y_true, y_pred, labels=['benign', 'malignant']):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

# 绘制ROC曲线
def plot_roc_curve(y_true, y_pred_prob):
    fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='b', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

In [None]:
# 计算准确率
accuracy_val, all_preds, all_labels = accuracy(model, val_loader)
print(f"Validation Accuracy: {accuracy_val * 100:.2f}%")

In [None]:
# 绘制混淆矩阵
plot_confusion_matrix(all_labels, all_preds)

In [None]:
# 计算预测概率并绘制 ROC 曲线
# 由于模型输出的是类别索引，因此我们需要获得模型的预测概率（softmax输出）
all_preds_prob = []
for inputs, _ in val_loader:
    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)[:, 1]  # 取类别1的概率
        all_preds_prob.extend(probs.cpu().numpy())

plot_roc_curve(all_labels, all_preds_prob)
