In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import h5py
import numpy as np
from tqdm import tqdm
import pandas as pd
import os

# ----------------------------
# ハイパーパラメータ設定
# ----------------------------
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
EPOCHS = 50
PATIENCE = 10

# ----------------------------
# Dataset定義（修正版）
# ----------------------------
class PCamDataset(Dataset):
    def __init__(self, h5_x_path, h5_y_path=None, transform=None):
        self.x_path = h5_x_path
        self.y_path = h5_y_path
        self.transform = transform
        self.has_labels = h5_y_path is not None

        with h5py.File(h5_x_path, 'r') as x_file:
            self.length = len(x_file['x'])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.x_path, 'r') as x_file:
            image = x_file['x'][idx]

        image = transforms.ToPILImage()(image.astype(np.uint8))

        if self.transform:
            image = self.transform(image)

        if self.has_labels:
            with h5py.File(self.y_path, 'r') as y_file:
                # 修正箇所: item()メソッドを使って明示的にスカラー値を取得
                label = y_file['y'][idx].item()  # .item()を使う
            return image, label
        else:
            return image

# ----------------------------
# データ変換
# ----------------------------
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(96, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_val_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ----------------------------
# データローダ-
# ----------------------------
train_dataset = PCamDataset('camelyonpatch_level_2_split_train_x.h5',
                          'camelyonpatch_level_2_split_train_y.h5',
                          transform=transform_train)

val_dataset = PCamDataset('valid_x_uncompressed.h5',
                         'valid_y_uncompressed.h5',
                         transform=transform_val_test)

test_dataset = PCamDataset('camelyonpatch_level_2_split_test_x.h5',
                         'camelyonpatch_level_2_split_test_y.h5',  # テストラベルを追加
                         transform=transform_val_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE*4, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE*4, shuffle=False, num_workers=4, pin_memory=True)


# ----------------------------
# モデル構築
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PCamResNet50(nn.Module):
    def __init__(self):
        super(PCamResNet50, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        
        # 特徴抽出器の凍結（最初の数層のみ）
        for param in list(self.resnet.parameters())[:100]:
            param.requires_grad = False
            
        # 最終層を置換
        self.resnet.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)
        )
        
    def forward(self, x):
        return self.resnet(x)

model = PCamResNet50().to(device)

# ----------------------------
# 損失関数・最適化
# ----------------------------
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True)

# ----------------------------
# 学習ループ（修正版）
# ----------------------------
def train(model, train_loader, val_loader, criterion, optimizer, epochs=EPOCHS, patience=PATIENCE):
    best_val_loss = float('inf')
    best_epoch = 0
    no_improve = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        # 学習ループ
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            imgs = imgs.to(device)
            labels = labels.to(device).float().unsqueeze(1)  # 形状を [batch_size, 1] に
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * imgs.size(0)
        
        # 検証ループ
        val_loss, val_metrics = evaluate(model, val_loader, criterion)
        train_loss /= len(train_loader.dataset)
        
        # 学習率スケジューリング
        scheduler.step(val_loss)
        
        # 早期停止チェック
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            no_improve += 1
        
        # メトリクス表示
        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"Val Precision: {val_metrics['precision']:.4f}")
        print(f"Val Recall: {val_metrics['recall']:.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f}")
        print(f"Val AUC: {val_metrics['auc']:.4f}")
        
        if no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}. Best epoch was {best_epoch+1}.")
            break
    
    # 最良モデルのロード
    model.load_state_dict(torch.load('best_model.pth'))
    return model

# ----------------------------
# テストセット評価関数の追加
# ----------------------------
def evaluate_test_set(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Evaluating Test Set"):
            imgs = imgs.to(device)
            labels = labels.to(device).float().unsqueeze(1)
            
            outputs = model(imgs)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    # メトリクス計算
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds),
        'recall': recall_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds),
        'auc': roc_auc_score(all_labels, all_probs)  # 確率を使用
    }
    
    return metrics

def evaluate(model, loader, criterion):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0
    
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device).float().unsqueeze(1)  # 形状を [batch_size, 1] に
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * imgs.size(0)
            
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    # メトリクス計算
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds),
        'recall': recall_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds),
        'auc': roc_auc_score(all_labels, all_preds)
    }
    
    return total_loss / len(loader.dataset), metrics

# ----------------------------
# モデル訓練（変更なし）
# ----------------------------
model = train(model, train_loader, val_loader, criterion, optimizer)

# ----------------------------
# 検証セット評価（変更なし）
# ----------------------------
_, val_metrics = evaluate(model, val_loader, criterion)
print("\nValidation Metrics:")
print(f"Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Precision: {val_metrics['precision']:.4f}")
print(f"Recall: {val_metrics['recall']:.4f}")
print(f"F1 Score: {val_metrics['f1']:.4f}")
print(f"ROC AUC: {val_metrics['auc']:.4f}")

# ----------------------------
# テストセット評価（追加）
# ----------------------------
test_metrics = evaluate_test_set(model, test_loader)
print("\nTest Set Metrics:")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall: {test_metrics['recall']:.4f}")
print(f"F1 Score: {test_metrics['f1']:.4f}")
print(f"ROC AUC: {test_metrics['auc']:.4f}")

# ----------------------------
# 混同行列の表示（追加）
# ----------------------------
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# テストセットの予測と真のラベルを取得
test_preds = []
test_labels = []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        probs = torch.sigmoid(outputs).cpu().numpy()
        preds = (probs > 0.5).astype(int)
        test_preds.extend(preds)
        test_labels.extend(labels.numpy())

# 混同行列の計算と表示
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Negative', 'Positive'], 
            yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

Epoch 1: 100%|██████████| 4096/4096 [03:59<00:00, 17.09it/s]



Epoch 1/50:
Train Loss: 0.2781
Val Loss: 0.3265
Val Accuracy: 0.8524
Val Precision: 0.9221
Val Recall: 0.7694
Val F1: 0.8389
Val AUC: 0.8523


Epoch 2: 100%|██████████| 4096/4096 [03:57<00:00, 17.23it/s]



Epoch 2/50:
Train Loss: 0.2238
Val Loss: 0.3676
Val Accuracy: 0.8589
Val Precision: 0.9397
Val Recall: 0.7666
Val F1: 0.8444
Val AUC: 0.8588


Epoch 3:  28%|██▊       | 1161/4096 [01:08<02:54, 16.83it/s]


KeyboardInterrupt: 