### 싱글모델로 분류

In [2]:

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
import torch.nn.functional as F
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from nmc.models import *
from nmc.datasets import * 
from nmc.augmentations import get_train_augmentation, get_val_augmentation
from nmc.losses import get_loss
from nmc.schedulers import get_scheduler
from nmc.optimizers import get_optimizer
from nmc.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from tools.val import evaluate_epi
from nmc.utils.episodic_utils import * 
from scipy.cluster import hierarchy
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
from torch.optim import lr_scheduler
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mutual_info_score
from scipy.cluster import hierarchy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data import Subset
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import cv2

In [3]:
with open('../configs/NMC.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)
print(cfg)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
save_dir = Path(cfg['SAVE_DIR'])
save_dir.mkdir(exist_ok=True)
cleanup_ddp()

{'DEVICE': 'cuda:1', 'SAVE_DIR': 'output', 'MODEL': {'NAME': 'FGMaxxVit', 'BACKBONE': 'FGMaxxVit', 'PRETRAINED': 'checkpoints/pretrained/maxvit_base_tf_512.in1k_pretrained_weights.pth', 'UNFREEZE': 'full', 'VERSION': 'ImageNet_APTOS_1024'}, 'DATASET': {'NAME': 'NMCSDataset', 'ROOT': '/data/nmc/processed_image', 'TRAIN_RATIO': 0.7, 'VALID_RATIO': 0.15, 'TEST_RATIO': 0.15}, 'TRAIN': {'IMAGE_SIZE': [512, 512], 'BATCH_SIZE': 32, 'EPOCHS': 100, 'EVAL_INTERVAL': 25, 'AMP': False, 'DDP': False}, 'LOSS': {'NAME': 'BCEWithLogitsLoss', 'CLS_WEIGHTS': False}, 'OPTIMIZER': {'NAME': 'adamw', 'LR': 0.0001, 'WEIGHT_DECAY': 0.01}, 'SCHEDULER': {'NAME': 'warmuppolylr', 'POWER': 0.9, 'WARMUP': 10, 'WARMUP_RATIO': 0.1}, 'EVAL': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.NMC.pth', 'IMAGE_SIZE': [512, 512]}, 'TEST': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.NMC.pth', 'FILE': 'assests/ade', 'IMAGE_SIZE': [512, 512], 'OVERLAY': True}}


In [4]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [5]:
def get_train_augmentation(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_test_transform(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [6]:
start = time.time()
best_mf1 = 0.0
device = torch.device(cfg['DEVICE'])
print("device : ", device)
num_workers = mp.cpu_count()
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']

image_size = [256,256]
image_dir = Path(dataset_cfg['ROOT']) / 'train_images'
train_transform = get_train_augmentation(image_size)
val_test_transform = get_val_test_transform(image_size)
batch_size = 32
###################
target_label = 0
###################
dataset = eval(dataset_cfg['NAME'])(
    dataset_cfg['ROOT'] + '/combined_images',
    dataset_cfg['TRAIN_RATIO'],
    dataset_cfg['VALID_RATIO'],
    dataset_cfg['TEST_RATIO'],
    transform=None,
    target_label=target_label
)
trainset, valset, testset = dataset.get_splits()
trainset.transform = train_transform
valset.transform = val_test_transform
testset.transform = val_test_transform

trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True)
valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
testloader = DataLoader(testset, batch_size=1, num_workers=1, pin_memory=True)


device :  cuda:1
/data/nmc/processed_image/combined_images
Target label: 0
Train size: 677, Positive samples: 353
Validation size: 145, Positive samples: 76
Test size: 146, Positive samples: 76


In [7]:

# Model definition (changed to binary classification)
resnet = models.resnext101_32x8d(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.Linear(num_ftrs, 1)
)
resnet = resnet.to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [8]:
# L2 regularization
weight_decay = 1e-5
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.0001, weight_decay=weight_decay)
criterion = nn.BCEWithLogitsLoss()
scaler = GradScaler(enabled=train_cfg['AMP'])
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)


In [9]:

def train_epoch(model, dataloader, criterion, optimizer, scaler, device):
    model.train()
    total_loss = 0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast(enabled=scaler is not None):
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels.float())
        
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [10]:
def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            preds = (torch.sigmoid(outputs) > 0.5).int().squeeze()
            
            if preds.dim() == 0:
                preds = preds.unsqueeze(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    f1 = f1_score(all_labels, all_preds)
    return f1

In [11]:
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, scaler, device, epochs):
    best_f1 = 0.0
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_f1 = evaluate(model, val_loader, device)
        
        print(f"Training Loss: {train_loss:.4f}")
        print(f"Validation F1 Score: {val_f1:.4f}")
        
        scheduler.step(val_f1)
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), 'model/best_model_{}.pth'.format(target_label))
            print("New best model saved!")
        
        early_stopping(val_f1)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
        
        print()
    
    return best_f1

