In [1]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, f1_score
import torch.nn.functional as F
import functools
from torch.optim import Adam
import torchvision.transforms as transforms
import os
import random
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, recall_score, precision_score, confusion_matrix


In [3]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import random

class PCamDataset(Dataset):
    def __init__(self, h5_x_path, h5_y_path=None, transform=None):
        self.x_path = h5_x_path
        self.y_path = h5_y_path
        self.transform = transform
        self.has_labels = h5_y_path is not None

        # h5pyファイルを開きっぱなしにする
        self.x_file = h5py.File(h5_x_path, 'r')
        self.length = len(self.x_file['x'])

        if self.has_labels:
            self.y_file = h5py.File(h5_y_path, 'r')
            self.labels = self.y_file['y'][:]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        image = self.x_file['x'][idx].astype(np.uint8)  # [H,W,C], uint8 expected

        # PIL画像に変換しaugmentation適用
        image = transforms.ToPILImage()(image)

        if self.transform:
            image = self.transform(image)

        if self.has_labels:
            label = self.labels[idx].item() 
            return image, label
        else:
            return image

    def __del__(self):
        # ファイルを閉じる
        if hasattr(self, 'x_file'):
            self.x_file.close()
        if self.has_labels and hasattr(self, 'y_file'):
            self.y_file.close()

# 平均・標準偏差（ImageNet準拠）
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(96, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)),
    transforms.RandomPerspective(distortion_scale=0.3, p=0.5),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomApply([transforms.RandomSolarize(threshold=0.5, p=0.3)], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
])

eval_transform = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# データセット・ローダー作成
train_dataset = PCamDataset(
    '/home/gotou/Medical/camelyonpatch_level_2_split_train_x.h5',
    '/home/gotou/Medical/camelyonpatch_level_2_split_train_y.h5',
    transform=train_transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)  # バッチサイズ64に増やす

val_dataset = PCamDataset(
    '/home/gotou/Medical/valid_x_uncompressed.h5',
    '/home/gotou/Medical/valid_y_uncompressed.h5',
    transform=eval_transform
)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

# 再現性確保
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # lrを1e-4に増やす

# Cosine Annealing LR scheduler を導入
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

best_acc = 0
train_losses = []
val_accuracies = []

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_train_loss:.4f}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f"Validation accuracy after epoch {epoch+1}: {acc:.2f}%")

    train_losses.append(avg_train_loss)
    val_accuracies.append(acc)

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), 'bestresnet50.pth')
        print("Best model saved.")

    scheduler.step()



Epoch 1/20: 100%|██████████| 4096/4096 [05:32<00:00, 12.33it/s, loss=0.373]

Epoch 1 Training Loss: 0.3205





Validation accuracy after epoch 1: 86.10%
Best model saved.


Epoch 2/20: 100%|██████████| 4096/4096 [05:38<00:00, 12.11it/s, loss=0.215] 

Epoch 2 Training Loss: 0.2565





Validation accuracy after epoch 2: 88.42%
Best model saved.


Epoch 3/20: 100%|██████████| 4096/4096 [05:46<00:00, 11.82it/s, loss=0.27]  

Epoch 3 Training Loss: 0.2337





Validation accuracy after epoch 3: 87.12%


Epoch 4/20: 100%|██████████| 4096/4096 [05:36<00:00, 12.18it/s, loss=0.192] 

Epoch 4 Training Loss: 0.2174





Validation accuracy after epoch 4: 86.89%


Epoch 5/20:   2%|▏         | 72/4096 [00:06<05:57, 11.24it/s, loss=0.246] 


KeyboardInterrupt: 