# Imports

In [1]:
from general_funcs import *
from sklearn.cluster import KMeans
from skimage.color import rgb2gray
from skimage.io import imread
from sklearn.metrics import precision_recall_fscore_support
import pandas as pd

# K Means

## Functions

In [6]:
def evaluate_k_means_segmentation(gtm_image, binary_mask):
    """
    Computes precision, recall, and F1 score for binary segmentation masks.

    Parameters:
        gtm_image (np.ndarray): Ground truth binary mask (2D array).
        binary_mask (np.ndarray): Predicted binary mask (2D array).

    Returns:
        dict: Dictionary containing precision, recall, and F1 score.
    """
    # Flatten the masks
    gtm_image_flat = gtm_image.flatten()

    binary_mask_flat = binary_mask.flatten()

    # Ensure binary values (0 and 1 only)
    gtm_image_flat = (gtm_image_flat > 0).astype(int)
    binary_mask_flat = (binary_mask_flat > 0).astype(int)

    # Calculate precision, recall, and F1 score
    precision, recall, f1, _ = precision_recall_fscore_support(gtm_image_flat, binary_mask_flat, average='binary')

    dice_score = calculate_dice_score(gtm_image, binary_mask)
    dice_score_class_1 = calculate_dice_score_for_1_class(gtm_image, binary_mask)
    jaccard_score = calculate_jaccard_index(gtm_image, binary_mask)
    accuracy = calculate_accuracy(gtm_image, binary_mask)
    weighted_accuracy = calculate_weighted_accuracy(gtm_image, binary_mask)
    sse = calculate_sse(gtm_image, binary_mask)
    return {"precision": precision,
            "recall": recall,
            "f1_score": f1,
            "dice_score": dice_score,
            "dice_score_class_1": dice_score_class_1,
            "jaccard_score": jaccard_score,
            "accuracy": accuracy,
            "weighted_accuracy": weighted_accuracy,
            "sse" : sse
            }


def calculate_sse(ground_truth, prediction):
    """
    Calculate the Sum of Squared Errors (SSE) between the ground truth mask and the predicted mask.

    Parameters:
    - ground_truth: 2D numpy array (8-bit) representing the ground truth mask.
    - prediction: 2D numpy array (8-bit) representing the predicted mask.

    Returns:
    - sse: A float value representing the sum of squared errors.
    """
    # Ensure the inputs are numpy arrays
    ground_truth = np.array(ground_truth)
    prediction = np.array(prediction)
    
    # Check that both images have the same shape
    if ground_truth.shape != prediction.shape:
        raise ValueError("Ground truth and prediction must have the same shape")
    
    # Return the SSE
    return np.sum((ground_truth - prediction) ** 2)


def print_scores(accuracies):
    avg_metrics = {f"avg_{metric}" : 0 for metric in accuracies[0].keys()}
    num_of_images = len(accuracies)
    # Calculate the total scores for each metric
    for evaluation_dict in accuracies:
        for metric, score in evaluation_dict.items():
                avg_metrics[f"avg_{metric}"] += score
    
    # Calculate the average score for each metric
    for metric, total_score in avg_metrics.items():
         avg_score = total_score/num_of_images
         # convert scores to precentages id needed
         avg_metrics[metric] = avg_score * 100 if metric != "avg_sse" else avg_score
    
    # Convert the dictionary to a pandas DataFrame for better formatting
    df = pd.DataFrame(list(avg_metrics.items()), columns=['Metric', 'Value'])

    # Convert the dictionary to a pandas DataFrame for better formatting
    df = pd.DataFrame(list(avg_metrics.items()), columns=['Metric', 'Value'])

    # Round the values to 2 decimal places for clarity
    df['Value'] = df['Value'].round(2)

    # Use style.set_table_styles to left-align the "Metric" column
    styled_df = df.style.set_table_styles(
    [{'selector': 'td:nth-child(2)', 'props': [('text-align', 'left')]},  # Left-align the first column
     {'selector': 'td', 'props': [('text-align', 'right')]},  # left-align all other columns
     {'selector': 'th', 'props': [('text-align', 'left')]}  # Center-align the headers
    ])

    # Apply the formatting to the "Value" column (display with 2 decimals)
    styled_df = styled_df.format({'Value': '{:.2f}'})

    # Display the styled DataFrame
    display(styled_df)
    


## Testing

