In [None]:
import pandas as pd
import os
import csv
import re
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
import numpy as np
import cv2
import math
%run generator.ipynb
%run model.ipynb
# import xgboost as xgb

# import warnings
# warnings.filterwarnings('ignore')
# from google.colab.patches import cv2_imshow  # Only required in Google Colab
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning, module="imageio")


# Definitions

In [None]:

def detect_cells_a1(density_map, sigma=2, threshold_factor=1.5):
    # Smooth the density map using a Gaussian filter
    smoothed_map = gaussian_filter(density_map, sigma=sigma)

    # Calculate threshold based on the mean and standard deviation of the smoothed map
    mean_val = np.mean(smoothed_map)
    std_val = np.std(smoothed_map)
    threshold = mean_val + threshold_factor * std_val
    # Identify local maxima above the threshold
    coordinates = peak_local_max(smoothed_map, min_distance=2, threshold_abs=threshold)

    # Convert to a list of (x, y) tuples
    cell_centroids = [(int(x), int(y)) for y, x in coordinates]

    return cell_centroids


def detect_cells_a2_1(density_map, sigma=2, threshold_factor=1.5):
    # Convert input to numpy if needed
    if not isinstance(density_map, np.ndarray):
        density_map = np.array(density_map)

    # Handle various input types
    if density_map.ndim == 3:  # Color image
        density_map = density_map.mean(axis=2)  # Convert to grayscale

    # Normalize to 0-1 range
    density_map = density_map.astype(np.float32)

    # Normalize to [0, 1] float
    img_norm = cv2.normalize(density_map, None, 0.0, 1.0, cv2.NORM_MINMAX)

    # Threshold the image to get binary mask of bright regions
    _, binary_mask = cv2.threshold(img_norm, 0.7, 1, cv2.THRESH_BINARY)

    # Processing pipeline
    smoothed = gaussian_filter(density_map, sigma=sigma)
    threshold = np.mean(smoothed) + threshold_factor * np.std(smoothed)
    # Suppose 'image' is your grayscale image

    coords = peak_local_max(smoothed,
                            min_distance=2,
                            threshold_abs=threshold,
                            exclude_border=False,
                            num_peaks=1000)
    # print('before------' , len(coords))
    if len(coords) != 1000:
        return [(int(x), int(y)) for y, x in coords]
    else:
        return []


In [None]:

def draw_cell_detections(image, centroids, radius=15, color=(255, 0, 0), thickness=2):
    """Draw circles around detected cells on the image."""
    vis_image = image.copy()
    if len(vis_image.shape) == 2:  # Convert grayscale to color for visualization
        vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)

    for (x, y) in centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, color, thickness)
    return vis_image


def draw_cell_detections_filterd(image, filtred_centroids, non_filtred_centroids, radius=10, color=(255, 0, 0),
                                 thickness=2):
    """Draw circles around detected cells on the image."""
    vis_image = image.copy()
    if len(vis_image.shape) == 2:  # Convert grayscale to color for visualization
        vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)

    for (x, y) in non_filtred_centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, color, thickness)
    for (x, y) in filtred_centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, (0, 0, 255), thickness)
    return vis_image


def draw_grids(draw_grid, X, Y, tile_size, save_folder, gt_visualization, rec_visualization, animal_name,
               enhanced=False):
    if draw_grid:
        for i in range(len(X)):
            x = X[i]
            y = Y[i]
            cv2.rectangle(rec_visualization, (x, y), (x + tile_size, y + tile_size), (0, 255, 0), 1)
            cv2.putText(rec_visualization, f"{x},{y}", (x + 5, y + 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 200, 255), 1, cv2.LINE_AA)

            cv2.rectangle(gt_visualization, (x, y), (x + tile_size, y + tile_size), (0, 255, 0), 1)
            cv2.putText(gt_visualization, f"{x},{y}", (x + 5, y + 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 200, 255), 1, cv2.LINE_AA)

    # Save visualizations
    cv2.imwrite(os.path.join(save_folder, f"{animal_name}_1_gt_detections.png"), gt_visualization)
    rec_visualization_cropped = rec_visualization[3:-3, 3:-3]
    if enhanced:
        cv2.imwrite(os.path.join(save_folder, f"{animal_name}_3_rec_detections.png"), rec_visualization_cropped)
    else:
        cv2.imwrite(os.path.join(save_folder, f"{animal_name}_2_rec_detections.png"), rec_visualization_cropped)

    print(f"Visualizations with detected cells saved to {save_folder}")


