In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
import warnings

#supress tensorflow warings
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
tf.get_logger().setLevel('ERROR')
warnings.filterwarnings('ignore', category=DeprecationWarning)

# Function to load a pre-trained model
def load_model(model_path):
    """Load a pre-trained model."""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    return tf.keras.models.load_model(model_path, compile=False)


# Function to predict a mask using the segmentation model
def predict_mask(model, eye_image):
    """Predict the mask for the given eye image."""
    img_input = np.expand_dims(eye_image, axis=0)
    # print(f"Input shape for model: {np.shape(img_input)}")
    mask = model.predict(img_input)
    return mask


# Function 1: Capture images using OpenCV
def opencv():
    """Capture images from the camera and save them."""
    # Initialize the camera
    cap = cv2.VideoCapture(0)

    # Define the box coordinates
    box_start = (150, 100)  # Top-left corner of the box
    box_end = (400, 300)    # Bottom-right corner of the box

    # Initialize image counter
    img_counter = 1

    # Specify the folder where images will be saved
    save_folder = r"Captured_Images_Path"

    # Create the folder if it doesn't exist
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to capture frame from the camera.")
            break

        # Draw the box on the frame
        cv2.rectangle(frame, box_start, box_end, (0, 255, 0), 2)
        cv2.putText(frame, "Press 'c' to capture, 'q' to quit", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

        # Display the frame with the box
        cv2.imshow("Camera", frame)

        # Check for key press
        key = cv2.waitKey(1) & 0xFF
        if key == ord('c'):  # Press 'c' to capture the image inside the box
            # Crop the image inside the box
            cropped_image = frame[box_start[1]:box_end[1], box_start[0]:box_end[0]]
            
            # Save the captured image in the specified folder with a unique filename
            filename = os.path.join(save_folder, f"captured_image_{img_counter}.jpg")
            cv2.imwrite(filename, cropped_image)
            # print(f"Image saved as '{filename}'")
            img_counter += 1  # Increment the counter for the next capture

        elif key == ord('q'):  # Press 'q' to quit
            print("Exiting camera capture.")
            break

    # Release the camera and close windows
    cap.release()
    cv2.destroyAllWindows()


# Function 2: Segment images using the segmentation model
def segmentation_model():
    """Segment images using the segmentation model and save the results."""
    # Load the segmentation model
    model_path = r'Segmentation_Model_Path'
    model = load_model(model_path)
    h = 256  # Image height
    w = 256  # Image width

    # Folder path containing eye images
    folder_path = r'Captured_Images_Path'
    output_folder = r'Segmented_Images_Path'

    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Process each image in the folder
    for image_name in os.listdir(folder_path):
        if image_name.endswith(('.jpg', '.jpeg', '.png')):  # Ensure valid image formats
            image_path = os.path.join(folder_path, image_name)

            try:
                # Preprocess the image
                eye_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
                if eye_image is None:
                    print(f"Failed to load image: {image_name}")
                    continue

                eye_image = cv2.resize(eye_image, (w, h))
                eye_image = cv2.cvtColor(eye_image, cv2.COLOR_BGR2RGB)  # Convert to RGB
                eye_image = eye_image / 255.0  # Normalize to [0, 1]

                # Predict the mask
                mask = predict_mask(model, eye_image)
                mask = mask * 255  # Scale mask values to [0, 255]
                mask = np.squeeze(mask)  # Remove singleton dimensions

                # Apply thresholding
                ret, th1 = cv2.threshold(mask, 0.025 * 255, 255, cv2.THRESH_BINARY)
                th1 = np.expand_dims(th1, axis=-1)  # Add back the channel dimension
                th1 = th1 / 255.0  # Normalize thresholded mask

                # Crop the eye image using the thresholded mask
                cropped_eye_img = th1 * eye_image

                # Scale the cropped image back to [0, 255] for saving
                cropped_eye_img = (cropped_eye_img * 255).astype(np.uint8)

                # Convert the cropped image back to BGR for saving with OpenCV
                cropped_eye_img_bgr = cv2.cvtColor(cropped_eye_img, cv2.COLOR_RGB2BGR)

                # Save extracted conjunctiva image
                save_path = os.path.join(output_folder, f"segmented_{image_name}")
                cv2.imwrite(save_path, cropped_eye_img_bgr)

                # print(f"Saved segmented image: {save_path}")

                # Display the cropped eye image
                # plt.imshow(cropped_eye_img)  # Display in RGB format
                # plt.title(f"Cropped Eye Image for {image_name}")
                # plt.show()

            except Exception as e:
                print(f"Error processing image {image_name}: {e}")


# Function 3: Classify segmented images using the classification model
def classification_model():
    """Classify segmented images using the classification model."""
    # Load the pre-trained classification model
    classification_model_path = r'Classification_Model_Path'
    model = load_model(classification_model_path)

    # Image size
    img_size = (224, 224)

    # Folder path containing segmented eye images
    segmented_folder_path = r'Segmented_Images_Path'

    # Class labels
    class_labels = ['Mild Anemia', 'Moderate Anemia', 'No Anemia', 'Severe Anemia']

    # Process each segmented image for classification
    for image_name in os.listdir(segmented_folder_path):
        if image_name.endswith(('.jpg', '.jpeg', '.png')):  
            image_path = os.path.join(segmented_folder_path, image_name)

            try:
                # Load and preprocess image
                eye_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
                if eye_image is None:
                    print(f"Failed to load image: {image_name}")
                    continue

                eye_image = cv2.resize(eye_image, img_size)
                eye_image = eye_image / 255.0  # Normalize
                eye_image = np.expand_dims(eye_image, axis=0)  # Add batch dimension

                # Predict class
                prediction = model.predict(eye_image)
                class_idx = np.argmax(prediction)
                class_label = class_labels[class_idx]
                if class_label == "Mild Anemia":
                    class_label = "No Anemia"
                print(f"Predicted Class: {class_label}, Confidence: {prediction[0][class_idx]:.2f}")
                  
            except Exception as e:
                print(f"Error classifying image {image_name}: {e}")


# Main function to call all three functions synchronously
def main():
    """Main function to execute the pipeline."""
    # Step 1: Capture images
    print("Starting image capture...")
    opencv()
    print("Image Captured")

    # Step 2: Segment images
    print("Starting image segmentation...")
    segmentation_model()

    # Step 3: Classify segmented images
    print("Starting image classification...")
    classification_model()


# Run the main function
if __name__ == "__main__":
    main()

Starting image capture...
Exiting camera capture.
Image Captured
Starting image segmentation...
Starting image classification...
Predicted Class: Mild Anemia, Confidence: 1.00