In [3]:
gtmasks_dirs = list(filter(lambda dname: dname.endswith('gtmasks'),os.listdir('Img')))
gtmasks_dirs = sorted(gtmasks_dirs, key=lambda dname: int(dname.split('_')[3]))
gtmasks_dict = {dname: os.listdir(f'Img/{dname}') for dname in gtmasks_dirs}
gtmasks_dict.keys()

dict_keys(['wire_images_video_1_gtmasks', 'wire_images_video_2_gtmasks', 'wire_images_video_3_gtmasks', 'wire_images_video_4_gtmasks', 'wire_images_video_5_gtmasks', 'wire_images_video_6_gtmasks', 'wire_images_video_7_gtmasks', 'wire_images_video_8_gtmasks'])

In [4]:
image_dirs = list(filter(lambda dname: dname[-1].isdigit(), os.listdir('Img')))
image_dirs = sorted(image_dirs, key=lambda dname: int(dname[-1]))
images_dict = {dname: os.listdir(f'Img/{dname}') for dname in image_dirs}
images_dict

{'wire_images_video_1': ['wire1_ultrasound_watertank.png',
  'wire2_ultrasound_watertank.png',
  'wire3_ultrasound_watertank.png',
  'wire4_ultrasound_watertank.png',
  'wire5_ultrasound_watertank.png',
  'wire6_ultrasound_watertank.png'],
 'wire_images_video_2': ['wire10_ultrasound_watertank.png',
  'wire11_ultrasound_watertank.png',
  'wire12_ultrasound_watertank.png',
  'wire13_ultrasound_watertank.png',
  'wire14_ultrasound_watertank.png',
  'wire15_ultrasound_watertank.png',
  'wire16_ultrasound_watertank.png',
  'wire17_ultrasound_watertank.png',
  'wire18_ultrasound_watertank.png',
  'wire19_ultrasound_watertank.png',
  'wire20_ultrasound_watertank.png',
  'wire21_ultrasound_watertank.png',
  'wire22_ultrasound_watertank.png',
  'wire23_ultrasound_watertank.png',
  'wire7_ultrasound_watertank.png'],
 'wire_images_video_3': ['wire24_ultrasound_watertank.png',
  'wire25_ultrasound_watertank.png',
  'wire26_ultrasound_watertank.png',
  'wire27_ultrasound_watertank.png',
  'wire28_u

In [7]:
accuracies = []
for image_dir, gtmask_dir in zip(images_dict, gtmasks_dict):
    for image_name, gtmask_name in zip(images_dict[image_dir], gtmasks_dict[gtmask_dir]):
        im_path = get_image_path(image_dir, image_name)
        gtm_path = get_image_path(gtmask_dir, gtmask_name)
        image = rgb2gray(imread(im_path))
        # display_image(image, im_path)
        # Filters the top of the image and converts it to an 8-bit image
        filtered_image = create_8_bit_image(filter_image_top(image))
        # display_image(filtered_image, im_path)
        gtm_image = imread(gtm_path, as_gray=True)

        # Flatten the image for K-Means clustering
        flat_image = filtered_image.reshape((-1, 1))

        # Apply K-Means clustering with k=2
        kmeans = KMeans(n_clusters=2, random_state=42) 
        kmeans.fit(flat_image)

        # Get cluster labels for each pixel
        labels = kmeans.labels_

        # Create a binary mask
        mask = labels.reshape(filtered_image.shape) 

        # Determine which cluster corresponds to the wire (heuristic)
        # This heuristic assumes the wire pixels have a higher average intensity
        cluster_means = kmeans.cluster_centers_
        wire_cluster = np.argmax(cluster_means) 

        # Create the final binary mask
        binary_mask = np.zeros_like(mask)
        binary_mask[mask == wire_cluster] = 255
        binary_mask = create_8_bit_image(binary_mask)
        evaluation_dict = evaluate_k_means_segmentation(gtm_image, binary_mask)
        accuracies.append(evaluation_dict)

print_scores(accuracies)

Unnamed: 0,Metric,Value
0,avg_precision,92.3
1,avg_recall,53.28
2,avg_f1_score,66.59
3,avg_dice_score,99.22
4,avg_dice_score_class_1,66.59
5,avg_jaccard_score,50.94
6,avg_accuracy,98.45
7,avg_weighted_accuracy,2.23
8,avg_sse,40330.96