def search_in_csv(animal_name, csv_file):
    file_path = csv_file

    # Search term (you can also hard-code this)
    search_term = animal_name
    df = pd.read_csv(file_path)
    matching_rows = df[df.iloc[:, 0].astype(str).str.strip() == search_term]
    if not matching_rows.empty:
        return matching_rows
    return None


def save_results_csv(save_folder, manual_count_file, animal_name, gt_cell_centroids, rec_cell_centroids,
                     rec_cell_centroids_e, network_type ,dataset , xgb_pred):
    csv_path = os.path.join(save_folder, dataset + "_cell_counts.csv")
    file_exists = os.path.isfile(csv_path)
    row_manual_count = search_in_csv(animal_name, manual_count_file)
    # print(animal_name , row_manual_count)
    print("Manual GT cells: ", row_manual_count.iloc[0, 1], "\nAuto Ground truth cells _ method 1:",
          row_manual_count.iloc[0, 2], "\nAuto Ground truth cells _ method 2 :", len(gt_cell_centroids),
          "\nReconstructed cells:", len(rec_cell_centroids), "\nremoved overlapped cells:", len(rec_cell_centroids_e))
    # print("Pixel value range - min:", np.min(reconstructed), "max:", np.max(reconstructed))

    with open(csv_path, mode='a', newline='') as csv_file:
        fieldnames = ['Animal Name', 'Network Type', 'Manual GT Cells', 'Auto method GT Cells method 1 ',
                      'Auto method GT Cells method 2 ', 'Reconstructed Cells', 'Difference (gt- pred)',
                      'Abs Difference', 'Error rate', 'Abs Error rate', 'removed overlapped cells',
                      'Difference (gt- pred_e)', 'Abs Difference enh', 'Error rate enh', 'Abs Error rate enh']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        writer.writerow({
            'Animal Name': animal_name,
            'Network Type': network_type +" " +xgb_pred,
            'Manual GT Cells': row_manual_count.iloc[0, 1],
            'Auto method GT Cells method 1 ': row_manual_count.iloc[0, 2],
            'Auto method GT Cells method 2 ': len(gt_cell_centroids),
            'Reconstructed Cells': len(rec_cell_centroids),
            'Difference (gt- pred)': len(gt_cell_centroids) - len(rec_cell_centroids),
            'Abs Difference': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids)),
            'Error rate': (len(gt_cell_centroids) - len(rec_cell_centroids)) / len(gt_cell_centroids) * 100,
            'Abs Error rate': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids)) / len(
                gt_cell_centroids) * 100,

            'removed overlapped cells': len(rec_cell_centroids_e),
            'Difference (gt- pred_e)': len(gt_cell_centroids) - len(rec_cell_centroids_e),
            'Abs Difference enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e)),
            'Error rate enh': (len(gt_cell_centroids) - len(rec_cell_centroids_e)) / len(gt_cell_centroids) * 100,
            'Abs Error rate enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e)) / len(
                gt_cell_centroids) * 100,
        })


def extract_main_name(filename):
    parts = filename.split('.')

    return parts[0]  # If no tile index found


def main_name_ends_with_1(filename):
    main_name = extract_main_name(filename)
    return main_name is not None and main_name.endswith('1')


