# Disf + TBM Segmentation
---
Applying DISF as pre-segmentation method and TBM as superpixel classifier to obtain segmentation on FoxP3+ stained cells.

## Initial Setup

In [1]:
# Necessary Imports
import sys
sys.path.append("../DISF/python3/")
sys.path.append("../iDISF/python3/")
sys.path.append("../")
from scripts.metrics import calculate_precision, calculate_recall, calculate_f1_score

from idisf import iDISF_scribbles
from disf import DISF_Superpixels
from scripts.segmentation_utils import *
from PIL import Image
from scripts.utils import *
from scripts.superpixel_treatment import *
from ptk_code.utils import *
from ptk_code.TBM_PLOT import batch_PLOT

import datetime
import matplotlib.pyplot as plt
import joblib
import numpy as np
import cv2

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


Defining the params for all the pipeline

In [2]:
def get_positive_ids(predictions, superpixel_indexes):
    """Creating a list with only the positively classified superpixel ids"""
    positive_preds = []
    for i, prediction in enumerate(predictions):
        if prediction == 1:
            positive_preds.append(superpixel_indexes[i])
    return positive_preds

In [3]:
def get_overlap_img(pred, gt):
    """ Creates an overlapped visualization
    bewtween the ground truth and the obtained prediction"""
    pred_arr = np.asarray(pred)
    gt_arr = np.asarray(gt)
    binary_xor = cv2.bitwise_xor(pred_arr, gt_arr)
    return binary_xor

In [4]:
def get_positive_classification_visualization(tp_img, fp_img, ground_truth):
    green = np.array([0, 255, 0])
    blue = np.array([0, 255, 0])
    tp_img_colored = np.int32(change_color(tp_img, green))
    fp_img_colored = np.int32(change_color(fp_img, blue))
    ov_img = get_overlap_img(tp_img_colored, ground_truth)
    ov_img = get_overlap_img(ov_img, fp_img_colored)
    return ov_img

In [5]:
def get_predictions(classifier,input_batch: list[np.array], n_masses:int) -> list[int]:
    """
    Applies the TBM transform and predicts on the dataset
    """
    # Loading the reference dataset
    dataset = 'immuno_cells'
    train_dir = '../data/'+dataset+'/training/'
    (reference_x, y_train) = load_image_data(train_dir)
    # Obtaining the test dataset
    x_input = np.array([np.float64(rgb2gray(datapoint)) for datapoint in input_batch])
    # Applying transform
    batch_plot = batch_PLOT(Nmasses = n_masses)
    x_template=np.mean(reference_x,axis=0)
    reference_x_hat, x_input_hat, Pl_tem, P_tem = batch_plot.forward_seq(reference_x, x_input, x_template)
    # Making predictions
    preds = classifier.predict(x_input_hat)
    return preds

In [6]:
def filter_superpixels(ids, superpixels, rgb_threshold):
    """Applies the superpixel filter by color mean thresholding"""
    red_channel_mean_threshold = rgb_threshold[0]
    green_channel_mean_threshold = rgb_threshold[1] 
    blue_channel_mean_threshold = rgb_threshold[2]
    filtered_ids= []
    filtered_superpixels = []
    for id, superpixel in zip(ids, superpixels):
        red_channel_mean = sum(np.unique(superpixel[:,:,0])) / len(np.unique(superpixel[:,:,0]))
        green_channel_mean = sum(np.unique(superpixel[:,:,1])) / len(np.unique(superpixel[:,:,1]))
        blue_channel_mean = sum(np.unique(superpixel[:,:,2])) / len(np.unique(superpixel[:,:,2]))
        input_color_mean = sum([red_channel_mean,green_channel_mean, blue_channel_mean]) / 3
        threshold_color_mean = sum([red_channel_mean_threshold, green_channel_mean_threshold, blue_channel_mean_threshold])/3
        if input_color_mean < threshold_color_mean :
            filtered_superpixels.append(superpixel)
            filtered_ids.append(id)
    return filtered_ids, filtered_superpixels

In [7]:
def obtain_fltered_superpixels(img, label_img,rgb_threshold):
    # Obtaining the ids of the generated superpixels
    superpixel_ids = np.unique(label_img)
    superpixel_ids = superpixel_ids[1:] # Excluding the first empty superpixel

    segmented_superpixels = []
    segmented_superpixels_ids = []


    for superpixel_id in superpixel_ids:
        superpixel_img = get_superpixel_img(label_img, superpixel_id)
        cropped_original_img,cropped_superpixel_img = get_cropped_superpixel_img(img, superpixel_img)
        segmented = apply_mask(cropped_original_img, cropped_superpixel_img)
        segmented_superpixels.append(segmented)
        segmented_superpixels_ids.append(superpixel_id)

    filtered_ids, filtered_superpixels = filter_superpixels(segmented_superpixels_ids, segmented_superpixels, rgb_threshold)
    return filtered_ids, filtered_superpixels

