In [None]:
# General imports
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pickle
import skimage
from skimage.io import imread, imshow
from glob import glob
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_curve, auc, roc_auc_score

# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.applications import VGG16, DenseNet121
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers
from tensorflow.python.client import device_lib

# Ensure that GPU memory growth is enabled (optional, GPU-specific)
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Set random seed for reproducibility
np.random.seed(42)

# Display setup
%matplotlib inline

In [None]:
x_train = np.load("x_train_leuknet.npy")
y_train = np.load("y_train_leuknet.npy")

In [None]:
num_classes = 2

y_train = np_utils.to_categorical(y_train,num_classes)

In [None]:
x_train.shape, y_train.shape

In [None]:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced', np.unique(np.argmax(y_train, axis=1)),np.argmax(y_train, axis=1))
print("class weights: ",class_weights)

In [None]:
def crop_center(img, bounding):
    start = tuple(map(lambda a, da: a//2-da//2, img.shape, bounding))
    end = tuple(map(operator.add, start, bounding))
    slices = tuple(map(slice, start, end))
    im = img[slices].astype('float32')
    return im

In [None]:
VAL_ALL_PATH = r'/content/drive/MyDrive/leukemia/ALLIDB-2 best models/High/Test/all'
VAL_HEM_PATH = r'/content/drive/MyDrive/leukemia/ALLIDB-2 best models/High/Test/hem'

In [None]:
all_list = os.listdir(VAL_ALL_PATH)
hem_list = os.listdir(VAL_HEM_PATH)
all_list.sort()
hem_list.sort()
PATH = VAL_HEM_PATH
LIST = hem_list

In [None]:
root_dir = r'/content/drive/MyDrive/leukemia/ALLIDB-2 best models/Low'
model_name = r'DenseNet121_low_class_weight.h5'
last_conv_layer_name = "conv5_block16_concat"
out_dir = r'/content/drive/MyDrive/leukemia/ALLIDB-2 best models/Low/corresponding heatmaps/hem'
index = 0

In [None]:

crop_height = 150
crop_width = 150


In [None]:

import os
import skimage
from skimage.io import imread, imshow
img_all = imread(os.path.join(PATH, LIST[0]))
print('MAx: ', np.max(img_all))
print('MIN: ', np.min(img_all))
cropped_img_all = crop_center(img_all, (crop_height,crop_width,3))
print('Cropped_img MAx: ', np.max(cropped_img_all))
print('Cropped_img MIN: ', np.min(cropped_img_all))
cropped_img_all = cropped_img_all.astype('uint8')
print(cropped_img_all.dtype)
imshow(cropped_img_all)
rescaled_cropped_img_all = cropped_img_all * (1.0/255.0)
print('Rescaled_cropped_img MAx: ', np.max(rescaled_cropped_img_all))
print('Resclased_cropped_img MIN: ', np.min(rescaled_cropped_img_all))
print(type(rescaled_cropped_img_all))
print(rescaled_cropped_img_all.shape)
print(rescaled_cropped_img_all.dtype)
array = np.expand_dims(rescaled_cropped_img_all, axis=0)
print(array.dtype, array.shape)

In [None]:
from tensorflow import keras
model = keras.models.load_model(os.path.join(root_dir, model_name), compile=False)
model.summary()
print(model.layers[-1].name)

In [None]:
import tf_explain
from tf_explain.core.grad_cam import GradCAM

explainer = GradCAM()

arr = array[0,:,:,:]
data = ([arr], None)

output = explainer.explain(data, model, class_index=index, layer_name=last_conv_layer_name, colormap=cv2.COLORMAP_JET)

In [None]:
for x in range(len(LIST)):
    img = imread(os.path.join(PATH, LIST[x]))
    cropped_img = crop_center(img, (crop_height,crop_width,3))
    rescaled_cropped_img = cropped_img * (1./255.)
    data = ([rescaled_cropped_img], None)
    output = explainer.explain(data, model, class_index=index, layer_name=last_conv_layer_name, colormap=cv2.COLORMAP_JET)
    output_dir = out_dir
    output_name = LIST[x]
    output_name = output_name[:-4]
    output_name = output_name + '.png'
    explainer.save(output, output_dir, output_name)

print("Completed")


In [None]:
import os
import numpy as np
import skimage
from skimage.io import imshow, imread, imsave
print(len(LIST))
pred_list = []
for x in LIST:
    single_orig_image = imread(os.path.join(PATH, x))
    single_cropped_img = crop_center(single_orig_image, (crop_height,crop_width,3))
    single_cropped_img_dim_extended = np.expand_dims(single_cropped_img, axis=0)
    single_cropped_img_dim_extended = single_cropped_img_dim_extended / 255.0
    pred_value = model.predict(single_cropped_img_dim_extended)
    print(x, pred_value)
    if pred_value > 0.5:
        pred_list.append(1)
        pred_flat = 1
    else:
        pred_list.append(0)
        pred_flat = 0
    des_path = out_dir + '/' + x[:-4] + '_' + str(pred_flat) + '.png'
    imsave(des_path, single_cropped_img)


print('Number of 0 :', pred_list.count(0))
print('Number of 1 :', pred_list.count(1))

In [None]:
import os

path = r'/content/drive/MyDrive/leukemia/C-NMC best model/High/corresponding heatmaps/all'
all_images = os.listdir(path)
all_images.sort()
all_images

target_path0 = r'/content/drive/MyDrive/leukemia/C-NMC best model/High/corresponding heatmaps/all/classsified'
target_path1 = r'/content/drive/MyDrive/leukemia/C-NMC best model/High/corresponding heatmaps/all/misclasssified'

for x in range (len(all_images)):
    if all_images[x][-5:-4] == 0:
        orig = imread(os.path.join(path, all_images[x]))
        final_path = target_path0 + '/' + all_images[x]
        imsave(final_path, orig)
        print(all_images[x][:-6])
        name = all_images[x][:-6] + '.png'
        print('name: ', name)
        heatmap = imread(os.path.join(path, name))
        final_path = target_path0 + '/' + all_images[x][:-6] + '.png'
        imsave(final_path, heatmap)
    else:
        orig = imread(os.path.join(path, all_images[x]))
        final_path = target_path1 + '/' + all_images[x]
        imsave(final_path, orig)
        print(all_images[x])
        print(all_images[x][:-6])
        name = all_images[x][:-6] + '.png'
        print('name: ', name)
        heatmap = imread(os.path.join(path, name))
        final_path = target_path1 + '/' + all_images[x][:-6] + '.png'
        imsave(final_path, heatmap)