In [12]:
# Main execution code
# 정규화, lr스케쥴링, 데이터 증강, 조기종료, 배치정규화
epochs = 100
best_f1 = train_and_evaluate(resnet, trainloader, valloader, criterion, optimizer, scaler, device, epochs)

print(f"Training completed. Best F1 Score: {best_f1:.4f}")

# Final evaluation on test set
resnet.load_state_dict(torch.load('model/best_model_{}.pth'.format(target_label)))
test_f1 = evaluate(resnet, testloader, device)
print(f"Test F1 Score: {test_f1:.4f}")

Epoch 1/100


Training: 100%|██████████| 21/21 [00:19<00:00,  1.09it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.97it/s]


Training Loss: 0.6780
Validation F1 Score: 0.7919
New best model saved!

Epoch 2/100


Training: 100%|██████████| 21/21 [00:14<00:00,  1.42it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 21.93it/s]


Training Loss: 0.4548
Validation F1 Score: 0.6833

Epoch 3/100


Training: 100%|██████████| 21/21 [00:14<00:00,  1.50it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 23.72it/s]


Training Loss: 0.3309
Validation F1 Score: 0.7500

Epoch 4/100


Training: 100%|██████████| 21/21 [00:13<00:00,  1.61it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 23.39it/s]


Training Loss: 0.2840
Validation F1 Score: 0.7922
New best model saved!

Epoch 5/100


Training: 100%|██████████| 21/21 [00:14<00:00,  1.42it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 23.62it/s]


Training Loss: 0.2044
Validation F1 Score: 0.7564

Epoch 6/100


Training: 100%|██████████| 21/21 [00:15<00:00,  1.35it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 20.98it/s]


Training Loss: 0.2163
Validation F1 Score: 0.8662
New best model saved!

Epoch 7/100


Training: 100%|██████████| 21/21 [00:17<00:00,  1.23it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.40it/s]


Training Loss: 0.1472
Validation F1 Score: 0.7841

Epoch 8/100


Training: 100%|██████████| 21/21 [00:15<00:00,  1.36it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 20.11it/s]


Training Loss: 0.1325
Validation F1 Score: 0.7651

Epoch 9/100


Training: 100%|██████████| 21/21 [00:14<00:00,  1.44it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 20.00it/s]


Training Loss: 0.1061
Validation F1 Score: 0.8293

Epoch 10/100


Training: 100%|██████████| 21/21 [00:13<00:00,  1.53it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.67it/s]


Training Loss: 0.1372
Validation F1 Score: 0.8442

Epoch 11/100


Training: 100%|██████████| 21/21 [00:12<00:00,  1.66it/s]
Evaluating: 100%|██████████| 145/145 [00:08<00:00, 17.53it/s]


Training Loss: 0.0846
Validation F1 Score: 0.8302

Epoch 12/100


Training: 100%|██████████| 21/21 [00:11<00:00,  1.80it/s]
Evaluating: 100%|██████████| 145/145 [00:08<00:00, 18.07it/s]


Training Loss: 0.0996
Validation F1 Score: 0.8690
New best model saved!

Epoch 13/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.92it/s]
Evaluating: 100%|██████████| 145/145 [00:08<00:00, 16.61it/s]


Training Loss: 0.0528
Validation F1 Score: 0.8767
New best model saved!

Epoch 14/100


Training: 100%|██████████| 21/21 [00:13<00:00,  1.55it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 21.91it/s]


Training Loss: 0.0404
Validation F1 Score: 0.8734

Epoch 15/100


Training: 100%|██████████| 21/21 [00:15<00:00,  1.39it/s]
Evaluating: 100%|██████████| 145/145 [00:05<00:00, 24.83it/s]


Training Loss: 0.0344
Validation F1 Score: 0.8790
New best model saved!

Epoch 16/100


Training: 100%|██████████| 21/21 [00:17<00:00,  1.20it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 22.55it/s]


Training Loss: 0.0299
Validation F1 Score: 0.8904
New best model saved!

Epoch 17/100


Training: 100%|██████████| 21/21 [00:10<00:00,  2.10it/s]
Evaluating: 100%|██████████| 145/145 [00:06<00:00, 23.48it/s]


Training Loss: 0.0245
Validation F1 Score: 0.8919
New best model saved!

Epoch 18/100


