# 3D Object Detection
### Comparison between training using 2D or 3D Transformations
----

### Code Dependencies

##### a) General libraries and functions

In [None]:
%%capture
import cv2
import pickle
import scipy.io
import numpy as np
from PIL import Image
import deeptrack as dt
import geopandas as gpd
import matplotlib.pyplot as plt
from skimage import img_as_float
from shapely.geometry import Point
from matplotlib.patches import Rectangle
from sklearn.metrics import precision_recall_curve, average_precision_score
from tensorflow.python.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler

##### c) Object detection

In [None]:
#@title
class Preprocessor():

  def __init__(self,
               scales=[1, 2, 4]
               ):
    self.scales = scales

  
  def crop(self, image, x, y, wide):
    cropped_image = image[y: y + wide, x: x + wide]

    return cropped_image

  def normalizate(self, image):
    normalized_image = np.nan_to_num(image)
    normalized_image = (normalized_image - np.quantile(normalized_image, 0.01)) / (np.quantile(normalized_image, 0.99) - np.quantile(normalized_image, 0.01))

    return normalized_image


  def resize(self, image, scale):
    height, width = image.shape[:2]
    new_width = int(width * scale)
    new_height = int(height * scale)
    resized_image = cv2.resize(image, (new_width, new_height))

    return resized_image

  def generate_pipeline(self, data, transformations):
    pipeline = None
    for transformation in transformations:
        if pipeline is None:
            pipeline = transformation
        else:
            pipeline = pipeline >> transformation

    return pipeline

  def create_train_data(self, data, transformations):
    # data -> an image o a list of images
    normalized_images = [self.normalizate(image) for image in data]
    training_images = [np.expand_dims(image, axis = -1) for image in normalized_images]

    pipeline = self.generate_pipeline(training_images, transformations)
    train_set = dt.Value(lambda: np.array(random.choice(training_images))) >> pipeline
    train_set.plot()

    return train_set

  def load_images(self, data, plot=True):
    original_image = self.normalizate(data)
    input_set = [self.resize(original_image, scale) for scale in self.scales]

    if plot:
      fig, ax = plt.subplots(1, len(input_set), figsize=(25,5))
      fig.tight_layout()
      fig.suptitle('Scaled 3D images')
      for index in range(len(input_set)):
        ax[index].imshow(input_set[index])
    
    return input_set

