In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np


In [2]:
class ImageClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ImageClassifier, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_features = self.base_model.fc.in_features
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.fc(x)
        return x


In [3]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

    def forward(self, x):
        self.activations = []
        self.gradients = []

        def forward_hook(module, input, output):
            self.activations.append(output)

        def backward_hook(module, grad_input, grad_output):
            self.gradients.append(grad_output[0])

        target_layer = self.model._modules[self.target_layer]
        hook_a = target_layer.register_forward_hook(forward_hook)
        hook_g = target_layer.register_backward_hook(backward_hook)

        self.model.eval()
        with torch.no_grad():
            _ = self.model(x)

        hook_a.remove()
        hook_g.remove()

        self.activations = self.activations[0]
        self.gradients = self.gradients[0]

    def generate_heatmap(self, target_class):
        weights = torch.mean(self.gradients, dim=(2, 3))[0]
        heatmap = torch.zeros_like(self.activations[0])

        for i, weight in enumerate(weights):
            heatmap += weight * self.activations[i]

        heatmap = torch.relu(heatmap)
        heatmap /= torch.max(heatmap)

        return heatmap


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
from torch.utils.data import ConcatDataset
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch

# アノマリーデータのディレクトリのパス
anomaly_data_dir = "/content/drive/MyDrive/path_to_data_directory"

# ノーマルデータのディレクトリのパス
normal_data_dir = "/content/drive/MyDrive/path_to_data_directory"

# 前処理の設定
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# アノマリーデータのデータセットを作成し、ラベルを割り当てる
anomaly_dataset = ImageFolder(anomaly_data_dir, transform=transform)
anomaly_labels = torch.ones(len(anomaly_dataset))  # アノマリーデータのラベルを1とする

# ノーマルデータのデータセットを作成し、ラベルを割り当てる
normal_dataset = ImageFolder(normal_data_dir, transform=transform)
normal_labels = torch.zeros(len(normal_dataset))  # ノーマルデータのラベルを0とする

# データセットとラベルを結合
dataset = ConcatDataset([anomaly_dataset, normal_dataset])
labels = torch.cat([anomaly_labels, normal_labels])

# データローダーの定義
batch_size = 32
train_dataset = list(zip(dataset, labels))  # データとラベルをタプルとしてまとめる
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



FileNotFoundError: ignored

In [6]:
# データの前処理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# モデルの初期化
num_classes = 2
model = ImageClassifier(num_classes)

# モデルの学習と評価

# データの読み込みと分類ラベルの割り当て
# ここでは、datasetsとdataloadersを作成するコードが必要です

# オプティマイザと損失関数の定義
# オプティマイザと損失関数の定義
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# モデルの学習と評価
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_dataloader)

    # バリデーションデータでの評価
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

# テストデータでの評価
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_dataloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_accuracy = test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.4f}")

# テストデータの一部をランダムに表示してヒートマップを生成
num_display_images = 5
display_images, display_labels = iter(test_dataloader).next()

model.eval()
gradcam = GradCAM(model, target_layer='layer4')  # Grad-CAMを適用するターゲットレイヤーを指定

for i in range(num_display_images):
    image = display_images[i].unsqueeze(0)
    label = display_labels[i].item()

    output = model(image)
    _, predicted = torch.max(output.data, 1)
    confidence = torch.softmax(output, dim=1)[0, predicted[0]].item()

    gradcam.forward(image)
    heatmap = gradcam.generate_heatmap(predicted[0])

    image = image.squeeze(0)
    heatmap = heatmap.numpy()

    # 画像とヒートマップを表示
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
    ax[0].imshow(np.transpose(image, (1, 2, 0)))
    ax[0].set_title(f"Predicted: {predicted.item()}, Confidence: {confidence:.2f}")
    ax[0].axis('off')
    ax[1].imshow(heatmap, cmap='jet', alpha=0.5)
    ax[1].imshow(np.transpose(image, (1, 2, 0)), alpha=0.5)
    ax[1].set_title("Grad-CAM Heatmap")
    ax[1].axis('off')

    plt.show()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 97.8MB/s]


NameError: ignored