# Disf + TBM Segmentation
---
Applying DISF as pre-segmentation method and TBM as superpixel classifier to obtain segmentation on FoxP3+ stained cells.

## Initial Setup

In [12]:
# Necessary Imports
import sys
sys.path.append("../DISF/python3/")
sys.path.append("../iDISF/python3/")
sys.path.append("../")
from scripts.metrics import calculate_precision, calculate_recall, calculate_f1_score

from idisf import iDISF_scribbles
from disf import DISF_Superpixels
from scripts.segmentation_utils import *
from PIL import Image
from scripts.utils import *
from scripts.superpixel_treatment import *
from ptk_code.utils import *
from ptk_code.TBM_PLOT import batch_PLOT

import datetime
import matplotlib.pyplot as plt
import joblib
import numpy as np
import cv2

Defining the params for all the pipeline

In [13]:
def get_positive_ids(predictions, superpixel_indexes):
    """Creating a list with only the positively classified superpixel ids"""
    positive_preds = []
    for i, prediction in enumerate(predictions):
        if prediction == 1:
            positive_preds.append(superpixel_indexes[i])
    return positive_preds

In [14]:
def get_overlap_img(pred, gt):
    """ Creates an overlapped visualization
    bewtween the ground truth and the obtained prediction"""
    pred_arr = np.asarray(pred)
    gt_arr = np.asarray(gt)
    binary_xor = cv2.bitwise_xor(pred_arr, gt_arr)
    return binary_xor

In [15]:
def get_positive_classification_visualization(tp_img, fp_img, ground_truth):
    green = np.array([0, 255, 0])
    blue = np.array([0, 255, 0])
    tp_img_colored = np.int32(change_color(tp_img, green))
    fp_img_colored = np.int32(change_color(fp_img, blue))
    ov_img = get_overlap_img(tp_img_colored, ground_truth)
    ov_img = get_overlap_img(ov_img, fp_img_colored)
    return ov_img

In [16]:
def get_predictions(classifier,input_batch: list[np.array], n_masses:int) -> list[int]:
    """
    Applies the TBM transform and predicts on the dataset
    """
    # Loading the reference dataset
    dataset = 'immuno_cells'
    train_dir = '../data/'+dataset+'/training/'
    (reference_x, y_train) = load_image_data(train_dir)
    # Obtaining the test dataset
    x_input = np.array([np.float64(rgb2gray(datapoint)) for datapoint in input_batch])
    # Applying transform
    batch_plot = batch_PLOT(Nmasses = n_masses)
    x_template=np.mean(reference_x,axis=0)
    reference_x_hat, x_input_hat, Pl_tem, P_tem = batch_plot.forward_seq(reference_x, x_input, x_template)
    # Making predictions
    preds = classifier.predict(x_input_hat)
    return preds

In [17]:
def filter_superpixels(ids, superpixels):
    """Applies the superpixel filter by color mean thresholding"""
    filtered_ids= []
    filtered_superpixels = []
    red_channel_mean_threshold = 130 
    green_channel_mean_threshold = 90 
    blue_channel_mean_threshold = 80 
    for id, superpixel in zip(ids, superpixels):
        red_channel_mean = sum(np.unique(superpixel[:,:,0])) / len(np.unique(superpixel[:,:,0]))
        green_channel_mean = sum(np.unique(superpixel[:,:,1])) / len(np.unique(superpixel[:,:,1]))
        blue_channel_mean = sum(np.unique(superpixel[:,:,2])) / len(np.unique(superpixel[:,:,2]))
        if red_channel_mean < red_channel_mean_threshold and  green_channel_mean < green_channel_mean_threshold and blue_channel_mean < blue_channel_mean_threshold:
            filtered_superpixels.append(superpixel)
            filtered_ids.append(id)
    return filtered_ids, filtered_superpixels

In [18]:
def obtain_fltered_superpixels(img, label_img):
    # Obtaining the ids of the generated superpixels
    superpixel_ids = np.unique(label_img)
    superpixel_ids = superpixel_ids[1:] # Excluding the first empty superpixel

    segmented_superpixels = []
    segmented_superpixels_ids = []


    for superpixel_id in superpixel_ids:
        superpixel_img = get_superpixel_img(label_img, superpixel_id)
        cropped_original_img,cropped_superpixel_img = get_cropped_superpixel_img(img, superpixel_img)
        segmented = apply_mask(cropped_original_img, cropped_superpixel_img)
        segmented_superpixels.append(segmented)
        segmented_superpixels_ids.append(superpixel_id)

    filtered_ids, filtered_superpixels = filter_superpixels(segmented_superpixels_ids, segmented_superpixels)
    return filtered_ids, filtered_superpixels