Training: 100%|██████████| 21/21 [00:11<00:00,  1.80it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.54it/s]


Training Loss: 0.0177
Validation F1 Score: 0.8535

Epoch 19/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.95it/s]
Evaluating: 100%|██████████| 145/145 [00:08<00:00, 17.99it/s]


Training Loss: 0.0099
Validation F1 Score: 0.8742

Epoch 20/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.91it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.47it/s]


Training Loss: 0.0260
Validation F1 Score: 0.8243

Epoch 21/100


Training: 100%|██████████| 21/21 [00:10<00:00,  2.00it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.33it/s]


Training Loss: 0.0366
Validation F1 Score: 0.8535

Epoch 22/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.94it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.44it/s]


Training Loss: 0.0634
Validation F1 Score: 0.8344

Epoch 23/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.93it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.71it/s]


Training Loss: 0.0666
Validation F1 Score: 0.8944
New best model saved!

Epoch 24/100


Training: 100%|██████████| 21/21 [00:12<00:00,  1.67it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 20.40it/s]


Training Loss: 0.0311
Validation F1 Score: 0.7871

Epoch 25/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.93it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.15it/s]


Training Loss: 0.0326
Validation F1 Score: 0.8280

Epoch 26/100


Training: 100%|██████████| 21/21 [00:11<00:00,  1.89it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 18.81it/s]


Training Loss: 0.0183
Validation F1 Score: 0.8267

Epoch 27/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.99it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.30it/s]


Training Loss: 0.0234
Validation F1 Score: 0.8571

Epoch 28/100


Training: 100%|██████████| 21/21 [00:11<00:00,  1.85it/s]
Evaluating: 100%|██████████| 145/145 [00:09<00:00, 15.94it/s]


Training Loss: 0.0229
Validation F1 Score: 0.8701

Epoch 29/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.94it/s]
Evaluating: 100%|██████████| 145/145 [00:08<00:00, 17.07it/s]


Training Loss: 0.0295
Validation F1 Score: 0.7947
Epoch 00029: reducing learning rate of group 0 to 1.0000e-05.

Epoch 30/100


Training: 100%|██████████| 21/21 [00:11<00:00,  1.90it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.05it/s]


Training Loss: 0.0347
Validation F1 Score: 0.8108

Epoch 31/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.96it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.66it/s]


Training Loss: 0.0119
Validation F1 Score: 0.8267

Epoch 32/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.96it/s]
Evaluating: 100%|██████████| 145/145 [00:07<00:00, 19.30it/s]


Training Loss: 0.0044
Validation F1 Score: 0.8267

Epoch 33/100


Training: 100%|██████████| 21/21 [00:10<00:00,  1.98it/s]
Evaluating: 100%|██████████| 145/145 [00:09<00:00, 15.76it/s]


Training Loss: 0.0063
Validation F1 Score: 0.8400
Early stopping triggered
Training completed. Best F1 Score: 0.8944


Evaluating: 100%|██████████| 146/146 [00:06<00:00, 22.56it/s]

Test F1 Score: 0.8874





In [14]:
resnet.load_state_dict(torch.load('model/best_model_{}.pth'.format(target_label)))
test_f1 = evaluate(resnet, testloader, device)
print(f"Test F1 Score: {test_f1:.4f}")

Evaluating: 100%|██████████| 146/146 [00:07<00:00, 18.87it/s]

Test F1 Score: 0.8874





In [15]:
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def evaluate_and_save_misclassified(model, dataloader, device, save_dir):
    model.eval()
    all_preds = []
    all_labels = []
    misclassified = []
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # 정규화에 사용된 평균과 표준편차 (이 값들은 사용한 transform에 맞게 조정해야 함)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Evaluating")):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            preds = (torch.sigmoid(outputs) > 0.5).int().squeeze()
            
            if preds.dim() == 0:
                preds = preds.unsqueeze(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # 잘못 분류된 이미지 저장
            for j, (pred, label) in enumerate(zip(preds, labels)):
                if pred != label:
                    img = images[j].cpu()
                    
                    # 정규화 해제
                    img = denormalize(img, mean, std)
                    
                    # [0, 1] 범위의 float 텐서를 [0, 255] 범위의 uint8로 변환
                    img = (img * 255).byte()
                    
                    img = transforms.ToPILImage()(img)
                    img_name = f"misclassified_{i}_{j}_pred{pred.item()}_true{label.item()}.png"
                    img.save(os.path.join(save_dir, img_name))
                    misclassified.append((i, j, pred.item(), label.item()))
    
    return all_preds, all_labels, misclassified

In [70]:
# 모델 로드
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = models.resnext101_32x8d()
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.Linear(num_ftrs, 1)
)
resnet.load_state_dict(torch.load('model/best_model_{}.pth'.format(target_label), map_location=device))
resnet = resnet.to(device)
resnet.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1