In [8]:
def make_predictions(plot_ns_model,superpixels,ids,n_masses= 300, num_of_tries=5):
    higher_num_of_predictions = 0
    positive_ids = []
    all_predictions = []
    for i in range(num_of_tries):
        current_positive_ids = []
        # Obtainig the predictions on the dataset
        predictions = get_predictions(plot_ns_model,superpixels, n_masses)
        # Obtaining only the ids that were classified positively
        current_positive_ids = get_positive_ids(predictions, ids)
        all_predictions.append(current_positive_ids)
        if len(current_positive_ids) > higher_num_of_predictions:
            higher_num_of_predictions = len(current_positive_ids)
    return all_predictions

In [9]:
def DISF_TBM_CELLS(img, num_init_seeds=7000, num_final_superpixels = 4000, n_masses = 300, num_of_tries = 5, rgb_threshold = (130,90,80)):
    superpixel_label_img, border_img = DISF_Superpixels(img, num_init_seeds, num_final_superpixels)
    label_img = superpixel_label_img.copy()
    # Loading the classifier weights
    plot_ns_model = joblib.load('../checkpoints/best.pkl')
    ids, superpixels= obtain_fltered_superpixels(img, label_img, rgb_threshold)
    all_predictions = make_predictions(plot_ns_model, superpixels, ids, n_masses, num_of_tries)
    reconstructed_imgs = [get_reconstructed_image(label_img, all_predictions[i]) for i in range(len(all_predictions))]
    prediction_amounts = [len(pred) for pred in all_predictions]
    final_prediction_img = reconstructed_imgs[prediction_amounts.index(max(prediction_amounts))]
    final_prediction_img = np.uint8(final_prediction_img)
    return final_prediction_img

In [10]:
import datetime

In [11]:
def create_output_folders(result_folder):
    create_folder(result_folder + '/segmentation/')
    create_folder(result_folder + '/overlap/')
    create_folder(result_folder + '/tp/')
    create_folder(result_folder + '/fn/')
    create_folder(result_folder + '/fp/')

In [12]:
def segment_classify_test(input_img, result_folder, num_init_seeds, num_final_superpixels, n_masses, num_of_tries, csv_logfile_path, logfile_path, rgb_threshold):
    """Applies the complete pipeline of the DISF TBM method"""
    print(f'\n\n - Working on image: {input_img}')
    img = np.array(Image.open(f"../data/cells_dataset/original/{input_img}"), dtype= 'int32')
    gt = np.array(Image.open(f"../data/cells_dataset/labels/{input_img}"), dtype = 'int32')
    segmentation = DISF_TBM_CELLS(img, num_init_seeds, num_final_superpixels, n_masses, num_of_tries, rgb_threshold)

    gray_gt = cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2GRAY)
    true_positive_img = cv2.bitwise_and(segmentation, gray_gt)
    false_negative_img = apply_open(cv2.bitwise_xor(gray_gt, true_positive_img))
    false_positive_img = apply_open(cv2.bitwise_xor(segmentation, true_positive_img))
    # Obtaining the object count on each image
    gt_obj_count = cv2.connectedComponents(gray_gt)[0]
    true_positive =cv2.connectedComponents(apply_open(true_positive_img))[0]
    false_negative = cv2.connectedComponents(false_negative_img)[0]
    false_positive = cv2.connectedComponents(false_positive_img)[0]
    comparison_ovl_img = get_positive_classification_visualization(
        true_positive_img,
        false_positive_img,
        gt
    )

    Image.fromarray(segmentation).save(f'{result_folder}/segmentation/{input_img}')
    Image.fromarray(comparison_ovl_img.astype('uint8')).save(f'{result_folder}/overlap/{input_img}')
    Image.fromarray(true_positive_img.astype('uint8')).save(f'{result_folder}/tp/{input_img}')
    Image.fromarray(false_negative_img.astype('uint8')).save(f'{result_folder}/fn/{input_img}')
    Image.fromarray(false_positive_img.astype('uint8')).save(f'{result_folder}/fp/{input_img}')

    precision = round(calculate_precision(true_positive, false_positive),2)
    recall = round(calculate_recall(true_positive, false_negative), 2)
    f1_score = round(calculate_f1_score(precision, recall), 2)

    output_data_row_str = f"\n{input_img},{num_init_seeds},{num_final_superpixels},{n_masses},{num_of_tries},"
    output_data_row_str += f"{rgb_threshold[0]},{rgb_threshold[1]},{rgb_threshold[2]},"
    output_data_row_str += f"{true_positive},{false_positive},{false_negative},{precision},{recall},{f1_score}"
    with open(csv_logfile_path,'a') as log_csv_file:
        log_csv_file.write(output_data_row_str)
    with open(logfile_path, 'a') as log_file:
        log_file.write('\n--------------\n')
        log_file.write(f'img: {input_img}\n')
        log_file.write(f'num_init_seeds: {num_init_seeds}\n')
        log_file.write(f'num_final_superpixels: {num_final_superpixels}\n')
        log_file.write(f'n_masses: {n_masses}\n')
        log_file.write(f'num_of_tries: {num_of_tries}\n')
        log_file.write(f'tp: {true_positive}\n')
        log_file.write(f'fp: {false_positive}\n')
        log_file.write(f'fn: {false_negative}\n')
        log_file.write(f'precision: {precision}\n')
        log_file.write(f'recall: {recall}\n')
        log_file.write(f'f1_score: {f1_score}\n')

