In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

In [2]:
def generate_gradcam(model, input_data, target_class, last_conv_layer_name, output_layer_name):
    grad_model = tf.keras.models.Model(
        [model.inputs], 
        [
            model.get_layer(last_conv_layer_name).output, 
            model.get_layer(output_layer_name).output
        ]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(input_data)
        loss = predictions[:, target_class]

    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2, 3))
    conv_outputs = conv_outputs[0]

    # Pesos aplicados ao mapa de características
    cam = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)

    # Normalização para 0-1
    cam = tf.maximum(cam, 0)
    cam = cam / tf.reduce_max(cam)

    return cam.numpy()

In [3]:
base_path = "C:/Users/Team Taiane/Desktop/ADNI/FULL_ADNI/processed_7_slices_data/7_slices_sagital/"
model_path = f"{base_path}results/3d/test_1/binary_classifier_150_epochs_batch_64_2_classes.keras"
#sample_path = f"{base_path}validation/ad/I89119.nii.gz"
sample_path = "C:/Users/Team Taiane/Desktop/ADNI/FULL_ADNI/processed_7_slices_data/7_slices_axial/train/ad/I10861.nii.gz"

model = tf.keras.models.load_model(model_path)

sample = nib.load(sample_path).get_fdata()
sample = np.expand_dims(sample, axis=0)  # Adiciona batch size
sample = np.expand_dims(sample, axis=-1)  # Adiciona canal

target_class = 1
last_conv_layer_name = 'conv3d_3'
output_layer_name = 'dense_2'

In [None]:
# Visualizar o GradCAM
plt.imshow(sample[0, :, :, 4, 0], cmap='gray')  # Mostrar um slice central
plt.colorbar()
plt.show()

In [None]:
print(sample.shape)

In [None]:
model.summary()

In [None]:
cam = generate_gradcam(model, sample, target_class, last_conv_layer_name, output_layer_name)


# Visualizar o GradCAM
plt.imshow(cam[:, :, cam.shape[2] // 2], cmap='jet')  # Mostrar um slice central
plt.colorbar()
plt.show()