In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import itertools

# Task: Find Most Similar Image Pairs for a Given Patient

## Function 1: For each image, extract outline of tissue shape 

### Motivation: Because we want to consider images from different stains, cannot use color or saturation to score similarity. Shape of tissue can be used to assess similarity accross stains. 

### Outcome: Return a dictionary of black-and-white images that contain the outline of the tissue. This format can be easily assessed for similarity of shape.

In [None]:
def extract_images(folder_path, patient):
    '''
    folder_path: a path to a folder of all patients
    patient: the patient id
    '''
    raw_images = {}
    masked_images = {}

    path = os.path.join(folder_path, patient)
    images = os.listdir(path)

    all_files = []
    all_labels = []

    for i in images:
        all_files.append(os.path.join(path, i))
        all_labels.append(f"{i}")

    print(f"Total images: {len(all_files)}")

    names_concat = ''.join(images)

    if len(all_files) < 3 or not all(['h&e' in names_concat, 'melan' in names_concat, 'sox10' in names_concat]):
        return (None, None)

    for idx, image_path in enumerate(all_files):
        image = cv2.imread(image_path)

        if image is None:
            print(f"Error loading image: {image_path}")
            continue
        
        # Extracting image name
        image_name = os.path.basename(all_labels[idx])

        raw_images[image_name] = image

        # Converting to grayscale
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Enhance contrast to improve contour accuracy
        image_gray = cv2.equalizeHist(image_gray)

        # Blurring the image to reduce noise
        blurred = cv2.GaussianBlur(image_gray, (5, 5), 0)
        
        if 'melan' in image_name:
            adaptive_thresholding = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                       cv2.THRESH_BINARY, 11, 1)
        elif 'sox10' in image_name:
            adaptive_thresholding = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                       cv2.THRESH_BINARY, 13, 1)
        else: 
            # Adaptive thresholding used to seperate foreground objects from the background
            adaptive_thresholding = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                           cv2.THRESH_BINARY, 13, 1.8)

        # Find initial contours and create binary mask
        contours, _ = cv2.findContours(adaptive_thresholding, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        mask = np.zeros_like(adaptive_thresholding)
        cv2.drawContours(mask, contours, -1, (255), thickness=cv2.FILLED)

        # Blurring to reduce noise
        blurred_thresh = cv2.boxFilter(adaptive_thresholding, -1, (111, 111))
        
        # Create a mask for non-white areas directly from the binary image
        lower_white = 200 
        mask_binary = blurred_thresh > lower_white
        
        # Convert the boolean mask to a binary mask so it saves in the correct form
        mask_binary_image = (mask_binary.astype(np.uint8) * 255)

        # Save the new masked image to the dictionary
        masked_images[image_name] = mask_binary_image

    return raw_images, masked_images

## Function 2: Calculates the similarity between two images

### Motivation: Create a helper function that takes two images of input using "matchShape" from OpenCV. 

### Outcome: Returns the similarity score between the two images, the lower the score, the more similar the images are. 

In [None]:
def calculate_shape_similarity(image1, image2, image3):
    # Invert images
    image1 = cv2.bitwise_not(image1)
    image2 = cv2.bitwise_not(image2)
    image3 = cv2.bitwise_not(image3)

    # Find contours
    contours1, _ = cv2.findContours(image1, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contours2, _ = cv2.findContours(image2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contours3, _ = cv2.findContours(image3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)


    if contours1 and contours2 and contours3:
        # Get the two largest contours for each set
        sorted_contours1 = sorted(contours1, key=cv2.contourArea, reverse=True)[:2]
        sorted_contours2 = sorted(contours2, key=cv2.contourArea, reverse=True)[:2]
        sorted_contours3 = sorted(contours3, key=cv2.contourArea, reverse=True)[:2]

        similarity_second1 = similarity_second2 = similarity_second3 = None
        similarity_largest1 = similarity_largest2 = similarity_largest3 = 5
        # Ensure there is at least one contour to compare
        if len(sorted_contours1) > 0 and len(sorted_contours2) > 0:
            similarity_largest1 = cv2.matchShapes(sorted_contours1[0], sorted_contours2[0], cv2.CONTOURS_MATCH_I1, 0.0)
            if len(sorted_contours1) == 2 and len(sorted_contours2) == 2:
                similarity_second1 = cv2.matchShapes(sorted_contours1[1], sorted_contours2[1], cv2.CONTOURS_MATCH_I1, 0.0)

        if len(sorted_contours1) > 0 and len(sorted_contours3) > 0:
            similarity_largest2 = cv2.matchShapes(sorted_contours1[0], sorted_contours3[0], cv2.CONTOURS_MATCH_I1, 0.0)
            if len(sorted_contours1) == 2 and len(sorted_contours3) == 2:
                similarity_second2 = cv2.matchShapes(sorted_contours1[1], sorted_contours3[1], cv2.CONTOURS_MATCH_I1, 0.0)

        if len(sorted_contours2) > 0 and len(sorted_contours3) > 0:
            similarity_largest3 = cv2.matchShapes(sorted_contours2[0], sorted_contours3[0], cv2.CONTOURS_MATCH_I1, 0.0)
            if len(sorted_contours2) == 2 and len(sorted_contours3) == 2:
                similarity_second3 = cv2.matchShapes(sorted_contours2[1], sorted_contours3[1], cv2.CONTOURS_MATCH_I1, 0.0)

        # Compute the final similarity
        if all([similarity_second1, similarity_second2, similarity_second3]):
            if (similarity_second1 + similarity_second2 + similarity_second3) <= 0.5:
                return (((similarity_largest1 + similarity_largest2 + similarity_largest3) / 3) - 0.1)
            
        return ((similarity_largest1 + similarity_largest2 + similarity_largest3) / 3)


## Function 3: Finds most similar images pairs for a given patient

### Motivation: Calculate pair-wise similarity scores for all permulations of images for a paitient. Find most similar pair, and check to see if it falls within a minimum similarity threshold. If a pair that meets the threshold is found, recursively call the function on the remaining images until there are no pairs that meet minimum similarity.

### Outcome: Return pairs of similar images for a given patient.

In [None]:
def calculate_all_similarities(raw_images, image_dict, matches = None):
    if not matches:
        matches = []
        
    similarity_scores = {}

    # Generate all combinations of three images
    for (key1, value1), (key2, value2), (key3, value3) in itertools.combinations(image_dict.items(), 3):
        key_string = key1 + key2 + key3
        # Checks if group contains all three stains
        if 'h&e' in key_string and 'melan' in key_string and 'sox10' in key_string:
            similarity = calculate_shape_similarity(value1, value2, value3)

            if similarity is not None and similarity < 0.6 * 0.9 ** len(matches):
                similarity_scores[(key1, key2, key3)] = similarity

    if similarity_scores:
        min_similarity = min(similarity_scores.values())
        
        # Find the corresponding image group
        min_group = min(similarity_scores, key = similarity_scores.get)
        
        print(f"Minimum similarity score: {min_similarity:.4f} between images: {min_group}")
        
        matches.append({name:raw_images[name] for name in min_group})

        for img in min_group:
            del image_dict[img]
            del raw_images[img]

        return calculate_all_similarities(raw_images, image_dict, matches)
        
    if len(matches) > 0:
        return matches
    
    print("No similar groups found - using alternative algorithm")
    return None

## Noah algorithm (modified by Aryaman)

In [None]:
def extract_main_contour(image):
    # convert image to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # apply Gaussian Blur to reduce noise
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # adaptive threshold to handle variations in color intensity
    adaptive_thresh = cv2.adaptiveThreshold(blurred, 255,
                                            cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                            cv2.THRESH_BINARY_INV, 11, 2)
    
    # apply morphological operations to clean up image
    kernel = np.ones((15, 15), np.uint8)
    morph = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel)
    
    # find contours
    contours, _ = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # return the largest contour
    return max(contours, key = cv2.contourArea)


# compute shape similarity between two contours
def shape_similarity(image1, image2):
    # extract main contours
    contour1 = extract_main_contour(image1)
    contour2 = extract_main_contour(image2)
    # matchShapes returns distance between shapes, take the inverse for similarity
    return 1/cv2.matchShapes(contour1, contour2, cv2.CONTOURS_MATCH_I1, 0.0) if cv2.matchShapes(contour1, contour2, cv2.CONTOURS_MATCH_I1, 0.0) > 0 else np.inf

def read(folder_path):
    images = {}

    for stain in os.listdir(folder_path):
        images[stain] = cv2.imread(os.path.join(folder_path, stain))

    return images


def find_groups(folder):
    images = read(folder)
    he_images, melan_images, sox10_images = {}, {}, {}
    for i in images.items():
        if 'h&e' in i[0]:
            he_images[i[0]] = i[1]
        elif 'melan' in i[0]:
            melan_images[i[0]] = i[1]
        elif 'sox10' in i[0]:
            sox10_images[i[0]] = i[1]

    return he_images, melan_images, sox10_images

# Find all 3 way intersectionS from different sets
def find_best_groups(he_images, melan_images, sox10_images):
    best_groups = []
    for i in he_images.items():
        for j in melan_images.items():
            for k in sox10_images.items():
                similarity_score = shape_similarity(i[1], j[1]) + shape_similarity(i[1], k[1]) + shape_similarity(j[1], k[1])
                best_groups.append((i[0], j[0], k[0], similarity_score))
    best_groups.sort(key = lambda x: x[3], reverse = True)
    return best_groups

# Now remove all but the best groups
# The way I will do this is to look at the 1st group, and remove all other groups that have the same image
def remove_duplicates(groups):
    new_groups = []
    used_images = set()
    for i in range(len(groups)):
        if groups[i][0] in used_images or groups[i][1] in used_images or groups[i][2] in used_images:
            continue
        used_images.add(groups[i][0])
        used_images.add(groups[i][1])
        used_images.add(groups[i][2])
        new_groups.append(groups[i])

    return new_groups

def make_new_image_groups(folder):
    he_images, melan_images, sox10_images = find_groups(folder)
    best_groups = find_best_groups(he_images, melan_images, sox10_images)
    best_groups = remove_duplicates(best_groups)
    
    return best_groups

## Matching slices

#### First, use Faith's algorithm. If it does not identify any matches, then try Noah's algorihm.

In [None]:
def process_all_patients(folder_path):
    patient_folders = os.listdir(folder_path)

    for patient in patient_folders:
        patient_path = os.path.join(folder_path, patient)
        if os.path.isdir(patient_path):
            raw_images, image_dict = extract_images(folder_path, patient)

            if image_dict:
                similarity_results = calculate_all_similarities(raw_images, image_dict)

                output_dir = os.path.join('matches', patient)

                if similarity_results:

                    for i, match in enumerate(similarity_results, 1):
                        match_dir = os.path.join(output_dir, f'match{i}')

                        os.makedirs(match_dir, exist_ok = True)

                        _, axs = plt.subplots(1, 3)

                        for i, (name, img) in enumerate(match.items()):

                            axs[i].imshow(img)

                            cv2.imwrite(os.path.join(match_dir, name.replace('tif', 'jpg')),
                                        img,
                                        [int(cv2.IMWRITE_JPEG_QUALITY), 50])

                    plt.show()

                else:
                    g = make_new_image_groups(patient_path) # noah's algo
                    
                    if len(g) > 0:
                        for i, match in enumerate(g, 1):
                            score = match[-1]
                            names = match[:-1]

                            # similarity threshold
                            if score > 0.1:
                                match_dir = os.path.join(output_dir, f'match{i}')

                                os.makedirs(match_dir, exist_ok = True)
                            
                                _, axs = plt.subplots(1, 3)

                                for i, name in enumerate(names):

                                    img = raw_images[name]

                                    axs[i].imshow(img)

                                    cv2.imwrite(os.path.join(match_dir, name.replace('tif', 'jpg')),
                                                img,
                                                [int(cv2.IMWRITE_JPEG_QUALITY), 50])




            print(f"Processed {patient}")
            print("------------------------------------------------")


# Example usage - change folder as needed
folder_path = 'processed_images'

process_all_patients(folder_path)