In [13]:
def prepare_log(params):
    result_folder = f'../results/mean_evaluation/{datetime.datetime.now()}'
    logfile_path =  f'{result_folder}/log.log'
    create_folder(result_folder)
    create_output_folders(result_folder)

    with open(logfile_path,'w') as log_file:
        log_file.write(f"num_init_seeds: {params['num_init_seeds']}\n")
        log_file.write(f"num_final_superpixels: {params['num_final_superpixels']}\n")
        log_file.write(f"n_masses: {params['n_masses']}\n")
        log_file.write(f"num_of_tries: {params['num_of_tries']}\n")
        log_file.write(f"red_channel_mean_threshold: {params['red_channel_mean_threshold']}\n")
        log_file.write(f"green_channel_mean_threshold: {params['green_channel_mean_threshold']}\n")
        log_file.write(f"blue_channel_mean_threshold: {params['blue_channel_mean_threshold']}\n")
    return result_folder,logfile_path

In [14]:
csv_logfile_path =  '../results/mean_evaluation/log.csv'
# with open(csv_logfile_path, 'w') as log_csv_file:
#     columns = 'image_name,num_init_seeds,num_final_superpixels,n_masses,num_of_tries,'
#     columns += 'red_channel_mean_threshold,green_channel_mean_threshold,blue_channel_mean_threshold,'
#     columns += 'tp,fp,fn,precision,recall,f1_score'
#     log_csv_file.write(columns)

In [15]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 130,
    'green_channel_mean_threshold': 90,
    'blue_channel_mean_threshold': 80,
    'rgb_threshold': (130, 90, 80)
}
result_folder, logfile_path = prepare_log(params)

In [16]:
# for i in range (2, 11):
#     segment_classify_test(
#     f'{i}.png',
#     result_folder,
#     params['num_init_seeds'],
#     params['num_final_superpixels'],
#     params['n_masses'],
#     params['num_of_tries'], 
#     csv_logfile_path, 
#     logfile_path,
#     params['rgb_threshold']
# )

In [17]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 140,
    'green_channel_mean_threshold': 100,
    'blue_channel_mean_threshold': 90,
    'rgb_threshold': (140, 100, 90)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)



 - Working on image: 2.png


 - Working on image: 3.png


 - Working on image: 4.png


 - Working on image: 5.png


 - Working on image: 6.png


 - Working on image: 7.png


 - Working on image: 8.png


 - Working on image: 9.png


 - Working on image: 10.png


In [18]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 150,
    'green_channel_mean_threshold': 110,
    'blue_channel_mean_threshold': 100,
    'rgb_threshold': (150, 110, 100)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)



 - Working on image: 2.png


 - Working on image: 3.png


 - Working on image: 4.png


 - Working on image: 5.png


 - Working on image: 6.png


 - Working on image: 7.png


 - Working on image: 8.png


 - Working on image: 9.png


 - Working on image: 10.png


In [20]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 160,
    'green_channel_mean_threshold': 120,
    'blue_channel_mean_threshold': 110,
    'rgb_threshold': (160, 120, 110)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)



 - Working on image: 2.png


KeyboardInterrupt: 

In [21]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 170,
    'green_channel_mean_threshold': 130,
    'blue_channel_mean_threshold': 120,
    'rgb_threshold': (170, 130, 120)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)



 - Working on image: 2.png


KeyboardInterrupt: 

In [None]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 120,
    'green_channel_mean_threshold': 80,
    'blue_channel_mean_threshold': 70,
    'rgb_threshold': (120, 80, 70)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30


In [None]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 110,
    'green_channel_mean_threshold': 70,
    'blue_channel_mean_threshold': 60,
    'rgb_threshold': (110, 70, 60)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)

Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30
Len basis: 30


: 

In [None]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 100,
    'green_channel_mean_threshold': 60,
    'blue_channel_mean_threshold': 50,
    'rgb_threshold': (100, 60, 50)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)

In [None]:
params = {
    'num_init_seeds':7000,
    'num_final_superpixels': 4000,
    'n_masses': 300,
    'num_of_tries': 5,
    'red_channel_mean_threshold': 90,
    'green_channel_mean_threshold': 50,
    'blue_channel_mean_threshold': 40,
    'rgb_threshold': (90, 50, 40)
}
result_folder, logfile_path = prepare_log(params)
for i in range (2, 11):
    segment_classify_test(
    f'{i}.png',
    result_folder,
    params['num_init_seeds'],
    params['num_final_superpixels'],
    params['n_masses'],
    params['num_of_tries'], 
    csv_logfile_path, 
    logfile_path,
    params['rgb_threshold']
)