In [None]:
"""
    Copyright 2023 by Michał Stolarz <michal.stolarz@h-brs.de>

    This file is part of dl_behaviour_model_binary.
    It is used for obtaining Grad-CAM++ heatmaps for DLC8 model.

    dl_behaviour_model_binary is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    dl_behaviour_model_binary is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with dl_behaviour_model_binary. If not, see <http://www.gnu.org/licenses/>.
"""

import torch
from training.deep_classifier import DeepClassifier
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
import torch.nn as nn
from utils.data_loader import InteractionDataset
import cv2

In [None]:
class DLCNet3D(nn.Module):
    def __init__(self):
        super(DLCNet3D, self).__init__()

        # get the pretrained VGG19 network
        self.agent = DeepClassifier(input_state_size=8)
        self.model = self.agent.model

        # disect the network to access its last convolutional layer
        features_list = self.model.features[:11]
        self.features_conv = nn.Sequential(*features_list)

        # get the max pool of the features stem
        self.max_pool = self.model.features[11]

        # get the classifier
        self.classifier = self.model.classifier

        # placeholder for the gradients
        self.gradients = None

    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad

    def forward(self, x):
        x = self.features_conv(x)

        # register the hook
        h = x.register_hook(self.activations_hook)

        # apply the remaining pooling
        x = self.max_pool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        return x

    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients

    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

In [None]:
dlc = DLCNet3D()
dlc.eval()
print("Done")

def get_heatmap(grads, activations):
        grads_power_2 = grads**2
        grads_power_3 = grads_power_2 * grads
        sum_activations = np.sum(activations, axis=(0, 2, 3))
        sum_activations = np.expand_dims(sum_activations, axis=0)
        eps = 0.000001
        aij = grads_power_2 / (2 * grads_power_2 + sum_activations[:, :, None, None] * grads_power_3 + eps)
        weights = np.maximum(grads, 0) * aij
        weights = np.sum(weights, axis=(2, 3))

        for j in range(32):
            activations[:, j, :, :] *= weights[:, j]

        activations = torch.from_numpy(activations)
        heatmap = torch.sum(activations, dim=1).squeeze()
        heatmap = np.maximum(heatmap, 0)
        heatmap /= torch.max(heatmap)
        heatmap = heatmap.squeeze().numpy()
        return heatmap

In [None]:
for dataset_id in range(1, 4):
    print('Dataset id', dataset_id)

    if dataset_id == 2:
        test = InteractionDataset(path='/home/michal/thesis/test_frames', n_frames_used=8)
    elif dataset_id == 3:
        test = InteractionDataset(path='/home/michal/thesis/white_background_frames', n_frames_used=8)

    if dataset_id in [2, 3]:
        testloader = torch.utils.data.DataLoader(test,
                                             batch_size=1)

    if dataset_id == 1:
        testloader = dlc.agent.testloader

    for image, label in testloader:
        input_tensor = image.to('cuda')

        img_path = testloader.dataset.img_path[0].split('/')[-1]

        if dataset_id == 1:
            if img_path != 'x_6XTLNK55_b11_14_0.jpg':
                continue
        if dataset_id == 2:
            if img_path != '6XTLNK55_b2_14_0.jpg':
                continue
        if dataset_id == 3:
            if img_path != '6XTLNK55_w_14_0.jpg':
                continue

        # get the most likely prediction of the model
        pred = dlc(input_tensor)
        pred1 = pred.argmax(dim=1)

        label_num = label.tolist()[0]
        label_pred_num = pred1.tolist()[0]

        if label_pred_num != label_num:
            continue

        print(img_path)
        # get the gradient of the output with respect to the parameters of the model
        pred[:, int(label)].backward()

        # get the activations of the last convolutional layer
        activations = dlc.get_activations(input_tensor).detach().cpu().numpy()

        # pull the gradients out of the model
        gradients = dlc.get_activations_gradient().cpu()
        grads = gradients.numpy()
        # gradcam plus plus
        #pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
        heatmap = get_heatmap(grads, activations)
        heatmap = cv2.resize(heatmap, (198, 198))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap = np.float32(heatmap) / 255

        fig, axs = plt.subplots(1, 8)
        fig.set_size_inches(20.5, 12.5)

        for i, ax in enumerate(axs):
            ax.set_axis_off()
            img = ((input_tensor[0]).cpu()[i]*76+127)/255
            stacked_gray_img = np.stack((img,) * 3, axis=-1)
            img = np.array(stacked_gray_img)

            cam = heatmap * 0.5 + img * 0.5
            cam = cam / np.max(cam)
            cam = np.uint8(255 * cam)

            ax.imshow(cam)

        print(f"Label: {label_num}, Prediction: {label_pred_num} ")
        plt.savefig(f"data_{dataset_id}.jpg", bbox_inches="tight")