# --------------------------------------------------------------------
class Trainer():

  def __init__(self,
               filepath,
               model = dt.models.LodeSTAR(input_shape=(None, None, 1)),
               callbacks = [None]
               ):

    self.model = model
    self.filepath = filepath
    self.callbacks = [
        ModelCheckpoint(filepath=self.filepath,
                        save_weights_only=True,
                        monitor='consistency_loss',
                        mode='min',
                        save_best_only=True
                        ),
        EarlyStopping(monitor="total_loss",
                      patience=15,
                      verbose=1,
                      mode="auto",
                      restore_best_weights=True
                      ),
        LearningRateScheduler(lambda epoch, lr: lr if epoch < 10 else lr * np.exp(-0.1))
    ]
  
  def fit(self, train_set, epochs=40, batch_size=8):
    history = self.model.fit(
        train_set,
        epochs = epochs,
        batch_size = batch_size,
        callbacks = self.callbacks)

    return history

  def plot_performance(self, history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
    fig.suptitle('Total loss and consistency loss')

    ax1.plot(history.history['total_loss'])
    ax1.set_title('Total loss')
    ax1.set(xlabel='epoch', ylabel='loss')
    ax1.set_ylim([0, 1.5])

    ax2.plot(history.history['consistency_loss'])
    ax2.set_title('Consitency loss')
    ax2.set(xlabel='epoch', ylabel='loss')
    ax2.set_ylim([0, 1.5])

    plt.show()

# --------------------------------------------------------------------
class Detector():

  def __init__(self,
               downsample,
               alpha = 0.1,
               cutoff = 0.998,
               mode = "quantile",
               colors = 'rgb',
               model = dt.models.LodeSTAR(input_shape=(None, None, 1)),
               ):
    self.downsample = downsample
    self.alpha = alpha
    self.cutoff = cutoff
    self.mode = mode
    self.colors = colors
    self.model = model

  def detect(self, image, plot=True):
    test_set = image[np.newaxis, :, :, np.newaxis]
    test_image = test_set[:, ::self.downsample, ::self.downsample, :]

    detections = self.model.predict_and_detect(test_image, alpha=self.alpha, beta=1-self.alpha, cutoff=self.cutoff, mode=self.mode)[0]
    detections[:, 1] = detections[:, 1] * self.downsample
    detections[:, 0] = detections[:, 0] * self.downsample

    return detections

  def detect_all(self, images, plot=False):
    detections = []
    
    if plot:
      fig, ax = plt.subplots(1, len(images), figsize=(25, 5))
      fig.tight_layout()
      fig.suptitle('Detections')
    
    for index in range(len(images)):
      det = self.detect(image=images[index])
      detections.append(det)

      if plot:
        ax[index].imshow(images[index])
        ax[index].scatter(detections[index][:, 1], detections[index][:, 0], color=self.colors[index])
      
    return detections

# --------------------------------------------------------------------
class Postprocessor():

  def __init__(self,
               wide=50,
               scales=[1, 2, 4],
               colors = 'rgb'
               ):
    self.wide = wide
    self.scales = scales
    self.colors = colors
  
  def scale_detections(self, detection, scale):
    scaled_detection = detection * scale

    return scaled_detection.tolist()

  def create_boxes(self, detections):
    list_detections = [Point((x,y)) for (y,x) in detections]
    points = gpd.GeoSeries(list_detections)
    boxes = points.buffer(self.wide, cap_style = 3)
    bounds = np.array([boxes[index].bounds for index in range(len(boxes))])

    return boxes, bounds

  def NMSupression(self, boxes, overlapThresh):
    # Malisiewicz et al. - non_max_suppression_fast
    if len(boxes) == 0:
      return []

    if boxes.dtype.kind == "i":
      boxes = boxes.astype("float")
    pick = []
    x1 = boxes[:,0]
    y1 = boxes[:,1]
    x2 = boxes[:,2]
    y2 = boxes[:,3]

    area = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = np.argsort(y2)
 
    while len(idxs) > 0:
      last = len(idxs) - 1
      i = idxs[last]
      pick.append(i)

      xx1 = np.maximum(x1[i], x1[idxs[:last]])
      yy1 = np.maximum(y1[i], y1[idxs[:last]])
      xx2 = np.minimum(x2[i], x2[idxs[:last]])
      yy2 = np.minimum(y2[i], y2[idxs[:last]])

      w = np.maximum(0, xx2 - xx1 + 1)
      h = np.maximum(0, yy2 - yy1 + 1)

      overlap = (w * h) / area[idxs[:last]]

      idxs = np.delete(idxs, np.concatenate(([last],
        np.where(overlap > overlapThresh)[0])))
      
    return boxes[pick].astype("int")

  def apply_nms(self, image, list_bounds, figsize = (15,15), overlapThresh=0.3):
    all_bounds = [bounds.tolist() for bounds in list_bounds]
    final_bounds = []

    for index in range(len(all_bounds)):
      final_bounds += all_bounds[index]

    final_bounds = np.array(final_bounds)
    final_detections = self.NMSupression(final_bounds,
                                         overlapThresh=overlapThresh)
    self.plot_results(image, final_detections, figsize)

    return final_detections


  def plot_boxes(self, test_image, detections, figsize=(15, 15), plot=True):
    list_boxes, list_bounds, scaled_detections = [], [], []

    if plot:
      fig, ax = plt.subplots(figsize=figsize)
      plt.imshow(test_image)

    for index in range(len(detections)):
      scaled_det = self.scale_detections(detections[index], self.scales[index])
      scaled_detections.append(scaled_det)

      boxes, bounds = self.create_boxes(scaled_detections[index])
      list_boxes.append(boxes)
      list_bounds.append(bounds)

      if plot:
        boxes.boundary.plot(ax=ax, color = self.colors[index])

    return list_bounds

  def plot_results(self, image, boxes, figsize, color='red', lw=2):
    plt.figure(figsize=figsize)
    plt.imshow(image)

    wide = self.wide * 2
    for i in range(len(boxes)):
      rect = Rectangle((boxes[i][0], boxes[i][1]),wide,wide,
                       edgecolor=color,
                       facecolor='none',
                       lw=lw)
     
      plt.gca().add_patch(rect)

In [None]:
def test(test_data, model, downsample=1, alpha=0.1, cutoff=0.99, mode="quantile", plotPrevious=True):
  preprocessor = Preprocessor()
  input_set = preprocessor.load_images(test_data, plot=plotPrevious)
  detector = Detector(downsample, model=model, alpha=alpha, cutoff=cutoff, mode=mode)
  detections = detector.detect_all(input_set, plot=plotPrevious)
  postprocessor = Postprocessor()
  test_image = input_set[0]
  bounds = postprocessor.plot_boxes(test_image, detections, plot=plotPrevious)
  final_detections = postprocessor.apply_nms(test_image, bounds, figsize = (8,8))

  return final_detections

##### d) Computing metrics

In [None]:
def calculate_iou(gt_box, pred_box):
    """
    Calcula el IoU (Intersection over Union) entre dos bounding boxes en formato [xmin, ymin, xmax, ymax].
    """
    # Calcula la intersección entre los bounding boxes
    x1 = np.maximum(gt_box[0], pred_box[0])
    y1 = np.maximum(gt_box[1], pred_box[1])
    x2 = np.minimum(gt_box[2], pred_box[2])
    y2 = np.minimum(gt_box[3], pred_box[3])
    intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
    
    # Calcula la unión entre los bounding boxes
    gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
    pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
    union = gt_area + pred_area - intersection
    
    # Calcula el IoU
    iou = intersection / union
    
    return iou

def confusion_matrix(tp, fp, fn):
    """
    Construye una matriz de confusión a partir de los valores de verdaderos positivos (tp), falsos positivos (fp)
    y falsos negativos (fn).
    """
    tn = 0  # asumimos que no hay verdaderos negativos en este caso
    
    return np.array([[tp, fp], [fn, tn]])

def plot_confusion_matrix(conf_matrix):
    """
    Muestra la matriz de confusión como una gráfica utilizando la biblioteca matplotlib.
    """
    fig, ax = plt.subplots(figsize=(6,4))
    im = ax.imshow(conf_matrix, cmap='Blues')

    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            ax.text(j, i, str(conf_matrix[i][j]), ha='center', va='center')

    plt.colorbar(im)
    plt.xticks([0, 1], ['Positive', 'Negative'])
    plt.yticks([0, 1], ['Positive', 'Negative'])
    plt.xlabel('Actual values')
    plt.ylabel('Predicted values')
    plt.title('Confusion Matrix')
    plt.show()

def calculate_metrics(gt_boxes, pred_boxes, iou_threshold=0.5):
    """
    Calcula la precisión, el recall, el puntaje F1 y el accuracy para un conjunto de bounding boxes del ground truth
    y los predichos por una CNN.
    gt_boxes y pred_boxes deben ser NumPy arrays de bounding boxes en el formato [[xmin, ymin, xmax, ymax], ...].
    iou_threshold es el umbral de IoU para considerar que un bounding box predicho es correcto.
    """
    tp = 0
    fp = 0
    fn = 0
    
    for gt_box in gt_boxes:
        gt_matched = False
        for pred_box in pred_boxes:
            iou = calculate_iou(gt_box, pred_box)
            if iou >= iou_threshold:
                tp += 1
                gt_matched = True
                break
        if not gt_matched:
            fn += 1
    
    for pred_box in pred_boxes:
        pred_matched = False
        for gt_box in gt_boxes:
            iou = calculate_iou(gt_box, pred_box)
            if iou >= iou_threshold:
                pred_matched = True
                break
        if not pred_matched:
            fp += 1
    
    if tp == 0 and fp == 0 and fn == 0:
        # En caso de que no haya objetos en el ground truth o en las predicciones, todas las métricas son cero.
        precision = 0.0
        recall = 0.0
        f1_score = 0.0
        accuracy = 0.0
    else:
        precision = round(tp / float(tp + fp), 3)
        recall = round(tp / float(tp + fn), 3)
        if precision == 0 and recall == 0:
            f1_score = 0.0
        else:
            f1_score = round(2 * ((precision * recall) / (precision + recall)), 3)
        accuracy = round(tp / float(tp + fp + fn), 3)
    
    return precision, recall, f1_score, accuracy

def generate_metrics(gt_boxes, pred_boxes, iou_threshold=0.5):
    tp = 0
    fp = 0
    fn = 0
    total_tp = 0
    total_fp = 0
    total_fn = 0
    
    for gt_box in gt_boxes:
        gt_matched = False
        for pred_box in pred_boxes:
            iou = calculate_iou(gt_box, pred_box)
            if iou >= iou_threshold:
                tp += 1
                gt_matched = True
                break
        if not gt_matched:
            fn += 1

    total_tp += tp
    total_fn += fn

    for pred_box in pred_boxes:
        pred_matched = False
        for gt_box in gt_boxes:
            iou = calculate_iou(gt_box, pred_box)
            if iou >= iou_threshold:
                pred_matched = True
                break
        if not pred_matched:
            fp += 1

    total_fp += fp
    
    precision, recall, f1_score, accuracy = calculate_metrics(gt_boxes, pred_boxes, iou_threshold)
    conf_matrix = confusion_matrix(tp, fp, fn)
    
    return precision, recall, f1_score, accuracy, conf_matrix, [total_fp, total_fn, total_tp]

-------------------

### Datasets loading (images and ground truth)

In [None]:
import os
import pickle
import scipy.io

def load_data_file(mat_file_path, pkl_file_path, key):
    # Load the image data from the .mat file
    image_data = scipy.io.loadmat(mat_file_path)[key]

    # Load the ground truth bounding boxes from the .pkl file
    with open(pkl_file_path, 'rb') as file:
        gt_boxes = pickle.load(file)

    return image_data, gt_boxes

def load_dataset(folder_path, key):
    dataset = []
    file_names = os.listdir(folder_path)
    for file_name in file_names:
        if file_name.endswith('.mat'):
            mat_file_path = os.path.join(folder_path, file_name)
            pkl_file_path = os.path.join(folder_path, f'{os.path.splitext(file_name)[0]}_gt.pkl')
            if os.path.isfile(pkl_file_path):
                data = load_data_file(mat_file_path, pkl_file_path, key)
                dataset.append(data)
    return dataset

def display_images_with_gt(dataset):
    for image_data, gt_boxes in dataset:
        # Display the image
        plt.imshow(image_data)
        plt.axis('on')

        # Display the ground truth bounding boxes
        for box in gt_boxes:
            x1, y1, x2, y2 = box
            w = x2 - x1
            h = y2 - y1
            rect = plt.Rectangle((x1, y1), w, h, edgecolor='r', facecolor='none')
            plt.gca().add_patch(rect)

        plt.show()

------------

#### 1- Without surface base

In [None]:
noSurface_data_path = '../Data/MATLAB/ProcessedData/without_surface_base/'
noSurface_key = 'Spz'
noSurface_dataset = load_dataset(noSurface_data_path, noSurface_key)
display_images_with_gt(noSurface_dataset)

#### 2- With surface base

In [None]:
withSurface_data_path = '../Data/MATLAB/ProcessedData/with_surface_base'
withSurface_key = 'ZcM_r'
withSurface_dataset = load_dataset(withSurface_data_path, withSurface_key)
display_images_with_gt(withSurface_dataset)

-------------

# Testing and comparisons

#### Loading pre-trained weights

In [None]:
# Model with 2D transformations
filepath_2D = '../Model/checkpoints/2D_transformations'
trainer_2D = Trainer(filepath_2D)
trainer_2D.model = dt.models.LodeSTAR(input_shape=(None, None, 1))
trainer_2D.model.load_weights(filepath_2D)

# Model with 3D transformations
filepath_3D = '../Model/checkpoints/3D_transformations'
trainer_3D = Trainer(filepath_3D)
trainer_3D.model = dt.models.LodeSTAR(input_shape=(None, None, 1))
trainer_3D.model.load_weights(filepath_3D)

### Settings

In [None]:
size_object = 100 # 100
wide = 20 # size of the training template
downsample = size_object // wide

### 1) Model trained with 2D vs 3D Transformations in images with surface remotion

In [None]:
# 3D images with surface remotion
noSurface_test_images = [data[0] for data in noSurface_dataset]
noSurface_groundTruth = [data[1] for data in noSurface_dataset]
len(noSurface_test_images)

### a) 2D-transformations-trained model

#### Predictions

In [None]:
# Detection settings
alpha = 0.5 # 0.1
cutoff = 0.9975 # 0.9985
mode = "quantile"

detected_boxes = []
modelPreds_2D_noSurface = []

for idx in range(len(noSurface_test_images)):
    detected_boxes = test(noSurface_test_images[idx], trainer_2D.model, downsample=downsample, alpha=alpha, cutoff=cutoff, plotPrevious=False)
    modelPreds_2D_noSurface.append(detected_boxes)

#### Results comparison

In [None]:
import statistics
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

def evaluate_predictions(all_images, all_gt_boxes, all_detected_boxes):
    # General metrics
    prom_metrics = {"prom_precision": [], "prom_accuracy": [], "prom_recall": [], "prom_f1_score": []}

    for image, gt_boxes, detected_boxes in zip(all_images, all_gt_boxes, all_detected_boxes):
        # Create a copy of the image to avoid modifying the original
        image_copy = image.copy()

        fig, ax = plt.subplots()
        ax.imshow(image_copy)

        # Ground truth bounding boxes (green)
        for gt_box in gt_boxes:
            x1, y1, x2, y2 = gt_box
            rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=3, edgecolor='black', facecolor='none')
            ax.add_patch(rect)

        # Predicted bounding boxes (red)
        for detected_box in detected_boxes:
            x1, y1, x2, y2 = detected_box
            rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

        plt.show()

        # Individual metrics
        precision, recall, f1_score, accuracy, conf_matrix, [fp, fn, tp] = generate_metrics(gt_boxes, detected_boxes)
        prom_metrics["prom_precision"].append(precision)
        prom_metrics["prom_recall"].append(recall)
        prom_metrics["prom_f1_score"].append(f1_score)
        prom_metrics["prom_accuracy"].append(accuracy)

        print(f'Precision: {precision}')
        print(f'Accuracy: {accuracy}')
        print(f'Recall: {recall}')
        print(f'F1 Score: {f1_score}')

        # Confusion Matrix
        plot_confusion_matrix(conf_matrix)

    # Average results and standard deviation
    print("-----Average results-----\n")
    for metric in prom_metrics:
        avg_metric = sum(prom_metrics[metric]) / len(prom_metrics[metric])
        std_dev = statistics.stdev(prom_metrics[metric])
        print(f'Average {metric}: {round(avg_metric, 3)} +/- {round(std_dev, 3)}')

In [None]:
# Black - Ground truth
# Red - Predictions
evaluate_predictions(noSurface_test_images, noSurface_groundTruth, modelPreds_2D_noSurface)