In [71]:
import os
from torchvision import transforms
# 평가 및 잘못 분류된 이미지 저장
save_dir = "misclassified_images"
all_preds, all_labels, misclassified = evaluate_and_save_misclassified(resnet, testloader, device, save_dir)

print(f"Total misclassified images: {len(misclassified)}")
print("Misclassified images saved in:", save_dir)

# 잘못 분류된 이미지 분석
false_positives = sum(1 for pred, label in zip(all_preds, all_labels) if pred == 1 and label == 0)
false_negatives = sum(1 for pred, label in zip(all_preds, all_labels) if pred == 0 and label == 1)

print(f"False Positives: {false_positives}")
print(f"False Negatives: {false_negatives}")

# 몇 개의 잘못 분류된 이미지 예시 출력
print("\nSome examples of misclassified images:")
for i, (batch, index, pred, true) in enumerate(misclassified[:5]):  # 처음 5개만 출력
    print(f"Image {i+1}: Batch {batch}, Index {index}, Predicted {pred}, True label {true}")

Evaluating: 100%|██████████| 146/146 [00:07<00:00, 20.18it/s]

Total misclassified images: 17
Misclassified images saved in: misclassified_images
False Positives: 8
False Negatives: 9

Some examples of misclassified images:
Image 1: Batch 1, Index 0, Predicted 1, True label 0.0
Image 2: Batch 3, Index 0, Predicted 1, True label 0.0
Image 3: Batch 4, Index 0, Predicted 1, True label 0.0
Image 4: Batch 19, Index 0, Predicted 1, True label 0.0
Image 5: Batch 46, Index 0, Predicted 1, True label 0.0





In [72]:
class MultiLayerGradCAM:
    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.activations = {}
        self.gradients = {}
        
        for name, layer in self.target_layers:
            layer.register_forward_hook(self.save_activation(name))
            layer.register_backward_hook(self.save_gradient(name))
    
    def save_activation(self, name):
        def hook(module, input, output):
            self.activations[name] = output
        return hook
    
    def save_gradient(self, name):
        def hook(module, grad_input, grad_output):
            self.gradients[name] = grad_output[0]
        return hook
    
    def generate_cam(self, input_image, target_class=None):
        self.model.eval()
        model_output = self.model(input_image)
        
        if target_class is None:
            target_class = model_output.argmax(dim=1)
        
        self.model.zero_grad()
        model_output[0, target_class].backward()
        
        heatmaps = []
        for name, _ in self.target_layers:
            activation = self.activations[name]
            gradient = self.gradients[name]
            
            pooled_gradients = torch.mean(gradient, dim=[0, 2, 3])
            heatmap = torch.sum(activation * pooled_gradients.unsqueeze(1).unsqueeze(2), dim=1)
            heatmap = F.relu(heatmap)
            heatmap = heatmap / (torch.max(heatmap) + 1e-10)
            heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), 
                                    size=input_image.shape[2:],
                                    mode='bilinear', 
                                    align_corners=False)
            heatmaps.append(heatmap)
        
        return torch.mean(torch.stack(heatmaps), dim=0)[0, 0].detach().cpu().numpy()


In [73]:
target_layers = [
    ('layer2', resnet.layer2[-1]),
    ('layer3', resnet.layer3[-1]),
    ('layer4', resnet.layer4[-1])
]
grad_cam = MultiLayerGradCAM(resnet, target_layers)


In [77]:
def visualize_gradcam(image, label, prediction, index, save_path):
    try:
        # 이미지를 CPU로 이동하고 numpy 배열로 변환
        image_np = image.cpu().numpy().transpose(1, 2, 0)
        
        # 정규화 해제
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image_np = std * image_np + mean
        image_np = np.clip(image_np, 0, 1)

        # GradCAM 생성
        input_tensor = image.unsqueeze(0).to(device)
        heatmap = grad_cam.generate_cam(input_tensor)
        
        # Heatmap을 원본 이미지 크기로 리사이즈
        heatmap = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
        
        # Heatmap을 RGB로 변환
        heatmap_rgb = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
        heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB) / 255.0
        
        # Heatmap과 원본 이미지 합성
        cam_image = (heatmap_rgb * 0.4 + image_np * 0.6)
        cam_image = cam_image / np.max(cam_image)

        # 결과 시각화 및 저장
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
        
        ax1.imshow(image_np)
        ax1.set_title(f'Original (True: {label}, Pred: {prediction})')
        ax1.axis('off')
        
        ax2.imshow(heatmap, cmap='jet')
        ax2.set_title('Grad-CAM Heatmap')
        ax2.axis('off')
        
        ax3.imshow(cam_image)
        ax3.set_title('Grad-CAM Overlay')
        ax3.axis('off')
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close(fig)
        
    except Exception as e:
        print(f"Error processing image {index}: {str(e)}")