In [19]:
def make_predictions(plot_ns_model,superpixels,ids,n_masses= 300, num_of_tries=5):
    higher_num_of_predictions = 0
    positive_ids = []
    all_predictions = []
    for i in range(num_of_tries):
        current_positive_ids = []
        # Obtainig the predictions on the dataset
        predictions = get_predictions(plot_ns_model,superpixels, n_masses)
        # Obtaining only the ids that were classified positively
        current_positive_ids = get_positive_ids(predictions, ids)
        all_predictions.append(current_positive_ids)
        if len(current_positive_ids) > higher_num_of_predictions:
            higher_num_of_predictions = len(current_positive_ids)
    return all_predictions

In [20]:
def DISF_TBM_CELLS(img, num_init_seeds=7000, num_final_superpixels = 4000, n_masses = 300, num_of_tries = 5):
    superpixel_label_img, border_img = DISF_Superpixels(img, num_init_seeds, num_final_superpixels)
    label_img = superpixel_label_img.copy()
    # Loading the classifier weights
    plot_ns_model = joblib.load('../checkpoints/best.pkl')
    ids, superpixels= obtain_fltered_superpixels(img, label_img)
    all_predictions = make_predictions(plot_ns_model, superpixels, ids, n_masses, num_of_tries)
    reconstructed_imgs = [get_reconstructed_image(label_img, all_predictions[i]) for i in range(len(all_predictions))]
    prediction_amounts = [len(pred) for pred in all_predictions]
    final_prediction_img = reconstructed_imgs[prediction_amounts.index(max(prediction_amounts))]
    final_prediction_img = np.uint8(final_prediction_img)
    return final_prediction_img

In [21]:
input_img = '2.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(apply_open(true_positive_img))[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')

print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Amount of objects on the ground truth: 158
True Positive: 134
False Negative: 23
False Positive: 13
Precision: 0.91
Recall: 0.85
F1-Score0.88


In [22]:
input_img = '3.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')

print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Amount of objects on the ground truth: 117
True Positive: 105
False Negative: 16
False Positive: 24
Precision: 0.81
Recall: 0.87
F1-Score0.84


In [23]:
input_img = '4.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')

print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Amount of objects on the ground truth: 95
True Positive: 77
False Negative: 34
False Positive: 4
Precision: 0.95
Recall: 0.69
F1-Score0.8


In [24]:
input_img = '5.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')

print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30


TypeError: Cannot handle this data type: (1, 1, 3), <i4

In [None]:
input_img = '6.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')
print(f"Amount of objects on the ground truth: {gt_obj_count}")
print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

In [None]:
input_img = '7.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')
print(f"Amount of objects on the ground truth: {gt_obj_count}")
print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

In [None]:
input_img = '8.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')
print(f"Amount of objects on the ground truth: {gt_obj_count}")
print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

In [None]:
input_img = '9.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')
print(f"Amount of objects on the ground truth: {gt_obj_count}")
print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")

In [None]:
input_img = '10.png'
img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
segmentation = DISF_TBM_CELLS(img)

gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
# Obtaining the object count on each image
gt_obj_count = cv2.connectedComponents(gray_gt)[0]
true_positive =cv2.connectedComponents(true_positive_img)[0]
false_negative = cv2.connectedComponents(false_negative_img)[0]
false_positive = cv2.connectedComponents(false_positive_img)[0]
comparison_ovl_img = get_positive_classification_visualization(
    true_positive_img,
    false_positive_img,
    gt
)
Image.fromarray(segmentation).save(f'../quick_results/segmentation/{input_img}')
Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'../quick_results/overlap/{input_img}')
print(f"Amount of objects on the ground truth: {gt_obj_count}")
print(f"True Positive: {true_positive}")
print(f"False Negative: {false_negative}")
print(f"False Positive: {false_positive}")
precision = round(calculate_precision(true_positive, false_positive),2)
recall = round(calculate_recall(true_positive, false_negative), 2)
f1_score = round(calculate_f1_score(precision, recall), 2)
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score{f1_score}")