def merge_overlapping_cells_keep_others(centroids, radius=15):
    centroids = np.array(centroids, dtype=float)
    visited = np.zeros(len(centroids), dtype=bool)
    merged_centroids = []

    for i in range(len(centroids)):
        if visited[i]:
            continue

        # Find all centroids overlapping with this one
        cluster = [centroids[i]]
        visited[i] = True

        for j in range(i + 1, len(centroids)):
            if visited[j]:
                continue
            dist = np.linalg.norm(centroids[i] - centroids[j])
            if dist < 2 * radius:  # Overlap threshold
                cluster.append(centroids[j])
                visited[j] = True

        if len(cluster) == 1:
            # No overlaps → keep as is
            merged_centroids.append(tuple(cluster[0]))
        else:
            # Merge cluster → take average
            cluster_center = np.mean(cluster, axis=0)
            merged_centroids.append(tuple(cluster_center))

    return merged_centroids


In [None]:

# Reconstruct cells
def reconstruct_image_from_tiles(dataset , network_type,xgb_pred, manual_count_file, tile_folder, save_folder,  animal_name, label_image_add, file_names_saved, dic_image, sigma=5.5, threshold=2,
                                 m_color=(255, 0, 0), draw_grid=True , force_save=False):
    final_size = 1406
    tile_size = 256
    label_image = cv2.imread(label_image_add)
    label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY)
    X, Y, reconstructed = reconstruct(tile_folder, file_names_saved, animal_name, final_size, tile_size)
    # # print(len(X) , len(Y) , X , Y )
    rec_cell_centroids = detect_cells_a2_1(reconstructed, sigma=sigma, threshold_factor=threshold)
    gt_cell_centroids = detect_cells_a1(label_image, sigma=2.5, threshold_factor=1.5)

    gt_visualization = draw_cell_detections(label_image, gt_cell_centroids)
    rec_cell_centroids_aftre_removing_overlaps = merge_overlapping_cells_keep_others(rec_cell_centroids, radius=15)
    # rec_visualization = draw_cell_detections(reconstructed, rec_cell_centroids_aftre_removing_overlaps)
    rec_visualization = draw_cell_detections(dic_image, rec_cell_centroids_aftre_removing_overlaps, color=m_color)

    print("Gt cell counts: ", len(gt_cell_centroids))
    print("Predicted cell counts:  ", len(rec_cell_centroids))
    print("Predicted cell counts after removing overlaps:  ", len(rec_cell_centroids_aftre_removing_overlaps))

    # Draw grid if needed

    # Save cell counts to CSV
    draw_grids(draw_grid, X, Y, tile_size, save_folder, gt_visualization, rec_visualization, animal_name,
               enhanced=False)
    if len(rec_cell_centroids_aftre_removing_overlaps)==0  and force_save:
        print("here  saving " , str(xgb_pred))
        save_results_csv( save_folder, manual_count_file, animal_name, gt_cell_centroids, rec_cell_centroids,
                     rec_cell_centroids_aftre_removing_overlaps, network_type  , dataset ," xgb "+ str(xgb_pred))

    if len(rec_cell_centroids_aftre_removing_overlaps)==0:
        print("here not saving")
        return reconstructed, rec_visualization , len(rec_cell_centroids_aftre_removing_overlaps)

    else:
        print("here  saving " , str(xgb_pred))
        save_results_csv( save_folder, manual_count_file, animal_name, gt_cell_centroids, rec_cell_centroids,
                     rec_cell_centroids_aftre_removing_overlaps, network_type  , dataset ," xgb "+ str(xgb_pred))
    return reconstructed, rec_visualization , rec_cell_centroids_aftre_removing_overlaps


In [None]:

