In [31]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from tqdm import tqdm

# --------------------------
# Dataset 定義
# --------------------------
class PCamDataset(Dataset):
    def __init__(self, h5_x_path, h5_y_path=None, transform=None):
        self.x_path = h5_x_path
        self.y_path = h5_y_path
        self.transform = transform
        self.has_labels = h5_y_path is not None

        with h5py.File(h5_x_path, 'r') as x_file:
            self.length = len(x_file['x'])

        if self.has_labels:
            with h5py.File(h5_y_path, 'r') as y_file:
                self.labels = y_file['y'][:]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.x_path, 'r') as x_file:
            image = x_file['x'][idx].astype(np.uint8)

        image = transforms.ToPILImage()(image)
        if self.transform:
            image = self.transform(image)

        if self.has_labels:
            label = self.labels[idx].item()
            return image, label
        else:
            return image


# --------------------------
# パスと設定
# --------------------------
TRAIN_X = 'camelyonpatch_level_2_split_train_x.h5'
TRAIN_Y = 'camelyonpatch_level_2_split_train_y.h5'
VAL_X   = 'valid_x_uncompressed.h5'
VAL_Y   = 'valid_y_uncompressed.h5'
TEST_X  = 'camelyonpatch_level_2_split_test_x.h5'
TEST_Y  = 'camelyonpatch_level_2_split_test_y.h5'

mean = [0.702, 0.538, 0.597]
std = [0.144, 0.181, 0.177]

In [32]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(96, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)),
    transforms.RandomPerspective(distortion_scale=0.3, p=0.5),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomApply([transforms.RandomSolarize(threshold=0.5, p=0.3)], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
])

eval_transform = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [33]:
# --------------------------
# Dataloader
# --------------------------
batch_size = 128
train_ds = PCamDataset(TRAIN_X, TRAIN_Y, transform=train_transform)
val_ds   = PCamDataset(VAL_X, VAL_Y, transform=eval_transform)
test_ds  = PCamDataset(TEST_X, TEST_Y, transform=eval_transform)


train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [34]:
import timm  # pip install timm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('efficientnet_b3', pretrained=True, num_classes=1)
# または
# model = timm.create_model('convnext_small', pretrained=True, num_classes=1)

# カスタムヘッドの追加
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.BatchNorm1d(512),
    nn.SiLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 1)
)
model = model.to(device)

In [35]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

criterion = FocalLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)

In [36]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

scheduler = CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,  # 最初の周期のエポック数
    T_mult=2, # 周期の倍率
    eta_min=1e-7
)

In [37]:
def tta_predict(model, image, n_aug=5):
    model.eval()
    with torch.no_grad():
        # 基本変換
        base_tf = eval_transform
        
        # TTA変換リスト
        tta_transforms = [
            base_tf,
            transforms.Compose([base_tf, transforms.RandomHorizontalFlip(p=1.0)]),
            transforms.Compose([base_tf, transforms.RandomVerticalFlip(p=1.0)]),
            transforms.Compose([base_tf, transforms.RandomRotation(30)]),
            transforms.Compose([base_tf, transforms.ColorJitter(brightness=0.2, contrast=0.2)])
        ]
        
        outputs = []
        for tf in tta_transforms[:n_aug]:
            augmented = tf(image)
            output = model(augmented.unsqueeze(0).to(device))
            outputs.append(torch.sigmoid(output).cpu())
            
        return torch.stack(outputs).mean()

In [38]:
def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for imgs, labels in progress_bar:
        imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
        
        # MixUp データ拡張
        if np.random.rand() < 0.5:
            lam = np.random.beta(0.4, 0.4)
            rand_index = torch.randperm(imgs.size(0))
            labels_a = labels
            labels_b = labels[rand_index]
            
            mixed_x = lam * imgs + (1 - lam) * imgs[rand_index]
            outputs = model(mixed_x)
            loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 勾配クリッピング
        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
    
    return running_loss / len(train_loader)

In [39]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
import timm
from tqdm import tqdm

# scikit-learnが使えない場合の代替評価関数
def calculate_metrics(outputs, labels):
    outputs = torch.sigmoid(outputs).cpu().numpy()
    preds = (outputs > 0.5).astype(int)
    labels = labels.cpu().numpy()
    
    # 簡易版Accuracy計算
    acc = (preds == labels).mean()
    
    # 簡易版AUC計算（scikit-learnなし）
    pos_prob = outputs[labels == 1]
    neg_prob = outputs[labels == 0]
    auc = (pos_prob[:, None] > neg_prob).mean() if len(pos_prob) > 0 and len(neg_prob) > 0 else 0.5
    
    # 簡易版F1計算
    tp = ((preds == 1) & (labels == 1)).sum()
    fp = ((preds == 1) & (labels == 0)).sum()
    fn = ((preds == 0) & (labels == 1)).sum()
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    f1 = 2 * precision * recall / (precision + recall + 1e-7)
    
    return acc, auc, f1

# ...（その他のコードは以前と同じ）...

def evaluate(loader, split='Val'):
    model.eval()
    correct, total = 0, 0
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc=f"Evaluating {split}"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = (torch.sigmoid(outputs) > 0.5).long().squeeze()
            
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_outputs.append(outputs.cpu())
            all_labels.append(labels.cpu())
    
    # メトリクス計算
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    acc, auc, f1 = calculate_metrics(all_outputs, all_labels)
    
    print(f"{split} Metrics - Acc: {acc*100:.2f}% | AUC: {auc:.4f} | F1: {f1:.4f}")
    return acc

# ...（その他のコードは変更なし）...

In [40]:
def evaluate_with_tta(loader, split='Test'):
    model.eval()
    correct, total = 0, 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc=f"TTA {split}"):
            imgs, labels = imgs.to(device), labels.to(device)
            batch_preds = []
            
            for img in imgs:
                img_pil = transforms.ToPILImage()(img.cpu())
                pred = tta_predict(model, img_pil)
                batch_preds.append(pred)
            
            preds = (torch.stack(batch_preds) > 0.5).long().squeeze()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    acc = correct / total
    print(f"TTA {split} Accuracy: {acc * 100:.2f}%")
    return acc

In [41]:
# メイン訓練ループ
# --------------------------
def main():
    best_val_acc = 0
    early_stop_counter = 0
    patience = 7
    
    for epoch in range(1, 50):
        train_loss = train_one_epoch(epoch)
        val_acc = evaluate(val_loader, 'Val')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            early_stop_counter = 0
            print(f"New best model saved. Val Acc: {val_acc*100:.2f}%")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
    
    # 最終評価
    model.load_state_dict(torch.load('best_model.pth'))
    test_acc = evaluate(test_loader, 'Test')
    tta_acc = evaluate_with_tta(test_loader)
    
    print("\nFinal Results:")
    print(f"Standard Test Accuracy: {test_acc*100:.2f}%")
    print(f"TTA Test Accuracy: {tta_acc*100:.2f}%")

if __name__ == '__main__':
    main()

Epoch 1: 100%|██████████| 2048/2048 [05:53<00:00,  5.80it/s, loss=0.0305, lr=1.01e-5]
Evaluating Val: 100%|██████████| 256/256 [00:11<00:00, 23.15it/s]


Val Metrics - Acc: 50.01% | AUC: 0.8358 | F1: 0.4451
New best model saved. Val Acc: 50.01%


Epoch 2:  13%|█▎        | 259/2048 [00:42<04:52,  6.12it/s, loss=0.0322, lr=2.68e-6]


KeyboardInterrupt: 