In [3]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model # Import load_model
from sklearn.metrics import jaccard_score, f1_score, accuracy_score, precision_score, recall_score
from tensorflow.keras.preprocessing import image
from IPython.display import display

# --- Section 1: Data Loading and Preprocessing ---
# Path to the data directory
data_dir = r"C:\Users\shrir\Music\New folder\Data"

# Parameters
# IMPORTANT FIX: Changed img_height, img_width to 128 to match model's expected input
img_height, img_width = 128, 128
batch_size = 32

# Data generators (still needed to get class information for metrics calculation)
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_gen = datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True # Shuffle for training
)

val_gen = datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False # Do NOT shuffle validation set for consistent metrics calculation
)

print(f"Train samples: {train_gen.samples}, Validation samples: {val_gen.samples}")

# --- Section 2: Load Pre-trained CNN Model ---
# Path to your saved .h5 model
model_path = r"C:\Users\shrir\Music\New folder\hello_new_cnn_model.h5"

if os.path.exists(model_path):
    model = load_model(model_path)
    print(f"Successfully loaded pre-trained model from: {model_path}")
    model.summary() # Display summary of the loaded model
else:
    print(f"Error: Model not found at {model_path}. Please ensure the path is correct.")
    exit() # Exit if the model cannot be loaded


# --- Section 3: Prediction on a Single Image (Example) ---
def predict_image(img_path):
    if not os.path.exists(img_path):
        print(f"Image not found at: {img_path}. Skipping prediction example.")
        return

    img = image.load_img(img_path, target_size=(img_height, img_width))
    img_array = image.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)

    # The ValueError should be resolved by the img_height/width fix
    pred_probs = model.predict(img_array)
    pred_class_idx = np.argmax(pred_probs, axis=1)[0]
    idx_to_class = {v: k for k, v in train_gen.class_indices.items()}
    predicted_label = idx_to_class[pred_class_idx]
    confidence = pred_probs[0][pred_class_idx]

    print(f"Predicted: {predicted_label} (Confidence: {confidence:.2f})")
    return predicted_label, confidence

# Example usage (ensure this path is correct and accessible on your system)
sample_image_path = r"C:\Users\shrir\Music\New folder\Data\Moderate Dementia\OAS1_0308_MR1_mpr-1_100.jpg"
predict_image(sample_image_path)

# --- Section 4: Metrics Calculation and Plotting ---

print("\n--- Calculating Performance Metrics ---")

# Get predictions and true labels from the validation generator
val_gen.reset() # Reset generator to ensure predictions start from the beginning
y_pred_probs = model.predict(val_gen, verbose=1)
y_pred = np.argmax(y_pred_probs, axis=1)

# True labels (integer class indices)
y_true = val_gen.classes

# Number of classes
num_classes = y_pred_probs.shape[1]
class_names = list(val_gen.class_indices.keys())

# Per-class metrics
iou_per_class = jaccard_score(y_true, y_pred, average=None, labels=range(num_classes))
f1_per_class = f1_score(y_true, y_pred, average=None, labels=range(num_classes), zero_division=0)
precision_per_class = precision_score(y_true, y_pred, average=None, labels=range(num_classes), zero_division=0)
recall_per_class = recall_score(y_true, y_pred, average=None, labels=range(num_classes), zero_division=0)

# Macro metrics (average across classes)
miou = jaccard_score(y_true, y_pred, average='macro')
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
recall_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
overall_accuracy = accuracy_score(y_true, y_pred) # Renamed from pixel_acc for clarity in classification

# Build metrics table
metrics_df = pd.DataFrame({
    'IoU': iou_per_class,
    'F1': f1_per_class,
    'Precision': precision_per_class,
    'Recall': recall_per_class
}, index=class_names)

# Add macro/mean row
metrics_df.loc['Macro/Mean'] = [miou, f1_macro, precision_macro, recall_macro]

# Add overall accuracy as a separate row
metrics_df.loc['Overall Accuracy'] = [overall_accuracy, np.nan, np.nan, np.nan]

# Display table
display(metrics_df)

# Plot line graph for per-class metrics
metrics_df.iloc[:-2].plot(marker='o', figsize=(10, 6))
plt.title('Per-Class Classification Metrics')
plt.ylabel('Score')
plt.ylim(0, 1) # Metrics like F1, Precision, Recall are typically between 0 and 1
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig('per_class_metrics_line_plot.png')
plt.close() # Close the plot to free memory

print("\nPerformance metrics table and line plot generated successfully.")
print("The plots are saved as 'per_class_metrics_line_plot.png'.")

Found 25151 images belonging to 4 classes.
Found 6286 images belonging to 4 classes.




Train samples: 25151, Validation samples: 6286
Successfully loaded pre-trained model from: C:\Users\shrir\Music\New folder\hello_new_cnn_model.h5


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 114ms/step


  self._warn_if_super_not_called()


Predicted: Moderate Dementia (Confidence: 0.92)

--- Calculating Performance Metrics ---
[1m197/197[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 45ms/step


Unnamed: 0,IoU,F1,Precision,Recall
Mild Dementia,0.134858,0.237665,0.389522,0.171
Moderate Dementia,0.842593,0.914573,0.892157,0.938144
Non Demented,0.454395,0.624858,0.478858,0.898936
Very mild Dementia,0.178852,0.303434,0.511668,0.215665
Macro/Mean,0.402674,0.520132,0.568051,0.555936
Overall Accuracy,0.485364,,,



Performance metrics table and line plot generated successfully.
The plots are saved as 'per_class_metrics_line_plot.png'.