def reconstruct(tile_folder, file_names_saved, animal_name, final_size, tile_size):
    X = []
    Y = []
    reconstructed = np.zeros((final_size, final_size, 3), dtype=np.uint8)

    # Regex to extract x, y from filenames
    pattern = re.compile(f"{animal_name}_\d+_x(\d+)_y(\d+)\.png")

    # First pass: Load all tiles and their positions
    tiles = []
    for filename in os.listdir(tile_folder):
        match = pattern.match(filename)
        if match:
            x = int(match.group(1))
            y = int(match.group(2))
            tile_path = os.path.join(tile_folder, filename, file_names_saved)
            # print('tile path is',tile_path)
            tile_img = cv2.imread(tile_path)
            # if tile_img is not None and tile_img.std()< 10:
            #     # tile_img  = np.zeros(tile_img.shape, dtype=np.uint8)
            #     tile_img = np.ones(tile_img.shape, dtype=np.uint8) * 255  # all pixels = 255 (white)

            # Copy image (preserves metadata)
            # shutil.copy(tile_path, os.path.join(save_folder, str(filename)))
            if tile_img is not None:
                tiles.append((x, y, tile_img))
            else:
                print(f"Warning: Couldn't load {tile_path}, skipping.")

    # 1.reconstruct images:
    # 2.merge overlaps
    # 3.save the location of cells of each tile after merging
    for x, y, tile_pred in sorted(tiles, key=lambda item: (-item[0], -item[1])):

        current_region = reconstructed[y:y + tile_size, x:x + tile_size]

        tile_gray = cv2.cvtColor(tile_pred, cv2.COLOR_BGR2GRAY)
        current_gray = cv2.cvtColor(current_region, cv2.COLOR_BGR2GRAY)

        mask = tile_gray > current_gray
        for c in range(3):
            reconstructed[y:y + tile_size, x:x + tile_size, c] = np.where(
                mask,
                tile_pred[:, :, c],
                current_region[:, :, c]
            )
        X.append(x)
        Y.append(y)
    return X, Y, reconstructed


In [None]:
def get_xgb_results(animal_name, saved_add, set_name):

    # -------------------------------
    # 1. Load data
    # -------------------------------

    df = pd.read_csv(saved_add + set_name+ "/"+set_name+"_features_with_labels_f64.csv")
    row = df[df['filename'] == animal_name]
    # Drop filename column (not useful for training)
    X  = row.drop(columns=["filename", "region" , "manual count",	"auto count"])
    y  = row["region"]
    print("y is ----",y)
    return X, y

# Final Workflow - XGB_UNetFuse

In [None]:
saved_add = './Dataset/'
manual_count_file = './Dataset/match_auto_and_manaull_counts_corrected_high_discrepancy.csv'

