In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random

#pytorch
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

#pyroch_lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

#huggingface
from datasets import load_dataset

In [None]:
pl.seed_everything(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# データセットのロード
dataset = load_dataset("1aurent/PatchCamelyon")
dataset.num_rows

In [None]:
# train_dataset = load_dataset("1aurent/PatchCamelyon", split="train", streaming=True) #Hugging Face datasetsが逐次読み込みモードで動作（RAMをほぼ使わない）

# # --- 初期化 ---
# mean = np.zeros(3, dtype=np.float64)
# M2 = np.zeros(3, dtype=np.float64)
# n_pixels = 0

# # --- ストリーミング処理 ---
# for sample in tqdm(train_dataset, desc="Computing mean/std"):
#     img = np.array(sample["image"], dtype=np.float32) / 255.0  # (H, W, C)
#     pixels = img.reshape(-1, 3) # (H*W, C)
#     batch_mean = pixels.mean(axis=0)
#     batch_M2 = pixels.var(axis=0) * len(pixels)
    
#     # running update
#     old_n = n_pixels
#     n_pixels += len(pixels)
#     delta = batch_mean - mean
#     mean += delta * len(pixels) / n_pixels
#     M2 += batch_M2 + delta**2 * old_n * len(pixels) / n_pixels

# std = np.sqrt(M2 / n_pixels)

# print("Mean:", mean)
# print("Std :", std)

#計算結果
print("Mean: [0.70075713 0.5383597  0.69162006]")
print("Std : [0.23497961 0.27741011 0.21289514]")

In [None]:
# ====== 1️⃣ データセット定義 ======
class PatchCamelyonDataModule(pl.LightningDataModule):
    def __init__(self, dataset, batch_size=32):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.70075713, 0.5383597, 0.69162006], std=[0.23497961, 0.27741011, 0.21289514]),
        ])

    def setup(self, stage=None):
        # 画像だけにtransformを適用する関数を定義
        def transform_fn(examples):
            examples["image"] = [self.transform(img) for img in examples["image"]]
            examples["label"] = [int(label) for label in examples["label"]]
            return examples
        
        self.train_dataset = self.dataset['train'].with_transform(transform_fn)
        self.val_dataset = self.dataset['valid'].with_transform(transform_fn)
        self.test_dataset = self.dataset['test'].with_transform(transform_fn)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=6)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=6)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=6)

In [None]:
datamodule = PatchCamelyonDataModule(dataset, batch_size=256)
datamodule.setup("fit")

# train_dataloaderを取得
train_loader = datamodule.train_dataloader()

# DataLoaderからすべてのバッチを取り出してリスト化
train_dataset = train_loader.dataset

# ランダムに10個選ぶ
indices = random.sample(range(len(train_dataset)), 10)
samples = [train_dataset[i] for i in indices]

# クラス名
label_names = ["Non-tumor", "Tumor"]

# 画像とラベルを表示
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for ax, sample in zip(axes.flatten(), samples):
    img = sample["image"].permute(1, 2, 0).numpy()  # CHW -> HWC
    label = sample["label"]
    ax.imshow(img)
    ax.set_title(f"Label: {label_names[label]}")
    ax.axis("off")
plt.show()

In [None]:
# ====== 2️⃣ モデル定義 ======
class ResNet50LitModel(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        return [optimizer], [scheduler]

In [None]:
# ====== 3️⃣ 実行部分 ======

# データとモデル
data_module = PatchCamelyonDataModule(batch_size=256, dataset=dataset)
model = ResNet50LitModel(lr=1e-3)

# ====== コールバックとロガー設定 ======
# ① モデルの自動保存
checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",       # 監視対象（validation accuracy）
    mode="max",              # 高い方が良い
    save_top_k=1,            # 最良モデル1つだけ保存
    filename="resnet50-{epoch:02d}-{val_acc:.2f}"
)

# ② 学習率モニタリング
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# ③ 早期終了（EarlyStopping）
early_stop_callback = EarlyStopping(
    monitor="val_loss",      # 監視対象（val_loss）
    mode="min",              # 小さい方が良い
    patience=3,              # 3エポック連続で改善がなければ終了
    verbose=True
)

# ④ TensorBoard ロガー
logger = TensorBoardLogger(
    save_dir="lightning_logs",   # ログ保存ディレクトリ
    name="resnet50"      # 実験名（サブフォルダ名）
)

# ====== トレーナー設定 ======
trainer = Trainer(
    max_epochs=20,
    accelerator="auto",    # GPUが使えるなら自動でGPU
    devices=1 if torch.cuda.is_available() else None,
    precision=16 if torch.cuda.is_available() else 32,  # GPUならFP16高速化
    callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
    logger=logger,         # TensorBoard ロガーを有効化
)

# ====== 学習実行 ======
trainer.fit(model, datamodule=data_module)