In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
import os
import math
import scipy
import glob
import xml.etree.ElementTree as ET
from skimage.feature import peak_local_max

### __Metrics__
##### Normalized MCME (Mean Coordinate Matching Error)

In [None]:
def diagonal(image):
  """Computes the diagonal of an image"""
  d = (image.shape[0]**2 + image.shape[1]**2)**0.5
  return d



sqdiff = lambda xx,yy,diag: sum((x-y)**2 for x,y in zip(xx, yy))**0.5/diag # normalized distance



def match(A, B):
    NN = {}
    for aa in A:
        dist = {bb:sqdiff(aa,bb,diagonal(test_data[0])) for bb in B}
        min_bb = min(dist, key=lambda bb: dist[bb])
        NN[aa] = (min_bb, dist[min_bb])
    #print(NN)
    return NN



def MCME2_ext(P, GT):
    """Computes MCME for the pair"""
    m1 = match(P,GT)
    m2 = match(GT,P)
    
    r1 = {(p,gt):d for (p,(gt,d)) in m1.items()}
    r2 = {(p,gt):d for (gt,(p,d)) in m2.items()}
    
    r = {**r1, **r2}
    return r, sum(r.values())/len(r)

MCME2 = lambda P, GT: MCME2_ext(P,GT)[1]


def Global_MCME(pred_centroids, gt_centroids): # Gets predicted centroids and ground truth centroids as input
  """Computes the average normalized MCME for all test images"""
  import numpy as np
  from skimage.feature import peak_local_max
  MCME_list = []
  counter = 0 # counts the images with no predicted centroids
  assert len(gt_centroids) == len(pred_centroids)
  for img_no in range(len(pred_centroids)):
    coords_set_gt = set()
    coords_set_pred = set()
    gt_centroids_coords = peak_local_max(gt_centroids[img_no], min_distance=1)
    pred_centroids_coords = peak_local_max(pred_centroids[img_no], min_distance=1)
    for a in gt_centroids_coords:
      coords_set_gt.add(tuple(a))
    for a in pred_centroids_coords:
      coords_set_pred.add(tuple(a))
    if len(pred_centroids_coords) != 0:
      MCME_list.append(MCME2(coords_set_pred, coords_set_gt))
    else:
      counter += 1
  return (round(np.mean(MCME_list), 3), counter, MCME_list)

##### MPPR (Multiple Patch Precision-Recall)

In [None]:
def get_patches(img_pred, img_gt, patch_x_width, patch_y_height):
  """Get the patches for a single image"""
  random_start_w = np.random.randint(0, img_pred.shape[1])
  random_start_h = np.random.randint(0, img_pred.shape[0])
        
  patch_array_pred = img_pred[random_start_h:random_start_h + patch_y_height, random_start_w:random_start_w + patch_x_width]
  patch_array_gt = img_gt[random_start_h:random_start_h + patch_y_height, random_start_w:random_start_w + patch_x_width]
        
  return patch_array_pred, patch_array_gt



def MPPR(pred, gt):
  '''Computes the MPPR for a single image comparing
  the predicted map and the ground truth map'''
  TP = 0
  TN = 0
  FP = 0
  FN = 0

  if (pred.sum() != 0) and (gt.sum() != 0):
    TP = 1
  elif (pred.sum() == 0) and (gt.sum() != 0):
    FN = 1
  elif (pred.sum() != 0) and (gt.sum() == 0):
    FP = 1
  else:
    TN = 1

  return TP, TN, FP, FN



def Global_MPPR(pred_data, gt_data, no_patches, width, height):
  """Computes the average MPPR for all test data"""
  list_of_mean_precision = []
  list_of_mean_recall = []

  for img_pred, img_gt in zip(pred_data, gt_data):
    list_TP = []
    list_TN = []
    list_FP = []
    list_FN = []
    for _ in range(no_patches):
      a, b = get_patches(img_pred, img_gt, width, height)
      TP, TN, FP, FN = MPPR(a, b)
      list_TP.append(TP)
      list_TN.append(TN)
      list_FP.append(FP)
      list_FN.append(FN)
    TP_sum = np.sum(list_TP)
    TN_sum = np.sum(list_TN)
    FP_sum = np.sum(list_FP)
    FN_sum = np.sum(list_FN)

    recall = TP_sum / (TP_sum + FN_sum) 
    precision = TP_sum / (TP_sum + FP_sum)
    
    list_of_mean_precision.append(100 * round(precision, 4))
    list_of_mean_recall.append(100 * round(recall, 4))

  return list_of_mean_precision, list_of_mean_recall

### __Filtering__

In [None]:
def threshold_outputs(img, tau, sigma):
  """Thresholds the predicted map given a threshold tau"""
  _, img_tau = cv2.threshold(img/img.max(), tau, 1.0, cv2.THRESH_TOZERO)
  return scipy.ndimage.filters.gaussian_filter(img_tau, sigma=sigma)

