In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random

#pytorch
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

#pyroch_lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

#huggingface
from datasets import load_dataset

pl.seed_everything(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# データセットのロード
dataset = load_dataset("1aurent/PatchCamelyon")
dataset.num_rows

In [None]:
# ====== 1️⃣ データセット定義 ======
class PatchCamelyonDataModule(pl.LightningDataModule):
    def __init__(self, dataset, batch_size=32):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.70075713, 0.5383597, 0.69162006], std=[0.23497961, 0.27741011, 0.21289514]),
        ])

    def setup(self, stage=None):
        # 画像だけにtransformを適用する関数を定義
        def transform_fn(examples):
            examples["image"] = [self.transform(img) for img in examples["image"]]
            examples["label"] = [int(label) for label in examples["label"]]
            return examples
        
        self.train_dataset = self.dataset['train'].with_transform(transform_fn)
        self.val_dataset = self.dataset['valid'].with_transform(transform_fn)
        self.test_dataset = self.dataset['test'].with_transform(transform_fn)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=6)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=6)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=6)
    
    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=6)

In [None]:
# ====== 2️⃣ モデル定義 ======
class ResNet50LitModel(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["label"]
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]
        preds = self(x)
        return preds

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        return [optimizer], [scheduler]

In [None]:
# ====== 3️⃣ 実行部分 ======

# データとモデル
data_module = PatchCamelyonDataModule(batch_size=256, dataset=dataset)
data_module.setup("test")  # testデータをセットアップ

# 最良のckptパスを取得
best_model_path = "lightning_logs/resnet50/version_8/checkpoints/resnet50-epoch=00-val_acc=0.84.ckpt"

# モデルを復元
model = ResNet50LitModel.load_from_checkpoint(best_model_path)

# テスト実行
trainer = pl.Trainer(accelerator="auto", devices=1)
trainer.test(model, datamodule=data_module)

In [None]:
# モデルから予測を得る
preds = trainer.predict(model, datamodule=data_module)

In [None]:
from torch.utils.data import Subset

n_samples = 20  # サンプリング数

test_dataset = data_module.test_dataloader().dataset

# ====== ランダムに20個サンプリング ======
indices = random.sample(range(len(test_dataset)), n_samples)
subset = Subset(test_dataset, indices)
subset_loader = DataLoader(subset, batch_size=n_samples, shuffle=False, num_workers=6)

# # ====== 予測を実行 ======
preds = trainer.predict(model, dataloaders=subset_loader)
pred_labels = preds[0].argmax(dim=1).cpu()

# ===== 正規化を元に戻す関数 =====
def denormalize(img):
    mean = torch.tensor([0.70075713, 0.5383597, 0.69162006])
    std = torch.tensor([0.23497961, 0.27741011, 0.21289514])
    return (img * std + mean).clamp(0, 1)  # 0〜1にクリップ

# クラス名
classes = ["Non-tumor", "Tumor"]

# 可視化
plt.figure(figsize=(15, 4))
for i in range(n_samples):
    image = subset[i]["image"]
    label = subset[i]["label"]
    pred_label = pred_labels[i]
    img = denormalize(image.permute(1, 2, 0))
    plt.subplot(2, n_samples // 2, i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(
        f"GT: {classes[label]}\nPred: {classes[pred_label]}",
        color=("green" if label == pred_label else "red")
    )
plt.tight_layout()
plt.show()

In [None]:
#Grad-CAMの導入
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Grad-CAMのセットアップ
target_layers = [model.model.layer4[-1]]
cam = GradCAM(model=model.model, target_layers=target_layers)
# 可視化
plt.figure(figsize=(15, 4))
for i in range(n_samples):
    image = subset[i]["image"]
    label = subset[i]["label"]
    pred_label = pred_labels[i]
    targets = [ClassifierOutputTarget(pred_label)]
    img = denormalize(image.permute(1, 2, 0)).cpu().numpy()
    input_tensor = image.unsqueeze(0).to(device)
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    plt.subplot(2, n_samples // 2, i + 1)
    plt.imshow(cam_image)
    plt.axis("off")
    plt.title(
        f"GT: {classes[label]}\nPred: {classes[pred_label]}",
        color=("green" if label == pred_label else "red")
    )
plt.tight_layout()
plt.show()

In [None]:
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
i=0
image = subset[i]["image"]
label = subset[i]["label"]
pred_label = pred_labels[i]
targets = [ClassifierOutputTarget(pred_label)]
img = denormalize(image.permute(1, 2, 0)).cpu().numpy()
input_tensor = image.unsqueeze(0).to(device)

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
# You can also get the model outputs without having to redo inference
model_outputs = cam.outputs

In [None]:
plt.imshow(visualization)