In [6]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import random
import timm
from timm.scheduler.cosine_lr import CosineLRScheduler
import torch.backends.cudnn as cudnn
from torchsampler import ImbalancedDatasetSampler
import os
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, recall_score, precision_score, confusion_matrix


In [7]:
class PCamDataset(Dataset):
    def __init__(self, h5_x_path, h5_y_path=None, transform=None):
        self.x_path = h5_x_path
        self.y_path = h5_y_path
        self.transform = transform
        self.has_labels = h5_y_path is not None

        self.x_file = h5py.File(h5_x_path, 'r')
        self.length = len(self.x_file['x'])

        if self.has_labels:
            self.y_file = h5py.File(h5_y_path, 'r')
            self.labels = self.y_file['y'][:]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        image = self.x_file['x'][idx].astype(np.uint8)
        image = transforms.ToPILImage()(image)
        if self.transform:
            image = self.transform(image)
        if self.has_labels:
            label = self.labels[idx].item() 
            return image, label
        else:
            return image

    def __del__(self):
        if hasattr(self, 'x_file'):
            self.x_file.close()
        if self.has_labels and hasattr(self, 'y_file'):
            self.y_file.close()

    def get_labels(self):
        if self.has_labels:
            return self.labels.reshape(-1)  # 1次元に変換
        else:
            return None

# RandAugment + (0.5,0.5,0.5)正規化
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

eval_transform = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [8]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, confusion_matrix

# ===== テストデータの用意 =====
test_dataset = PCamDataset(
    '/home/gotou/Medical/b4us/pcamdata/camelyonpatch_level_2_split_test_x.h5',
    '/home/gotou/Medical/b4us/pcamdata/camelyonpatch_level_2_split_test_y.h5',
    transform=eval_transform
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# ===== モデル構築 & ロード =====
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('resnet50', pretrained=False, num_classes=2)

# DataParallel の場合は dict のキーが "module." で始まるので修正
state_dict = torch.load('bestresnet50.pth', map_location=device)
if any(k.startswith('module.') for k in state_dict.keys()):
    from collections import OrderedDict
    new_state_dict = OrderedDict((k.replace('module.', ''), v) for k, v in state_dict.items())
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(state_dict)

model = model.to(device)
model.eval()

# ===== 推論 =====
all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)[:, 1]  # クラス1の確率
        _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# ===== 評価指標 =====
acc = accuracy_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
roc_auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)

print(f"Accuracy: {acc:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Precision: {precision:.4f}")
print(f"F1-score: {f1:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print("Confusion Matrix:\n", cm)


Accuracy: 0.8830
Recall: 0.7999
Precision: 0.9591
F1-score: 0.8723
ROC-AUC: 0.9558
Confusion Matrix:
 [[15833   558]
 [ 3277 13100]]
