In [None]:
install libraries

In [None]:
!pip install shap
!pip install tensorflow_addons
!pip install np_utils

import libraries

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import copy
import warnings
warnings.filterwarnings('ignore')
import cv2
import shap
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
from tensorflow.keras.layers import Conv2D, Reshape, Embedding, Concatenate, Dense, LayerNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import np_utils
from tensorflow.keras import layers
import pandas as pd
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Dropout
import tensorflow_addons as tfa
from keras.src.utils.np_utils import to_categorical
from sklearn.metrics import confusion_matrix,classification_report,ConfusionMatrixDisplay

In [None]:
strategy = tf.distribute.MirroredStrategy()

In [None]:
image = mpimg.imread('/content/Alzheimer_s Dataset/ADimage.jpg')# Read an image from your filesystem
plt.imshow(image)  # Display the image
plt.title('true class:AD,predicted class:AD')# Add a title (optional)
plt.show()  # Show the image


In [None]:
w=128
h=128
label_to_class = {
    'MildDemented': 0,
    'ModerateDemented': 1,
    'NonDemented': 2,
    'VeryMildDemented':3


}
class_to_label = {v: k for k, v in label_to_class.items()}
n_classes = len(label_to_class)

def get_images(dir_name='/content/Alzheimer_s Dataset', label_to_class=label_to_class):
#read images / labels from directory

    Images = []
    Classes = []

    for j in ['/train','/test']:
        for label_name in os.listdir(dir_name+str(j)):
            cls1 = label_to_class[label_name]

            for img_name in os.listdir('/'.join([dir_name+str(j), label_name])):
                # Load the image
                img = load_img('/'.join([dir_name+str(j), label_name, img_name]), target_size=(w, h))

                # Convert the image to a NumPy array
                img = img_to_array(img)

                Images.append(img)
                Classes.append(cls1)

    Images = np.array(Images, dtype=np.float16)
    Classes = np.array(Classes, dtype=np.float16)
    Images, Classes = shuffle(Images, Classes, random_state=0)

    return Images, Classes

In [None]:
Images, Classes = get_images()

In [None]:
indices_train, indices_test = train_test_split(list(range(Images.shape[0])),train_size=0.8, test_size=0.2, shuffle=True)
x_train = Images[indices_train]
y_train = Classes[indices_train]
x_test = Images[indices_test]
y_test = Classes[indices_test]

In [None]:
y_train = np_utils.to_categorical(y_train, num_classes=n_classes)
y_test = np_utils.to_categorical(y_test, num_classes=n_classes)

In [None]:
class ClassToken(tf.keras.layers.Layer):
    def __init__(self):
        super(ClassToken, self).__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = self.add_weight(shape=(1, 1, input_shape[-1]), initializer=w_init, trainable=True )
       #self.w = tf.Variable(initial_value=w_init(shape=(1, 1, input_shape[-1]),dtype=np.float32), trainable=True )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]

        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        cls = tf.cast(cls, dtype=inputs.dtype)

        return cls

In [None]:
def mlp(x, cf):
    x = Dense(cf["mlp_dim"], activation="gelu")(x)
    x = Dropout(cf["dropout_rate"])(x)
    x = Dense(cf["hidden_dim"])(x)
    x = Dropout(cf["dropout_rate"])(x)

    return x

In [None]:
def transformer_encoder(x, cf):
    skip_1 = x
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.MultiHeadAttention(num_heads=cf["num_heads"], key_dim=cf["hidden_dim"])(x, x)
    x = tf.keras.layers.Add()([x, skip_1])
    x = tf.keras.layers.Dropout(cf["dropout_rate"])(x)

    skip_2 = x
    x = tf.keras.layers.LayerNormalization()(x)
    x = mlp(x, cf)
    x = tf.keras.layers.Add()([x, skip_2])
    x = tf.keras.layers.Dropout(cf["dropout_rate"])(x)

    return x

In [None]:
def ResNet50ViT(cf):
    input_shape = (cf["image_size"], cf["image_size"], cf["num_channels"])
    """ Input """
    inputs = layers.Input(shape=input_shape) ## (None, 128, 128, 3)

    """ Pre-trained Resnet50 """
    resnet50 = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs)
    output = resnet50.output ## (None, 4, 4, 2048)

    """ Patch Embeddings """
    patch_embed = Conv2D(cf["hidden_dim"], kernel_size=cf["patch_size"], padding="same")(output) ##(None, 1, 1, 64)

    patch_embed = BatchNormalization()(patch_embed)
    _, h, w, f = patch_embed.shape
    patch_embed = Reshape((h*w, f))(patch_embed) ## (None, 1, 64)

    """ Position Embeddings """
    positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (16,)
    pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (16, 64)

    """ Patch + Position Embeddings """
    embed = patch_embed + pos_embed ## (None, 16, 64)

    """ Adding Class Token """
    token = ClassToken()(embed)
    x = Concatenate(axis=1)([token, embed]) ## (None, 17, 64)

    """ Transformer Encoder """
    for _ in range(cf["num_layers"]):
        x = transformer_encoder(x, cf)

    x = LayerNormalization()(x)
    x = x[:, 0, :]
    logits = Dense(cf["num_classes"])(x)

    model = Model(inputs, logits)
    model = keras.Model(inputs=inputs, outputs=logits)

    print(model.summary())
    return model


In [None]:
#parameter settings

if __name__ == "__main__":
    config = {}
    config["num_layers"] = 6
    config["hidden_dim"] = 64
    config["mlp_dim"] = 2048
    config["num_heads"] = 8
    config["dropout_rate"] = 0.2#0.1

    config["image_size"] = 128
    config["patch_size"] = 32
    config["num_patches"] = int(config["image_size"]**2 / config["patch_size"]**2)
    config["num_channels"] = 3
    config["num_classes"] = 4

In [None]:
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=0.00005 , weight_decay=0.0001
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.CategoricalAccuracy(name="accuracy"),
            #keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/ontent/workingmodel/weights1.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x_train,
        y_train,
        batch_size=16,
        epochs=10,
        validation_data=(x_test, y_test),
        validation_split=0.2,
        callbacks=[checkpoint_callback],
    )
    model.save_weights(
    checkpoint_filepath,
    overwrite=True,
    )
    y_pred=model.predict(x_test)

    model.load_weights(checkpoint_filepath)
    _, accuracy = model.evaluate(x_test, y_test)  #, top_5_accuracy
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    #print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")


    return history

In [None]:
with strategy.scope():
    vit_classifier = ResNet50ViT(config)
    history = run_experiment(vit_classifier)

In [None]:
plt.plot(history.history['accuracy'],label='Training_accuracy')
plt.plot(history.history['val_accuracy'],label='Validation_accuracy')
plt.legend()
plt.title('Training Accuracy vs Validation Accuracy')
plt.xlabel('No.of epochs')
plt.ylabel('Accuracy')
plt.show()
plt.savefig('Accuracy_Graph')

In [None]:
plt.plot(history.history['loss'],label='Training_loss')
plt.plot(history.history['val_loss'],label='Validation_loss')
plt.legend()
plt.title('Training loss vs Validation loss')
plt.xlabel('No.of epochs')
plt.ylabel('Loss')
plt.show()
plt.savefig('Loss_Graph')

In [None]:
cm=confusion_matrix(y_test,y_pred)
print(cm)
target_names=['NC:0','EMCI:1','MCI:2','AD:3']
print(classification_report(y_test,y_pred,target_names=target_names,digits=2))
import seaborn as sns
f = sns.heatmap(cm, annot=True, fmt='d')  #confusion matrix plot