In [78]:
# GradCAM 적용 및 시각화
gradcam_dir = "gradcam_results"
os.makedirs(gradcam_dir, exist_ok=True)

resnet.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(tqdm(testloader, desc="Generating Grad-CAM")):
        images, labels = images.to(device), labels.to(device)
        outputs = resnet(images)
        _, predicted = outputs.max(1)
        
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        for i in range(images.size(0)):
            if predicted[i] != labels[i]:  # 잘못 분류된 이미지에 대해서만
                save_path = os.path.join(gradcam_dir, f"gradcam_batch{batch_idx}_img{i}_true{labels[i]}_pred{predicted[i]}.png")
                visualize_gradcam(images[i], labels[i].item(), predicted[i].item(), f"{batch_idx}_{i}", save_path)

print(f"Test Accuracy: {100.*correct/total:.2f}%")
print("Grad-CAM visualizations have been saved in:", gradcam_dir)

Generating Grad-CAM:   6%|▌         | 9/146 [00:00<00:07, 18.20it/s]

Error processing image 5_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 7_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  12%|█▏        | 18/146 [00:00<00:05, 22.63it/s]

Error processing image 14_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 16_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 17_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  18%|█▊        | 27/146 [00:01<00:04, 29.26it/s]

Error processing image 22_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 26_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 27_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 28_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  24%|██▍       | 35/146 [00:01<00:04, 27.17it/s]

Error processing image 31_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 32_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 34_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  28%|██▊       | 41/146 [00:01<00:04, 24.00it/s]

Error processing image 38_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 40_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  32%|███▏      | 47/146 [00:02<00:04, 24.26it/s]

Error processing image 43_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 44_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 45_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 47_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  36%|███▋      | 53/146 [00:02<00:03, 23.44it/s]

Error processing image 49_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 51_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 52_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 53_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  41%|████      | 60/146 [00:02<00:03, 26.22it/s]

Error processing image 54_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 55_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 58_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  44%|████▍     | 64/146 [00:02<00:02, 27.41it/s]

Error processing image 60_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 63_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 65_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  49%|████▊     | 71/146 [00:02<00:02, 27.25it/s]

Error processing image 67_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 68_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 70_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 71_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  53%|█████▎    | 78/146 [00:03<00:02, 25.96it/s]

Error processing image 72_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 73_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 77_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  58%|█████▊    | 84/146 [00:03<00:02, 25.35it/s]

Error processing image 80_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 82_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 84_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  60%|█████▉    | 87/146 [00:03<00:02, 22.11it/s]

Error processing image 85_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 86_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 88_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  64%|██████▍   | 94/146 [00:03<00:02, 21.88it/s]

Error processing image 90_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 91_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 93_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 94_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  68%|██████▊   | 100/146 [00:04<00:02, 21.85it/s]

Error processing image 97_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 98_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 99_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 100_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 101_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  73%|███████▎  | 106/146 [00:04<00:01, 21.85it/s]

Error processing image 102_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 103_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 104_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 105_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  77%|███████▋  | 112/146 [00:04<00:01, 23.73it/s]

Error processing image 107_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 108_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 109_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  79%|███████▉  | 115/146 [00:04<00:01, 22.98it/s]

Error processing image 112_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 113_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 115_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  83%|████████▎ | 121/146 [00:05<00:01, 21.57it/s]

Error processing image 117_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 119_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 120_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  89%|████████▉ | 130/146 [00:05<00:00, 21.18it/s]

Error processing image 125_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 128_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 129_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  93%|█████████▎| 136/146 [00:05<00:00, 22.93it/s]

Error processing image 130_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 132_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 133_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 134_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM:  96%|█████████▌| 140/146 [00:05<00:00, 24.64it/s]

Error processing image 137_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 139_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 140_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 141_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM: 100%|██████████| 146/146 [00:06<00:00, 23.90it/s]

Error processing image 142_0: element 0 of tensors does not require grad and does not have a grad_fn
Error processing image 145_0: element 0 of tensors does not require grad and does not have a grad_fn


Generating Grad-CAM: 100%|██████████| 146/146 [00:06<00:00, 22.86it/s]

Test Accuracy: 47.95%
Grad-CAM visualizations have been saved in: gradcam_results