### __Displacement__

In [None]:
def displacement(x1, x2):
  """Computes the shift in two centroids vectors x1 and x2"""

  import numpy as np

  centroids_dict = {}

  k, p = len(x1), len(x2)
  D = np.zeros((k,p)) # The ij-th element in matrix D is the distance of the i-th centroid in x1 from the j-th centroid in x2
  for i in range(k):
    for j in range(p):
      D[i,j] = np.linalg.norm(x2[j]-x1[i])

  if len(x2) >= len(x1):

    for elem in range(k): # for every row in matrix D
      if (x2[np.argmin(D, axis=1)][elem][0] > x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] == x1[elem][1]):
        out = 'EAST'
      elif (x2[np.argmin(D, axis=1)][elem][0] < x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] == x1[elem][1]):
        out = 'WEST'
      elif (x2[np.argmin(D, axis=1)][elem][0] == x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] > x1[elem][1]):
        out = 'NORTH'
      elif (x2[np.argmin(D, axis=1)][elem][0] == x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] < x1[elem][1]):
        out = 'SOUTH'
      elif (x2[np.argmin(D, axis=1)][elem][0] > x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] > x1[elem][1]):
        out = "NORTH-EAST"
      elif (x2[np.argmin(D, axis=1)][elem][0] < x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] > x1[elem][1]):
        out = "NORTH-WEST"
      elif (x2[np.argmin(D, axis=1)][elem][0] > x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] < x1[elem][1]):
        out = "SOUTH-EAST"
      elif (x2[np.argmin(D, axis=1)][elem][0] < x1[elem][0]) and (x2[np.argmin(D, axis=1)][elem][1] < x1[elem][1]):
        out = "SOUTH-WEST"
      else:
        out = 'NO MOVEMENT'
      
      distance = np.linalg.norm(x2[np.argmin(D, axis=1)][elem]-x1[elem])
      centroids_dict[elem] = x1[elem], x2[np.argmin(D, axis=1)][elem], f'Dist = {round(distance,3)}', out

  else:
    k, p = len(x2), len(x1)
    D = D.T

    for elem in range(k):
      if (x1[np.argmin(D, axis=1)][elem][0] > x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] == x2[elem][1]):
        out = 'WEST'
      elif (x1[np.argmin(D, axis=1)][elem][0] < x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] == x2[elem][1]):
        out = 'EAST'
      elif (x1[np.argmin(D, axis=1)][elem][0] == x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] > x2[elem][1]):
        out = 'SOUTH'
      elif (x1[np.argmin(D, axis=1)][elem][0] == x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] < x2[elem][1]):
        out = 'NORTH'
      elif (x1[np.argmin(D, axis=1)][elem][0] > x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] > x2[elem][1]):
        out = "SOUTH-WEST"
      elif (x1[np.argmin(D, axis=1)][elem][0] < x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] > x2[elem][1]):
        out = "SOUTH-EAST"
      elif (x1[np.argmin(D, axis=1)][elem][0] > x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] < x2[elem][1]):
        out = "NORTH-WEST"
      elif (x1[np.argmin(D, axis=1)][elem][0] < x2[elem][0]) and (x1[np.argmin(D, axis=1)][elem][1] < x2[elem][1]):
        out = "NORTH-EAST"
      else:
        out = 'NO MOVEMENT'
      
      distance = np.linalg.norm(x2[elem]-x1[np.argmin(D, axis=1)][elem])
      centroids_dict[elem] = x1[np.argmin(D, axis=1)][elem], x2[elem], f'Dist = {round(distance,3)}', out

  return centroids_dict



def prediction_and_displacement(model, t0, t1, tau):
  """Computes the displacement given two time instants t0 and t1
  and given a threshold tau"""

  inner_sigma = 10
  displacement_list = []

  a = np.expand_dims(test_data[t0], axis=0)
  b = np.expand_dims(test_data[t1], axis=0)
  img_t0 = model.predict(a)
  img_t1 = model.predict(b)

  if len(img_t0) == 2:
    img_t0 = img_t0[1]
    img_t1 = img_t1[1]

  img_t0 = img_t0[0,:,:,0]
  img_t1 = img_t1[0,:,:,0]

  img_t0 = threshold_outputs(img=img_t0, tau=tau, sigma=inner_sigma)
  img_t1 = threshold_outputs(img=img_t1, tau=tau, sigma=inner_sigma)

  centroids_t0 = peak_local_max(img_t0, threshold_rel=.25)
  centroids_t1 = peak_local_max(img_t1, threshold_rel=.25)

  displacement_list.append(displacement(centroids_t0, centroids_t1))
  
  return displacement_list