# for both train and test
address = ['train', 'test', 'eval']
for i in range(len(address)):

    all_animals = sorted([
        f for f in os.listdir(saved_add + address[i] + '/dic/')
        if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))
    ])
    # print(all_animals)

    tiles_predicted_light_unet = './results/Light_U_Net/' + address[
        i] + '_outputs_light_u_net_256/'
    tiles_predicted_unet = './results/U_Net/' + address[i] + '_outputs_u_net_256/'
    save_dir = './results/XGB_UNetFuse/' + address[i] + '_workflow/'
    os.makedirs(save_dir, exist_ok=True)
    for animal_name in all_animals:
        # if animal_name.split('.')[0] != '':
        animal_name = animal_name.split('.')[0]
        original_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        if address[i] != 'eval':
            label_image_add = saved_add + address[i] + '/gt/' + animal_name + '.png'
        else:
            label_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        dic_image = cv2.imread(original_image_add)
        print(label_image_add)
        # animal_name= '31_8_rd1'
        X, y = get_xgb_results(animal_name, saved_add, address[i])
        loaded_model = xgb.XGBClassifier()
        loaded_model.load_model(saved_add + "xgb_regions_detection_model.json")
        # Predict with it
        y_pred = loaded_model.predict(X)
        print(animal_name, y , "---- y prediction -----",y_pred )
        #

        if y_pred ==2: # its preph
            r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] , "Light-U-Net_55_1",  y_pred , manual_count_file,       tiles_predicted_light_unet, save_dir,
            animal_name, label_image_add, 'image_pr_regions.jpg', dic_image, sigma=5.5,
            threshold=1, m_color=(255, 0, 0))

            if len(cell_counts)< 5:
                r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] , "Light-U-Net_55_075",  y_pred , manual_count_file,       tiles_predicted_light_unet, save_dir,
                animal_name, label_image_add, 'image_pr_regions.jpg', dic_image, sigma=5.5,
                threshold=0.75, m_color=(255, 0, 0))

        else : # its central or middle

            r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] ,"U-Net_55_05", y_pred,manual_count_file, tiles_predicted_unet,
            save_dir, animal_name, label_image_add,
            'image_pr_regions.jpg', dic_image, sigma=5.5, threshold=0.5, m_color=(0, 255, 0))
            if len(cell_counts) < 50:
                r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] ,"U-Net_55_075", y_pred, manual_count_file, tiles_predicted_unet,
                save_dir, animal_name, label_image_add,
                'image_pr_regions.jpg', dic_image, sigma=5.5, threshold=0.75,
                m_color=(0, 0, 255))
                if len(cell_counts)< 350:
                    r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] , "Light-U-Net_55_1",  y_pred, manual_count_file,       tiles_predicted_light_unet, save_dir,
                    animal_name, label_image_add, 'image_pr_regions.jpg', dic_image, sigma=5.5,
                    threshold=1, m_color=(255, 0, 0) , force_save= True)



    #     break
    # break


# Reconstruction of Light_U_Net and U_Net results

 You can try plotting the results of Light_U_Net and U_Net with diffrent combinations of T and sigma.

In [None]:
saved_add = './Dataset/'
manual_count_file = './Dataset/match_auto_and_manaull_counts_corrected_high_discrepancy.csv'

# for both train and test
address = [ 'train' , 'test' , 'eval']
for i in range(len(address)):

    all_animals = sorted([
        f for f in os.listdir(saved_add + address[i] + '/dic/')
        if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))
    ])
    # print(all_animals)

    tiles_predicted_light_unet = 'C:/Users/narges/PycharmProjects3/full_size_data_training/results/Paper_figures/Light_U_Net_bigger_regions/' + address[
        i] + '_outputs_light_u_net_256/'
    tiles_predicted_unet = 'C:/Users/narges/PycharmProjects3/full_size_data_training/results/Paper_figures/U_Net/' + address[i] + '_outputs_u_net_256/'
    light_save_dir = './results/Light_U_Net/' + address[i] + '_reconstructed_outputs_light_u_net_256_55_1/'
    unet_save_dir = './results/U_Net/' + address[i] + '_reconstructed_outputs_u_net_256_55_05/'
    os.makedirs(light_save_dir, exist_ok=True)
    os.makedirs(unet_save_dir, exist_ok=True)
    for animal_name in all_animals:
        # if animal_name.split('.')[0] != '':
        animal_name = animal_name.split('.')[0]
        original_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        if address[i] != 'eval':
            label_image_add = saved_add + address[i] + '/gt/' + animal_name + '.png'
        else:
            label_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'

        dic_image = cv2.imread(original_image_add)

        r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] , "Light-U-Net_55_1",  "" , manual_count_file,       tiles_predicted_light_unet, light_save_dir,
        animal_name, label_image_add, 'image_pr_regions.jpg', dic_image, sigma=5.5,
        threshold=1, m_color=(255, 0, 0))

        r, rec_visualization  , cell_counts = reconstruct_image_from_tiles(address[i] ,"U-Net_55_05", "",manual_count_file, tiles_predicted_unet,
        unet_save_dir, animal_name, label_image_add,
        'image_pr_regions.jpg', dic_image, sigma=5.5, threshold=0.5, m_color=(0, 255, 0))

    #     break
    # break
