In [None]:
"""
    Copyright 2023 by Michał Stolarz <michal.stolarz@h-brs.de>

    This file is part of dlbm_binary.
    It is used for obtaining Grad-CAM++ heatmaps for ResNet50 model.

    dlbm_binary is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    dlbm_binary is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with dlbm_binary. If not, see <http://www.gnu.org/licenses/>.
"""

import torch
from training.pretrained import DeepClassifier
from config import config_resnet
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from utils.data_loader import InteractionDataset
from torchvision.models import ResNet50_Weights as weights
from torchvision.transforms import Resize
import numpy as np
import matplotlib.pyplot as plt
import cv2

torch.manual_seed(42)

In [None]:
agent = DeepClassifier(cfg=config_resnet, input_state_size=1, rgb=True, validation=True)
model = agent.model

target_layers = [model.layer4[-1].relu]
model.layer4[-1]

preprocess = weights.IMAGENET1K_V1.transforms()

In [None]:
limit = 30
loop_index = 0

for dataset_id in range(1, 4):

    if dataset_id == 2:
        test = InteractionDataset(path='/home/michal/thesis/test_frames', n_frames_used=1, rgb=True, standardization=False)
    elif dataset_id == 3:
        test = InteractionDataset(path='/home/michal/thesis/white_background_frames', n_frames_used=1, rgb=True, standardization=False)

    if dataset_id in [2, 3]:
        testloader = torch.utils.data.DataLoader(test,
                                             batch_size=1)

    if dataset_id == 1:
        testloader = agent.testloader

    img_id = 0

    for image, label in testloader:
        input_tensor = image.to('cuda')
        img_path = testloader.dataset.img_path.split('/')[-1]

        if dataset_id == 1:
            if img_path not in ['x_6XTLNK55_b11_14_6.jpg', 'x_5J7PWO3G_b8_16_3.jpg', 'x_3G4MPE2W_randomised_b2_2_2.jpg', 'x_Q4GTE6L4_b3_14_1.jpg']:
                continue

        if dataset_id == 2:
            if img_path not in ['6XTLNK55_b1_14_6.jpg', '5J7PWO3G_b0_16_3.jpg', '3G4MPE2W_randomised_b1_2_2.jpg', 'Q4GTE6L4_b4_14_1.jpg']:
                continue

        if dataset_id == 3:
            if img_path not in ['6XTLNK55_w_14_6.jpg', '5J7PWO3G_w_16_3.jpg', '3G4MPE2W_randomised_w_2_2.jpg', 'Q4GTE6L4_w_14_1.jpg']:
                continue

        pred = model(preprocess(input_tensor))
        pred1 = pred.argmax(dim=1)
        label_num = label.tolist()[0]
        label_pred_num = pred1.tolist()[0]

        if label_pred_num != label_num:
            continue

        print(img_path)
        input_tensor = image
        cam = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True)
        targets = [ClassifierOutputTarget(int(label))]

        # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
        grayscale_cam = cam(input_tensor=input_tensor,
                            targets=targets)

        # In this example grayscale_cam has only one image in the batch:
        grayscale_cam = grayscale_cam[0, :]
        gray_image = np.array(image[0][0])
        stacked_gray_img = np.stack((gray_image,) * 3, axis=-1)
        #img = np.zeros([198, 198, 3], dtype=np.float32)

        visualization = show_cam_on_image(stacked_gray_img, grayscale_cam, use_rgb = True)
        plt.imshow(visualization)
        plt.show()
        print(f"Label: {label.tolist()[0]} \n"
              f"Predicted: {pred1.tolist()[0]}")
        cv2.imwrite('data_' + str(dataset_id) + '_img_' + str(img_id) + '.jpg', cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB))
        img_id+=1