# 코드 설명

0. 해당 셀을 복사해서 main.ipynb의 맨 마지막셀에 넣어주시면 돼요

1. model 교체해주세요

```python
# model = models.googlenet(pretrained=True).to(device)
# setup_model_fc(model)
# model.load_state_dict(torch.load("Googlenet_32_0.0001.pth"))
```

2. target layer 교체해주세요

```python
# TODO: target layer 수정 필요!!!!!!!!
target_layers = []
cam = EigenCAM(model=model, target_layers=target_layers)
```


In [None]:
import math
from pytorch_grad_cam import GradCAM, EigenCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torchvision.models as models
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

from utils.loader import get_test_loaders

# TODO: 이거 사용하는 모델로 바꿔주세요!!!!!
# model = models.googlenet(pretrained=True).to(device)
# setup_model_fc(model)
# model.load_state_dict(torch.load("Googlenet_32_0.0001.pth"))


def apply_gradcam(model, images, targets, cam_algorithm):
    # GradCAM을 적용하여 CAM을 얻는 함수
    grayscale_cams = cam_algorithm(input_tensor=images, targets=targets)
    return grayscale_cams

real_test_loader = get_test_loaders()

print(len(real_test_loader))

# DataLoader에서 이미지와 레이블을 추출
images_list = []
labels_list = []
predictions_list = []

for i, (images, labels) in enumerate(real_test_loader):
    images_list.append(images)
    labels_list.append(labels)

    with torch.no_grad():
        outputs = model(images.to(device))
        _, predicted = torch.max(outputs, 1)
        predictions_list.append(predicted)

# 모델을 평가 모드로 설정
model.eval()

# TODO: target layer 수정 필요!!!!!!!!
target_layers = []
cam = EigenCAM(model=model, target_layers=target_layers)

# 리스트 평탄화
images_list = torch.cat(images_list)
labels_list = torch.cat(labels_list)
predictions_list = torch.cat(predictions_list)

# 표시할 이미지 수
num_images = len(real_test_loader) # 전체는 images_list.size(0)
num_cols = 2 
num_rows = math.ceil(num_images / num_cols)

fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 6, num_rows * 6))
axes = axes.flatten()

# categories = ['0', '1', '2', '3', '4', '5']
categories = [
    "DangerousDriving",
    "Distracted",
    "Drinking",
    "SafeDriving",
    "SleepyDriving",
    "Yawn"
]

# 미리 추출한 이미지, 레이블, 예측을 반복합니다.
for idx in range(num_images):
    images = images_list[idx].unsqueeze(0)
    labels = labels_list[idx]
    predicted = predictions_list[idx]

    # 현재 이미지의 예측 레이블 입력
    targets = [ClassifierOutputTarget(predicted.item())]

    # GradCAM을 적용
    grayscale_cam = apply_gradcam(model, images, targets, cam)
    grayscale_cam = grayscale_cam[0, :]  # 배치 차원 제거

    # 이미지를 원래 형태로 변환
    image_denormalized = images_list[idx] * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    rgb_img = image_denormalized.permute(1, 2, 0).numpy()  # (C, H, W) -> (H, W, C)

    # 결과 시각화
    visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=False)

    ax = axes[idx]
    ax.imshow(visualization)
    true_label_name = categories[labels.item()]
    predicted_label_name = categories[predicted.item()]
    ax.set_title(f"#{idx}True: {true_label_name} Pred: {predicted_label_name} / {'Right' if true_label_name == predicted_label_name else 'Wrong'}", fontsize=14)
    ax.axis('off')

for idx in range(num_images, len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

# 현재 시간 출력
import datetime
now = datetime.datetime.now()
now = now + datetime.timedelta(hours=9)
now = now.strftime('%Y-%m-%d %H:%M:%S')
print(now)

# test 세트 평가
test_loss, test_acc, test_preds, test_labels = evaluate(model, real_test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

# Confusion Matrix 계산
conf_matrix = confusion_matrix(test_labels, test_preds)

# Confusion Matrix 그리기
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="Blues", fmt="d", cbar=False,
            xticklabels=np.array(list(range(0, 6))), yticklabels=np.array(list(range(0, 6))))
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()