In [1]:
from keras.layers import Input, Lambda, Dense, Flatten
from keras.models import Model
from keras.applications import ResNet50
from keras.applications.vgg16 import VGG16
from keras.applications.resnet import preprocess_input
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import load_img
from keras.models import Sequential
import numpy as np
from glob import glob
from matplotlib import pyplot as plt

from keras.models import load_model
import os

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tf_explain.core.grad_cam import GradCAM
from keras.utils import img_to_array, load_img
import cv2

from PIL import Image, ImageDraw, ImageFont

import tensorflow as tf

In [7]:
train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

In [8]:
# Make sure you provide the same target size as initialied for the image size
training_set = train_datagen.flow_from_directory('../../Covid19-dataset/train',
                                                 target_size = (224, 224),
                                                 batch_size = 32,
                                                 class_mode = 'categorical')


Found 251 images belonging to 3 classes.


In [9]:
test_set = test_datagen.flow_from_directory('../../Covid19-dataset/test',
                                            target_size = (224, 224),
                                            batch_size = 32,
                                            class_mode = 'categorical')

Found 66 images belonging to 3 classes.


In [10]:
model = load_model('../../models/covid/covid_vgg1.h5')

In [12]:
test_loss, test_accuracy = model.evaluate(test_set)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")

Test Loss: 0.08715351670980453, Test Accuracy: 0.9696969985961914


In [13]:
print(model.summary())

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0     

In [14]:
def overlay_heatmap(img, heatmap):
    img = cv2.resize(img, (224, 224))
    heatmap = cv2.resize(heatmap, (224, 224))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * 0.4 + img
    return superimposed_img

In [16]:
def gradcam(pathname, outputfilename, layer_name):
    img_path = pathname
    img = load_img(img_path, target_size=(224, 224))
    x = img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    preds = model.predict(x)
    class_idx = np.argmax(preds[0])

    explainer = GradCAM()
    grid = explainer.explain((x, None), model, class_idx, layer_name=layer_name)

    orig_img = cv2.imread(pathname)
    superimposed_img = overlay_heatmap(orig_img, grid)
    cv2.imwrite(outputfilename, superimposed_img)

In [17]:
input_directory = '../../Covid19-dataset/test/'
output_directory = '../../Output_Images/gradcam/'
images_to_process = 5

In [18]:
for i, category in enumerate(os.listdir(input_directory)):
    num_images_processed = 0
    # create directories
    image_output_directory = os.path.join(output_directory, category)
    # Iterate over all image files in the directory
    for j, filename in enumerate(os.listdir(os.path.join(input_directory, category))):
        # Create a new directory for each image
        image_directory = os.path.join(image_output_directory, f"image_{j+1}")
        os.makedirs(image_directory, exist_ok=True)
        
        for layer in model.layers:
            if 'conv' in layer.name or 'pool' in layer.name:
                name, extension = os.path.splitext(filename)
                gradcam(
                    os.path.join(input_directory, category, filename), 
                    os.path.join(image_directory, str(layer.name) + extension), 
                    str(layer.name))

        # Save the original image in the new directory
        root_name, extension = os.path.splitext(filename)
        cv2.imwrite(os.path.join(image_directory, '_original' + extension), cv2.imread(os.path.join(input_directory, category, filename)))
        
        num_images_processed += 1
        
        # Check if we have processed the desired number of images
        if num_images_processed == images_to_process:
            break



  heatmap = (heatmap - np.min(heatmap)) / (heatmap.max() - heatmap.min())
  cv2.cvtColor((heatmap * 255).astype("uint8"), cv2.COLOR_GRAY2BGR), colormap




In [19]:
def stitch_images(folder_path, grid_shape=(5, 4), image_size=(140, 140), padding=20, outer_padding=20, label_height=40, font_size=10):
    grid_width, grid_height = grid_shape
    img_width, img_height = image_size

    canvas_width = grid_width * img_width + (grid_width - 1) * padding + 2 * outer_padding
    canvas_height = grid_height * (img_height + label_height) + (grid_height - 1) * padding + 2 * outer_padding

    canvas = Image.new("RGB", (canvas_width, canvas_height), (255, 255, 255))
    draw = ImageDraw.Draw(canvas)
    font = ImageFont.truetype("arial.ttf", font_size)

    image_files = sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path)
                         if os.path.isfile(os.path.join(folder_path, f))
                         and f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))])

    for index, image_path in enumerate(image_files):
        img = Image.open(image_path)
        img = img.resize(image_size, Image.LANCZOS)

        x = outer_padding + (index % grid_width) * (img_width + padding)
        y = outer_padding + (index // grid_width) * (img_height + label_height + padding)

        canvas.paste(img, (x, y))

        label = os.path.splitext(os.path.basename(image_path))[0]
        label_width, label_height_actual = draw.textsize(label, font=font)
        label_x = x + (img_width - label_width) // 2
        label_y = y + img_height

        draw.text((label_x, label_y), label, font=font, fill=(0, 0, 0))

    return canvas


In [32]:
# Path to the folder containing the subfolders
parent_folder_path = '../../Output_Images/gradcam/'

# Iterate through the subfolders in the parent folder
for category in os.listdir(parent_folder_path): # /gradcam/CATEGORY
    category_path = os.path.join(parent_folder_path, category)
    for i, image_file in enumerate(os.listdir(category_path)): # /gradcam/CATEGORY/IMAGE_FILE
        # Check if it's a directory
        if os.path.isdir(os.path.join(category_path, image_file)):
            # Stitch images together in a 5x4 grid
            stitched_image = stitch_images(os.path.join(category_path, image_file))

            # Save the stitched image with the folder name as the file name in the "Output_Images" folder
            stitched_image.save(os.path.join(category_path, f"{category}_{image_file}_stitched_image.jpg"))


  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual = draw.textsize(label, font=font)
  label_width, label_height_actual