# Evaluation script

In [None]:
import torch
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import cv2
import time
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
%matplotlib inline

In [None]:
def init_cnn():
    predictor = nnUNetPredictor(
        tile_step_size=1,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=False
    )
    print(torch.cuda.get_device_name(0))
    predictor.initialize_from_trained_model_folder("../MODELS/MODEL_512_V3", checkpoint_name='checkpoint_final.pth', use_folds=(4,))
    return predictor

In [None]:
def cnn_result(image, predictor):
    IMAGE_SIZE = (512, 512)
    image = cv2.resize(image, IMAGE_SIZE).astype(np.float32) / 255.0
    cnn_input = image.reshape(1, 1, image.shape[0], image.shape[1])
    props = {'spacing': (999, 1, 1)}
    start_time = time.time()
    output = predictor.predict_single_npy_array(cnn_input, props, None, None, True)[0]
    print("Prediction: %s seconds" % (time.time() - start_time))
    return (output * 255).astype(np.uint8).reshape(IMAGE_SIZE)

In [None]:
def display_results(image_display, cnn_display, manual_points, correct_detections, missed_detections, false_detections, contours, image_name):
    # Display the image and the prediction side by side
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Draw contours on the original image
    image_with_contours = cv2.drawContours(cv2.cvtColor((image_display.copy()), cv2.COLOR_GRAY2RGB), contours, -1, (255, 0, 0), -1)

    # Draw manual points on the original image
    for point in correct_detections:
        cv2.circle(image_with_contours, point, 4, (0, 255, 0), -1)
        
    for point in missed_detections:
        cv2.circle(image_with_contours, point, 4, (255, 255, 0), -1)
    
    axes[0].imshow(image_display, cmap='gray')
    axes[0].set_title("Original MRI image")
    axes[0].axis('off')
    
    axes[1].imshow(cnn_display, cmap='gray')
    axes[1].set_title("Prediction")
    axes[1].axis('off')
    
    axes[2].imshow(image_with_contours, cmap='gray')
    axes[2].set_title("Manually annotated detection with CNN Contours")
    axes[2].axis('off')
    
    # Display manual points and detection results
    print("Image:", image_name)
    print("Manual Points:", len(manual_points))
    print("Correct Detections:", len(correct_detections))
    print("Missed Detections:", len(missed_detections))
    print("False Detections:", len(false_detections))
    
    # Open a file in write mode
    with open("detection_data.csv", "a+") as file:     
        file.write(f"{image_name},{len(manual_points)},{len(correct_detections)},{len(missed_detections)},{len(false_detections)}\n")

    plt.show()

In [None]:
def get_manual_points(image, display_scale=1.5):
    # Create a copy of the original image
    image_copy = image.copy()
    height, width = image_copy.shape[:2]
    resized_image = cv2.resize(image_copy, (int(width * display_scale), int(height * display_scale)))
    manual_points = deque()
    def mouse_callback(event, x, y, flags, param):
        # Adjust the mouse coordinates to the original image resolution
        if event == cv2.EVENT_LBUTTONDOWN:
            manual_points.append((int(x // display_scale), int(y // display_scale)))
    cv2.namedWindow("Image")
    cv2.setMouseCallback("Image", mouse_callback)
    while True:
        cv2.imshow("Image", resized_image)
        key = cv2.waitKey(1) & 0xFF
        if key == ord('\r'):
            break
    cv2.destroyAllWindows()
    return manual_points

In [None]:
def process_cnn_output(cnn_output, manual_points):
    correct_detections = []
    missed_detections = list(manual_points)
    false_detections = set()

    # Find blobs in the CNN output
    contours, _ = cv2.findContours(cnn_output, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Check if manual points are inside blobs
    for point in manual_points:
        for contour in contours:
            if cv2.pointPolygonTest(contour, point, False) >= 0:
                correct_detections.append(point)
                missed_detections.remove(point)

    # Check for false detections (blobs without manual points)
    for iterator, contour in enumerate(contours):
        overlap = False
        for point in manual_points:
            if cv2.pointPolygonTest(contour, point, False) >= 0:
                overlap = True
        if not overlap:
            false_detections.add(iterator)

    return correct_detections, missed_detections, false_detections, contours

In [None]:
if __name__ == '__main__':
    predictor = init_cnn()

    dir = "C:/Users/O/Desktop/Master Thesis/MRI Data/IMAGES_FOR_CNN_TEST"
    images = os.listdir(dir)
    for image_name in images:
        image = cv2.imread(os.path.join(dir, image_name), cv2.IMREAD_GRAYSCALE)
        
        IMAGE_SIZE = (512, 512)
        image = cv2.resize(image, IMAGE_SIZE)
        manual_points = get_manual_points(image)
        prediction = cnn_result(image, predictor)
        correct_detections, missed_detections, false_detections, contours = process_cnn_output(prediction, manual_points)
        display_results(image, prediction, manual_points, correct_detections, missed_detections, false_detections, contours, image_name)