# Cell detection in single cell electron microscopy

## Import libraries

In [None]:
import numpy as np
import cv2
import os
import multiprocessing as mp

## Similarity-based Search

In [None]:
# new version
import math


class Search(object):
    # compute mean squared error between encodings
    @staticmethod
    def MSE(encoding1, encoding2):
        return np.mean((encoding1-encoding2)**2, axis=2)
            
    @staticmethod
    def jaccardIndex(mask, image):
        image = cv2.imread('labels/'+image, cv2.IMREAD_GRAYSCALE)
        image = np.where(image == 1, True, False)
        return np.sum(mask & image) / np.sum(mask | image)
    
    @staticmethod
    def isCorrect(patch, image):
        image = cv2.imread('labels/'+image, cv2.IMREAD_GRAYSCALE)
        image = np.where(image == 1, True, False)
        return np.any(image[patch[0]:patch[0]+100, patch[1]:patch[1]+100])
    
    def __init__(self, encodingsFile):
        self.encodings = np.loadtxt(encodingsFile, delimiter=';')[::5]
        
    def query(self, queriesFile):
        queries = np.loadtxt(queriesFile, delimiter=';')
        
        for i in range(math.ceil(self.encodings.shape[0]/1000)):
            distances = np.fromfunction(
                lambda queryIndex, encodingIndex: self.MSE(queries[queryIndex, 3:], self.encodings[i*1000+encodingIndex, 3:]),
                (queries.shape[0], min(1000, self.encodings.shape[0]-i*1000)),
                dtype=int
            )
            # distances to similarity
            newSimilarities = 1 / np.exp(distances)
        
            # compute mean similarity wrt queries
            newSimilarities = np.mean(np.sort(newSimilarities, axis=0)[-5:], axis=0)
            if i == 0:
                similarities = newSimilarities
            else:
                similarities = np.hstack((similarities, newSimilarities))
        print(similarities.shape)
        # take some patches with highest similarity
        threshold = np.quantile(similarities, 0.9)
        matches = np.argwhere(similarities >= threshold).flatten()
        similarities = similarities[matches].reshape(-1,1)
        matches = self.encodings[matches, :3].astype(int)
        
        filteredMatches = []
        # filter overlapping patches
        while matches.shape[0] > 0:
            # find patch with highest similarity
            indexMax = np.argmax(similarities)
            # find overlapping patches with this patch
            overlapping = np.argwhere(
                (matches[..., 0] == matches[indexMax, 0]) *
                (np.abs(matches[..., 1] - matches[indexMax, 1]) <= 100) * 
                (np.abs(matches[..., 2] - matches[indexMax, 2]) <= 100)
            ).flatten()
            # add patch with highest similarity to filtered matches
            filteredMatches.append(matches[indexMax])
            
            # remove overlapping patches
            matches = np.delete(matches, overlapping, axis=0)
            similarities = np.delete(similarities, overlapping, axis=0)
        
            
        matches = filteredMatches
        
        # indicate matches in images
        images = os.listdir('images/')
        jaccard = []
        TP = 0
        FP = 0
        for imgIndex, image in enumerate(images):
            img = cv2.imread('images/' + image)
            mask = np.zeros(img.shape[:2], np.uint8)
            for match in matches:
                match = match.astype(int)
                if match[0] == imgIndex:
                    if self.isCorrect(match[1:], image):
                        TP += 1
                        cv2.rectangle(img, tuple(match[2:0:-1]), tuple(match[2:0:-1]+[100, 100]), (0, 255, 0), 2)
                    else:
                        FP += 1
                        cv2.rectangle(img, tuple(match[2:0:-1]), tuple(match[2:0:-1]+[100, 100]), (0, 0, 255), 2)
                    mask[match[1]:match[1]+100, match[2]:match[2]+100] |= True
            jaccard.append(self.jaccardIndex(mask, image))
            if imgIndex > 20:
                cv2.imwrite('results/' + image, img)     
        
        del jaccard[:20]
        
        # print results
        print("Median Jaccard Index:")
        print(np.median(jaccard))
        print("Minimum Jaccard Index:")
        print(np.min(jaccard))
        print("Maximum Jaccard Index:")
        print(np.max(jaccard))
        print("True Positives:")
        print(TP / (TP + FP))
        print("False Positives:")
        print(FP / (TP + FP))

In [None]:
search = Search("encodings_ae.csv")
search.query("queries_ae.csv")

## Clustering

In [None]:
from sklearn.cluster import AgglomerativeClustering

In [None]:
class Cluster(object):
    
    # compute mean squared error between encodings
    @staticmethod
    def MSE(encoding1, encoding2):
        return np.mean((encoding1-encoding2)**2, axis=2)
            
    @staticmethod
    def jaccardIndex(mask, image):
        image = cv2.imread('labels/'+image, cv2.IMREAD_GRAYSCALE)
        image = np.where(image == 1, True, False)
        return np.sum(mask & image) / np.sum(mask | image)
    
    @staticmethod
    def isCorrect(patch, image):
        image = cv2.imread('labels/'+image, cv2.IMREAD_GRAYSCALE)
        image = np.where(image == 1, True, False)
        return np.any(image[patch[0]:patch[0]+100, patch[1]:patch[1]+100])
    
    def __init__(self, encodingsFile):
        self.encodings = np.loadtxt(encodingsFile, delimiter=';')[::10]
        
    def query(self, queriesFile):
        queries = np.loadtxt(queriesFile, delimiter=';')
        encodings = np.append(self.encodings, queries, axis=0)
        clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0.09, linkage='average')
        labels = clustering.fit_predict(encodings[:, 3:])

        targetLabels = labels[-queries.shape[0]:]
        encodings = encodings[:-queries.shape[0]]
        
        images = os.listdir('images/')
        jaccard = []
        TP = 0
        FP = 0
        for imgIndex, image in enumerate(images):
            img = cv2.imread('images/' + image)
            mask = np.zeros(img.shape[:2], np.uint8)
            for matchIndex, match in enumerate(encodings[..., :3]):
                match = match.astype(int)
                if match[0] == imgIndex and labels[matchIndex] in targetLabels:
                    correct = self.isCorrect(match[1:3], image)
                    if correct:
                        TP += 1
                        cv2.rectangle(img, tuple(match[2:0:-1]), tuple(match[2:0:-1]+[100, 100]), (0, 255, 0), 2)
                    else:
                        FP += 1
                        cv2.rectangle(img, tuple(match[2:0:-1]), tuple(match[2:0:-1]+[100, 100]), (0, 0, 255), 2)
                    mask[match[1]:match[1]+100, match[2]:match[2]+100] |= True
            jaccard.append(self.jaccardIndex(mask, image))
            cv2.imwrite('results2/' + image, img)     
        
        jaccard = jaccard[20:]
        # print results
        print("Median Jaccard Index:")
        print(np.median(jaccard))
        print("Minimum Jaccard Index:")
        print(np.min(jaccard))
        print("Maximum Jaccard Index:")
        print(np.max(jaccard))
        print("True Positives:")
        print(TP / (TP + FP))
        print("False Positives:")
        print(FP / (TP + FP))

In [None]:
search = Cluster("encodings_ae.csv")
search.query("queries_ae.csv")