# imports

In [2]:
import pandas as pd
import os
import csv
import re
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
import numpy as np
import cv2
import math
%run generator.ipynb
%run model.ipynb
# import warnings
# warnings.filterwarnings('ignore')
# from google.colab.patches import cv2_imshow  # Only required in Google Colab
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning, module="imageio")


# Definitions

In [3]:
def draw_cell_detections(image, centroids, radius=10, color=(255, 0, 0), thickness=2):
    """Draw circles around detected cells on the image."""
    vis_image = image.copy()
    if len(vis_image.shape) == 2:  # Convert grayscale to color for visualization
        vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)

    for (x, y) in centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, color, thickness)
    return vis_image


def draw_cell_detections_filterd(image, filtred_centroids , non_filtred_centroids , radius=10, color=(255, 0, 0), thickness=2):
    """Draw circles around detected cells on the image."""
    vis_image = image.copy()
    if len(vis_image.shape) == 2:  # Convert grayscale to color for visualization
        vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)

    for (x, y) in non_filtred_centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, color, thickness)
    for (x, y) in filtred_centroids:
        cv2.circle(vis_image, (int(x), int(y)), radius, (0,0,255), thickness)
    return vis_image

def draw_grids(draw_grid ,X,Y ,tile_size  , save_folder , gt_visualization , rec_visualization , animal_name , enhanced= False ):
    if draw_grid:
        for i in range(len(X)):
            x=X[i]
            y=Y[i]
            cv2.rectangle(rec_visualization, (x, y), (x + tile_size, y + tile_size), (0, 255, 0), 1)
            cv2.putText(rec_visualization, f"{x},{y}", (x + 5, y + 20),
                      cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 200, 255), 1, cv2.LINE_AA)

            cv2.rectangle(gt_visualization, (x, y), (x + tile_size, y + tile_size), (0, 255, 0), 1)
            cv2.putText(gt_visualization, f"{x},{y}", (x + 5, y + 20),
                      cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 200, 255), 1, cv2.LINE_AA)

    # Save visualizations
    cv2.imwrite(os.path.join(save_folder, f"{animal_name}_1_gt_detections.png"), gt_visualization)
    rec_visualization_cropped = rec_visualization[3:-3, 3:-3]
    if enhanced:
        cv2.imwrite(os.path.join(save_folder, f"{animal_name}_3_rec_detections.png"), rec_visualization_cropped)
    else:
        cv2.imwrite(os.path.join(save_folder, f"{animal_name}_2_rec_detections.png"), rec_visualization_cropped)

    print(f"Visualizations with detected cells saved to {save_folder}")


def search_in_csv(animal_name, csv_file):

    file_path = csv_file

    # Search term (you can also hard-code this)
    search_term = animal_name
    df= pd.read_csv(file_path)
    matching_rows = df[df.iloc[:, 0].astype(str).str.strip() == search_term]
    if not matching_rows.empty:
        return matching_rows
    return  None


def save_results_csv(save_folder , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , rec_cell_centroids_e  , is_blur):
    csv_path = os.path.join(save_folder, "cell_counts.csv")
    file_exists = os.path.isfile(csv_path)
    row_manual_count = search_in_csv(animal_name, manual_count_file)
    # print(animal_name , row_manual_count)
    print("Manual GT cells: ", row_manual_count.iloc[0, 1] ,"\nAuto Ground truth cells _ method 1:" , row_manual_count.iloc[0,2]  ,  "\nAuto Ground truth cells _ method 2 :", len(gt_cell_centroids), "\nReconstructed cells:", len(rec_cell_centroids) ,"\nremoved overlapped cells:", len(rec_cell_centroids_e)  )
    # print("Pixel value range - min:", np.min(reconstructed), "max:", np.max(reconstructed))

    with open(csv_path, mode='a', newline='') as csv_file:
        fieldnames = [ 'Animal Name','is_blur' , 'Manual GT Cells', 'Auto method GT Cells method 1 ' , 'Auto method GT Cells method 2 ', 'Reconstructed Cells' , 'Difference (gt- pred)' , 'Abs Difference' , 'Error rate' , 'Abs Error rate' ,  'removed overlapped cells' ,'Difference (gt- pred_e)' , 'Abs Difference enh' , 'Error rate enh'  ,'Abs Error rate enh' ]
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        writer.writerow({
            'Animal Name': animal_name,
            'is_blur': is_blur,
            'Manual GT Cells':row_manual_count.iloc[0, 1],
            'Auto method GT Cells method 1 ':row_manual_count.iloc[0, 2],
            'Auto method GT Cells method 2 ': len(gt_cell_centroids),
            'Reconstructed Cells': len(rec_cell_centroids),
            'Difference (gt- pred)' : len(gt_cell_centroids) - len(rec_cell_centroids),
            'Abs Difference': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids)),
            'Error rate': (len(gt_cell_centroids) - len(rec_cell_centroids))/len(gt_cell_centroids)*100,
            'Abs Error rate': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids))/len(gt_cell_centroids)*100,

            'removed overlapped cells': len(rec_cell_centroids_e),
            'Difference (gt- pred_e)' : len(gt_cell_centroids) - len(rec_cell_centroids_e),
            'Abs Difference enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e)),
            'Error rate enh': (len(gt_cell_centroids) - len(rec_cell_centroids_e))/len(gt_cell_centroids)*100,
            'Abs Error rate enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e))/len(gt_cell_centroids)*100,
        })


def extract_main_name(filename):
        parts = filename.split('.')

        return parts[0]  # If no tile index found

def main_name_ends_with_1(  filename):
    main_name = extract_main_name(filename)
    return main_name is not None and main_name.endswith('1')


In [4]:
def reconstruct(tile_folder ,  file_names_saved , animal_name ,  final_size  , tile_size ):
    X=[]
    Y=[]
    reconstructed = np.zeros((final_size, final_size, 3), dtype=np.uint8)

    # Regex to extract x, y from filenames
    pattern = re.compile(f"{animal_name}_\d+_x(\d+)_y(\d+)\.png")

    # First pass: Load all tiles and their positions
    tiles = []
    for filename in os.listdir(tile_folder):
        match = pattern.match(filename)
        if match:
            x = int(match.group(1))
            y = int(match.group(2))
            tile_path = os.path.join(tile_folder, filename, file_names_saved)
            # print('tile path is',tile_path)
            tile_img = cv2.imread(tile_path)
            # if tile_img is not None and tile_img.std()< 10:
            #     # tile_img  = np.zeros(tile_img.shape, dtype=np.uint8)
            #     tile_img = np.ones(tile_img.shape, dtype=np.uint8) * 255  # all pixels = 255 (white)

            # Copy image (preserves metadata)
            # shutil.copy(tile_path, os.path.join(save_folder, str(filename)))
            if tile_img is not None:
                tiles.append((x, y, tile_img))
            else:
                print(f"Warning: Couldn't load {tile_path}, skipping.")

    # 1.reconstruct images:
    # 2.merge overlaps
    # 3.save the location of cells of each tile after merging
    for x, y, tile_pred in sorted(tiles, key=lambda item: (-item[0], -item[1])):

        current_region = reconstructed[y:y+tile_size, x:x+tile_size]

        tile_gray = cv2.cvtColor(tile_pred, cv2.COLOR_BGR2GRAY)
        current_gray = cv2.cvtColor(current_region, cv2.COLOR_BGR2GRAY)

        mask = tile_gray > current_gray
        for c in range(3):
            reconstructed[y:y+tile_size, x:x+tile_size, c] = np.where(
                mask,
                tile_pred[:, :, c],
                current_region[:, :, c]
            )
        X.append(x)
        Y.append(y)
    return X,Y ,reconstructed

In [5]:
def local_blur_map(image, window_size=32, step=16):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    h, w = gray.shape
    blur_map = np.zeros((h, w), dtype=np.float32)

    for y in range(0, h - window_size + 1, step):
        for x in range(0, w - window_size + 1, step):
            patch = gray[y:y+window_size, x:x+window_size]
            variance = cv2.Laplacian(patch, cv2.CV_64F).var()
            blur_map[y:y+window_size, x:x+window_size] = variance

    # Normalize blur_map for visualization (invert so low variance = high blur)
    norm_blur_map = (blur_map - blur_map.min()) / (blur_map.max() - blur_map.min())
    heatmap = 1 - norm_blur_map  # blur = high values

    return heatmap


In [6]:
def save_overlayed_heatmap(image, heatmap, alpha=0.5, save_path='overlayed.png'):
    """
    Overlay heatmap on original image with transparency alpha.

    Args:
        image: Original BGR image (uint8).
        heatmap: Heatmap float array (same size as image), values 0..1.
        alpha: Transparency of heatmap overlay (0.0 to 1.0).
        save_path: File path to save the overlay image.
    """
    # Normalize heatmap to 0-255
    norm_heatmap = cv2.normalize( heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Apply colormap (e.g. JET)
    heatmap_color = cv2.applyColorMap(norm_heatmap, cv2.COLORMAP_JET)

    # Make sure heatmap size matches image size
    if heatmap_color.shape[:2] != image.shape[:2]:
        heatmap_color = cv2.resize(heatmap_color, (image.shape[1], image.shape[0]))

    # Blend images using addWeighted
    overlayed = cv2.addWeighted(src1=image, alpha=1-alpha, src2=heatmap_color, beta=alpha, gamma=0)

    # Save result
    cv2.imwrite(save_path, overlayed)
    print(f"Overlayed heatmap saved to {save_path}")




# check the bluriness

In [7]:

def detect_blur(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    variance = cv2.Laplacian(gray, cv2.CV_64F).var()
    return variance

def detect_brightness_issues(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    mean_intensity = np.mean(gray)
    return mean_intensity

def estimate_cell_density(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)

def analyze_image_quality(image_path):
    image = cv2.imread(image_path)

    blur_score = detect_blur(image)
    brightness_score = detect_brightness_issues(image)
    density_score = estimate_cell_density(image)

    quality_report = {
        "blur_score": blur_score,
        "brightness_score": brightness_score,
        "density_estimate": density_score,
        "blurred": blur_score < 80,
        "too_dark": brightness_score < 50,
        "too_bright": brightness_score > 200
    }

    return quality_report , blur_score < 80


In [8]:
def merge_overlapping_cells_keep_others(centroids, radius=15):

    centroids = np.array(centroids, dtype=float)
    visited = np.zeros(len(centroids), dtype=bool)
    merged_centroids = []

    for i in range(len(centroids)):
        if visited[i]:
            continue

        # Find all centroids overlapping with this one
        cluster = [centroids[i]]
        visited[i] = True

        for j in range(i+1, len(centroids)):
            if visited[j]:
                continue
            dist = np.linalg.norm(centroids[i] - centroids[j])
            if dist < 2 * radius:  # Overlap threshold
                cluster.append(centroids[j])
                visited[j] = True

        if len(cluster) == 1:
            # No overlaps → keep as is
            merged_centroids.append(tuple(cluster[0]))
        else:
            # Merge cluster → take average
            cluster_center = np.mean(cluster, axis=0)
            merged_centroids.append(tuple(cluster_center))

    return merged_centroids


In [9]:
def merge_overlapping_cells(centroids, radius=10):

    centroids = np.array(centroids, dtype=float)
    visited = np.zeros(len(centroids), dtype=bool)
    merged_centroids = []

    for i in range(len(centroids)):
        if visited[i]:
            continue

        # Start a cluster with the current centroid
        cluster = [centroids[i]]
        visited[i] = True

        # Find all centroids overlapping with this one
        for j in range(i+1, len(centroids)):
            if visited[j]:
                continue
            dist = np.linalg.norm(centroids[i] - centroids[j])
            if dist < 2 * radius:  # overlap threshold
                cluster.append(centroids[j])
                visited[j] = True

        # Average the cluster to get the new merged centroid
        cluster_center = np.mean(cluster, axis=0)
        merged_centroids.append(tuple(cluster_center))

    return merged_centroids


# Counting approaches

In [10]:
def filter_centroids_red_and_blur(image, centroids, blur_map, blur_threshold=100):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    lower_red1 = np.array([0, 70, 50])
    upper_red1 = np.array([10, 255, 255])
    lower_red2 = np.array([170, 70, 50])
    upper_red2 = np.array([180, 255, 255])

    mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
    red_mask = cv2.bitwise_or(mask1, mask2)
    red_mask_bool = red_mask > 0

    filtered = []
    non_filtered = []

    for (x, y) in centroids:
        if 0 <= y < image.shape[0] and 0 <= x < image.shape[1]:
            is_red = red_mask_bool[y, x]
            blur_score = blur_map[y, x]
            is_blurred = blur_score > blur_threshold

            if is_red and is_blurred:
                filtered.append((x, y))
            else:
                non_filtered.append((x, y))
        else:
            non_filtered.append((x, y))

    return filtered, non_filtered


In [11]:
def detect_cells_a1(density_map, sigma=2, threshold_factor=1.5):

    # Smooth the density map using a Gaussian filter
    smoothed_map = gaussian_filter(density_map, sigma=sigma)

    # Calculate threshold based on the mean and standard deviation of the smoothed map
    mean_val = np.mean(smoothed_map)
    std_val = np.std(smoothed_map)
    threshold = mean_val + threshold_factor * std_val
    # Identify local maxima above the threshold
    coordinates = peak_local_max(smoothed_map, min_distance=2, threshold_abs=threshold)

    # Convert to a list of (x, y) tuples
    cell_centroids = [(int(x), int(y)) for y, x in coordinates]

    return cell_centroids



In [12]:
def detect_cells_a2_1(density_map, sigma=2, threshold_factor=1.5):

    # Convert input to numpy if needed
    if not isinstance(density_map, np.ndarray):
        density_map = np.array(density_map)

    # Handle various input types
    if density_map.ndim == 3:  # Color image
        density_map = density_map.mean(axis=2)  # Convert to grayscale

    # Normalize to 0-1 range
    density_map = density_map.astype(np.float32)

    # Normalize to [0, 1] float
    img_norm = cv2.normalize(density_map, None, 0.0, 1.0, cv2.NORM_MINMAX)

    # Threshold the image to get binary mask of bright regions
    _, binary_mask = cv2.threshold(img_norm, 0.7, 1, cv2.THRESH_BINARY)

    # Processing pipeline
    smoothed = gaussian_filter(density_map, sigma=sigma)
    threshold = np.mean(smoothed) + threshold_factor * np.std(smoothed)
    # Suppose 'image' is your grayscale image

    coords = peak_local_max(smoothed,
                          min_distance=2,
                          threshold_abs=threshold,
                          exclude_border=False,
                          num_peaks=1000)
    # print('before------' , len(coords))
    if len(coords) != 1000:
        return     [(int(x), int(y)) for y, x in coords]
    else:
        return []

In [13]:

def print_color_info_at_centroids(image, centroids):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # Define simple HSV ranges for red, green, and blue (tune as needed)
    lower_red1 = np.array([0, 70, 50])
    upper_red1 = np.array([10, 255, 255])
    lower_red2 = np.array([170, 70, 50])
    upper_red2 = np.array([180, 255, 255])

    lower_green = np.array([35, 70, 50])
    upper_green = np.array([85, 255, 255])

    lower_blue = np.array([90, 70, 50])
    upper_blue = np.array([130, 255, 255])

    lower_yellow = np.array([20, 70, 50])
    upper_yellow = np.array([30, 255, 255])

    filtered_centroids = []
    non_filtered_centroids = []
    for i, (x, y) in enumerate(centroids):
        if 0 <= y < image.shape[0] and 0 <= x < image.shape[1]:
            bgr = image[y, x]  # Note: OpenCV uses BGR order
            hsv_pixel = hsv[y, x]

            # Check red
            is_red = ((lower_red1 <= hsv_pixel).all() and (hsv_pixel <= upper_red1).all()) or \
                     ((lower_red2 <= hsv_pixel).all() and (hsv_pixel <= upper_red2).all())
            # Check green
            is_green = (lower_green <= hsv_pixel).all() and (hsv_pixel <= upper_green).all()
            # Check blue
            is_blue = (lower_blue <= hsv_pixel).all() and (hsv_pixel <= upper_blue).all()
            is_yellow = (lower_yellow <= hsv_pixel).all() and (hsv_pixel <= upper_yellow).all()

            # print(f"Centroid {i} at (x={x}, y={y}): BGR={bgr}, HSV={hsv_pixel}, red={is_red}, green={is_green}, blue={is_blue}")
            if is_red or is_yellow:
                filtered_centroids.append((x, y))
            else :
                non_filtered_centroids.append((x, y))
        else:
            print(f"Centroid {i} at (x={x}, y={y}): Outside image bounds")
    return filtered_centroids, non_filtered_centroids

# Combine the codes part1

In [49]:
def reconstruct_image_from_tiles(is_blur, manual_count_file , tile_folder  , enhanced_tiles_predicted, save_folder, animal_name, label_image_add,file_names_saved, output_path=None, draw_grid=True , enhanced= False ):
    final_size = 1406
    tile_size = 256
    label_image = cv2.imread(label_image_add)
    label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY)
    X,Y, reconstructed = reconstruct(tile_folder ,  file_names_saved , animal_name ,  final_size  , tile_size )
    heatmap = local_blur_map(reconstructed)
    # Convert heatmap to 8-bit and apply a color map for visualization
    heatmap_uint8 = (heatmap * 255).astype('uint8')
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)

    # Save heatmap image
    # cv2.imwrite(save_folder+ animal_name+ '_4_blur_heatmap.jpg', heatmap_color)

    # print(len(X) , len(Y) , X , Y )
    rec_cell_centroids=  detect_cells_a2_1(reconstructed , sigma=2.5, threshold_factor=2)
    gt_cell_centroids = detect_cells_a1(label_image , sigma=2.5, threshold_factor=1.5)
    gt_visualization = draw_cell_detections(label_image, gt_cell_centroids)
    rec_cell_centroids = merge_overlapping_cells_keep_others(rec_cell_centroids)
    print('before changing the threshold factor' , len(rec_cell_centroids))
    # Create visualizations with detected cells
    if is_blur:
        rec_cell_centroids=  detect_cells_a2_1(reconstructed , sigma=2.5, threshold_factor=2)
        print('-------------This image is blurred-----------------')
        filtered_centroids, non_filtered_centroids = print_color_info_at_centroids(heatmap_color, rec_cell_centroids)
        print("Predicted cell counts:  "  , len(rec_cell_centroids))
        print("FILTRED OR NON FILTERD\n",len(filtered_centroids) , len(non_filtered_centroids) )
        # rec_cell_centroids=non_filtered_centroids
        rec_cell_centroids = merge_overlapping_cells_keep_others(rec_cell_centroids)

        rec_visualization = draw_cell_detections_filterd(reconstructed,  filtered_centroids, non_filtered_centroids)
        save_overlayed_heatmap(rec_visualization, heatmap_color, alpha=0.5, save_path=save_folder+ animal_name+ '_5_blur_heatmap.jpg')

    rec_visualization = draw_cell_detections(reconstructed, rec_cell_centroids)

    print("-----------------------final cell counts------------------------")
    print("Gt cell counts: ",len(gt_cell_centroids))
    print("Predicted cell counts:  "  , len(rec_cell_centroids) )


    # Draw grid if needed

    draw_grids(draw_grid ,X,Y ,tile_size  , save_folder , gt_visualization , rec_visualization , animal_name , enhanced )
    # Save cell counts to CSV
    draw_grids(draw_grid ,X,Y ,tile_size  , save_folder , gt_visualization , rec_visualization , animal_name , enhanced=False )

    # print(' =================Ehanced images===============')
    #
    # X_e,Y_e, reconstructed_e = reconstruct(enhanced_tiles_predicted ,  file_names_saved , animal_name ,  final_size  , tile_size )
    # # print(len(X_e) , len(Y_e) , X_e , Y_e )
    # # Create visualizations with detected cells
    # rec_cell_centroids_e= detect_cells_a2_1(reconstructed_e , sigma=2.5, threshold_factor=2)
    # rec_cell_centroids_e = merge_overlapping_cells_keep_others(rec_cell_centroids_e)
    #
    # if is_blur:
    #     rec_cell_centroids_e=  detect_cells_a2_1(reconstructed_e , sigma=0.5, threshold_factor=0.5)
    #     print("Predicted cell counts:  "  , len(rec_cell_centroids_e))
    #     filtered_centroids_e, non_filtered_centroids_e = print_color_info_at_centroids(heatmap_color, rec_cell_centroids_e)
    #     print("FILTRED OR NON FILTERD\n",len(filtered_centroids_e) , len(non_filtered_centroids_e) )
    #     rec_cell_centroids_e= non_filtered_centroids_e
    #     rec_cell_centroids_e = merge_overlapping_cells_keep_others(rec_cell_centroids_e)
    #
    #     rec_visualization_e = draw_cell_detections_filterd(reconstructed_e,  filtered_centroids_e, non_filtered_centroids_e)
    #     save_overlayed_heatmap(rec_visualization_e, heatmap_color, alpha=0.5, save_path=save_folder+ animal_name+ '_6_blur_heatmap_enhanced.jpg')
    #
    # print("Predicted cell counts enhanced one:  "  , len(rec_cell_centroids_e))
    # rec_visualization_e = draw_cell_detections(reconstructed_e, rec_cell_centroids_e)



    # Draw grid if needed
    # draw_grids(draw_grid ,X_e,Y_e ,tile_size  , save_folder , gt_visualization , rec_visualization_e , animal_name , enhanced=True )
    # Save cell counts to CSV

    save_results_csv(save_folder , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , [] , is_blur )
    return reconstructed


In [50]:

saved_add = 'C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/'
manual_count_file= 'C:/Users/narges/PycharmProjects3/full_size_data_training/results/match_auto_and_manaull_counts_corrected_high_discrepancy.csv'


# for both train and test
address = ['test']
for i in range(len(address)):

    all_animals = sorted([
            f for f in os.listdir(saved_add + address[i] + '/dic/')
            if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')) #and main_name_ends_with_1(f)
        ])
    # print(all_animals)

    tiles_predicted ='./results/Light_U_Net/train_256/' + address[i] + '_outputs_light_u_net_256/'
    enhanced_tiles_predicted ='./results/Light_U_Net/train_256/' + address[i] + '_outputs_light_u_net_enhanced_256_1/'

    save_dir = './results/Light_U_Net/train_256_1/' + address[i] + '_outputs_Light_u_net_reconstructed_centroids_app_1/'
    os.makedirs(save_dir, exist_ok=True)
    for animal_name in all_animals:
        animal_name = animal_name.split('.')[0]
        # animal_name= '23lu2'
        original_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        label_image_add = saved_add + address[i] + '/gt/' + animal_name + '.png'
        print(label_image_add)
        quality_metrics , is_blur = analyze_image_quality(original_image_add)
        print(quality_metrics)

        reconstruct_image_from_tiles(is_blur  , manual_count_file , tiles_predicted , enhanced_tiles_predicted , save_dir, animal_name , label_image_add, 'image_pr_centroids1.jpg', os.path.join(save_dir, f"{animal_name}.png") , enhanced=  False )

        # break


C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/test/gt/10ld2.png
{'blur_score': 41.35636766846728, 'brightness_score': 179.32247346938775, 'density_estimate': 2467, 'blurred': True, 'too_dark': False, 'too_bright': False}
before changing the threshold factor 366
-------------This image is blurred-----------------
Predicted cell counts:   385
FILTRED OR NON FILTERD
 144 241
Overlayed heatmap saved to ./results/Light_U_Net/train_256_1/test_outputs_Light_u_net_reconstructed_centroids_app_1/10ld2_5_blur_heatmap.jpg
-----------------------final cell counts------------------------
Gt cell counts:  348
Predicted cell counts:   366
Visualizations with detected cells saved to ./results/Light_U_Net/train_256_1/test_outputs_Light_u_net_reconstructed_centroids_app_1/
Visualizations with detected cells saved to ./results/Light_U_Net/train_256_1/test_outputs_Light_u_net_reconstructed_centroids_app_1/
Manual GT cells:  350 
Auto Ground truth cells _ method 1: 

# combine codes part2

In [16]:
def reconstruct_image_from_tiles(manual_count_file , tile_folder , save_folder, animal_name, label_image_add,file_names_saved,  draw_grid=True ):
    final_size = 1406
    tile_size = 256
    label_image = cv2.imread(label_image_add)
    label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY)
    X,Y, reconstructed = reconstruct(tile_folder ,  file_names_saved , animal_name ,  final_size  , tile_size )
    # # print(len(X) , len(Y) , X , Y )
    rec_cell_centroids=  detect_cells_a2_1(reconstructed , sigma=5.5, threshold_factor=2)
    gt_cell_centroids = detect_cells_a1(label_image , sigma=2.5, threshold_factor=1.5)

    gt_visualization = draw_cell_detections(label_image, gt_cell_centroids)
    rec_cell_centroids_aftre_removing_overlaps = merge_overlapping_cells_keep_others(rec_cell_centroids , radius=15)
    rec_visualization = draw_cell_detections(reconstructed, rec_cell_centroids_aftre_removing_overlaps)

    print("Gt cell counts: ",len(gt_cell_centroids))
    print("Predicted cell counts:  "  , len(rec_cell_centroids) )
    print("Predicted cell counts after removing overlaps:  "  , len(rec_cell_centroids_aftre_removing_overlaps) )


    # Draw grid if needed

    # Save cell counts to CSV
    draw_grids(draw_grid ,X,Y ,tile_size  , save_folder , gt_visualization , rec_visualization , animal_name , enhanced=False )
    save_results_csv(save_folder , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , rec_cell_centroids_aftre_removing_overlaps , '_' )
    return reconstructed


In [17]:

saved_add = 'C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/'
manual_count_file= 'C:/Users/narges/PycharmProjects3/full_size_data_training/results/match_auto_and_manaull_counts_corrected_high_discrepancy.csv'


# for both train and test
address = ['train'  , 'test']
for i in range(len(address)):

    all_animals = sorted([
            f for f in os.listdir(saved_add + address[i] + '/dic/')
            if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')) #and main_name_ends_with_1(f)
        ])
    # print(all_animals)

    tiles_predicted ='./results/Paper_figures/Light_U_Net_bigger_regions/'+address[i] +'_outputs_light_u_net_256/'

    # 'results\Paper_figures\Light_U_Net\train_outputs_light_u_net_256'

    save_dir = './results/Paper_figures/Light_U_Net_bigger_regions_thresh_2/'+address[i] +'_reconstructed_regions_sigma_5_and_half/'
    os.makedirs(save_dir, exist_ok=True)
    for animal_name in all_animals:
        # if animal_name.split('.')[0] != '':
        animal_name = animal_name.split('.')[0]
        # animal_name= '10rl2'
        original_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        label_image_add = saved_add + address[i] + '/gt/' + animal_name + '.png'
        print(label_image_add)
        # quality_metrics , is_blur = analyze_image_quality(original_image_add)
        # print(quality_metrics)

        reconstruct_image_from_tiles( manual_count_file , tiles_predicted  , save_dir, animal_name , label_image_add,'image_pr_regions.jpg', )
                     # 'image_pr_centroids1.jpg', )


        # break


C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/train/gt/10ld1.png
Gt cell counts:  379
Predicted cell counts:   0
Predicted cell counts after removing overlaps:   0
Visualizations with detected cells saved to ./results/Paper_figures/Light_U_Net_bigger_regions_thresh_2/train_reconstructed_regions_sigma_5_and_half/
Manual GT cells:  381 
Auto Ground truth cells _ method 1: 379 
Auto Ground truth cells _ method 2 : 379 
Reconstructed cells: 0 
removed overlapped cells: 0
C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/train/gt/10ld3.png
Gt cell counts:  238
Predicted cell counts:   188
Predicted cell counts after removing overlaps:   181
Visualizations with detected cells saved to ./results/Paper_figures/Light_U_Net_bigger_regions_thresh_2/train_reconstructed_regions_sigma_5_and_half/
Manual GT cells:  239 
Auto Ground truth cells _ method 1: 238 
Auto Ground truth cells _ method 2 : 238 
Reconstructed cells:

# watershed algorithms

In [12]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, filters, morphology, measure, feature, color
from skimage.segmentation import watershed
from scipy import ndimage as ndi
import pandas as pd
import os
from skimage.restoration import denoise_bilateral, denoise_tv_chambolle
from skimage.measure import regionprops
from skimage import img_as_ubyte, color
import cv2

In [13]:
def make_centroid_img(centroids_loc , dimension):
    # Create a black image
    img = np.zeros((dimension[0], dimension[1],3), dtype=np.uint8)

    # Draw white dots
    for xi, yi in centroids_loc:
        cv2.circle(img, (int(xi), int(yi)), radius=3, color=(255, 255, 255), thickness=-1)  # filled circle
    return img


In [14]:

def save_results_csv_all(save_folder , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , rec_cell_centroids_e  , watershed_cell_counts, is_blur):
    csv_path = os.path.join(save_folder, "cell_counts_all.csv")
    file_exists = os.path.isfile(csv_path)
    row_manual_count = search_in_csv(animal_name, manual_count_file)
    # print(animal_name , row_manual_count)
    print("Manual GT cells: ", row_manual_count.iloc[0, 1] ,"\nAuto Ground truth cells _ method 1:" , row_manual_count.iloc[0,2]  ,  "\nAuto Ground truth cells _ method 2 :", len(gt_cell_centroids), "\nReconstructed cells:", len(rec_cell_centroids) ,"\nremoved overlapped cells:", len(rec_cell_centroids_e) , "\nwatershedd cells:",  watershed_cell_counts  )
    # print("Pixel value range - min:", np.min(reconstructed), "max:", np.max(reconstructed))

    with open(csv_path, mode='a', newline='') as csv_file:
        fieldnames = [ 'Animal Name','is_blur' , 'Manual GT Cells', 'Auto method GT Cells method 1 ' , 'Auto method GT Cells method 2 ', 'Reconstructed Cells' , 'Difference (gt- pred)' , 'Abs Difference' , 'Error rate' , 'Abs Error rate' ,  'removed overlapped cells' ,'Difference (gt- pred_e)' , 'Abs Difference enh' , 'Error rate enh'  ,'Abs Error rate enh'  , 'watershed cell counts' ,'Diffrence (gt-water)','Error rate watershed' , 'Abs Error rate watershed' ]
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        writer.writerow({
            'Animal Name': animal_name,
            'is_blur': is_blur,
            'Manual GT Cells':row_manual_count.iloc[0, 1],
            'Auto method GT Cells method 1 ':row_manual_count.iloc[0, 2],
            'Auto method GT Cells method 2 ': len(gt_cell_centroids),
            'Reconstructed Cells': len(rec_cell_centroids),
            'Difference (gt- pred)' : len(gt_cell_centroids) - len(rec_cell_centroids),
            'Abs Difference': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids)),
            'Error rate': (len(gt_cell_centroids) - len(rec_cell_centroids))/len(gt_cell_centroids)*100,
            'Abs Error rate': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids))/len(gt_cell_centroids)*100,

            'removed overlapped cells': len(rec_cell_centroids_e),
            'Difference (gt- pred_e)' : len(gt_cell_centroids) - len(rec_cell_centroids_e),
            'Abs Difference enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e)),
            'Error rate enh': (len(gt_cell_centroids) - len(rec_cell_centroids_e))/len(gt_cell_centroids)*100,
            'Abs Error rate enh': math.fabs(len(gt_cell_centroids) - len(rec_cell_centroids_e))/len(gt_cell_centroids)*100,

            'watershed cell counts':  watershed_cell_counts,
            'Diffrence (gt-water)':  len(gt_cell_centroids) - watershed_cell_counts,
            'Error rate watershed': (len(gt_cell_centroids) - watershed_cell_counts)/len(gt_cell_centroids)*100,
            'Abs Error rate watershed': math.fabs(len(gt_cell_centroids) - watershed_cell_counts)/len(gt_cell_centroids)*100
        })



In [15]:
def reconstruct_image_from_tiles_return_cell_locations(manual_count_file , tile_folder , save_folder, animal_name, label_image_add,file_names_saved,  draw_grid=True ):
    final_size = 1406
    tile_size = 256
    label_image = cv2.imread(label_image_add)
    label_image = cv2.cvtColor(label_image, cv2.COLOR_BGR2GRAY)
    X,Y, reconstructed = reconstruct(tile_folder ,  file_names_saved , animal_name ,  final_size  , tile_size )
    # # print(len(X) , len(Y) , X , Y )
    rec_cell_centroids=  detect_cells_a2_1(reconstructed , sigma=3, threshold_factor=1)
    gt_cell_centroids = detect_cells_a1(label_image , sigma=2.5, threshold_factor=1.5)

    gt_visualization = draw_cell_detections(label_image, gt_cell_centroids)
    rec_cell_centroids_aftre_removing_overlaps = merge_overlapping_cells_keep_others(rec_cell_centroids , radius=15)
    rec_visualization = draw_cell_detections(reconstructed, rec_cell_centroids_aftre_removing_overlaps)

    print("Gt cell counts: ",len(gt_cell_centroids))
    print("Predicted cell counts:  "  , len(rec_cell_centroids) )
    print("Predicted cell counts after removing overlaps:  "  , len(rec_cell_centroids_aftre_removing_overlaps) )


    # Draw grid if needed

    # Save cell counts to CSV
    draw_grids(draw_grid ,X,Y ,tile_size  , save_folder , gt_visualization , rec_visualization , animal_name , enhanced=False )
    # save_results_csv(save_folder , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , rec_cell_centroids_aftre_removing_overlaps , '_' )
    return reconstructed , gt_cell_centroids , rec_cell_centroids ,rec_cell_centroids_aftre_removing_overlaps ,  rec_visualization


In [23]:
def my_watershed(dic, reconstructed_img, centroids , which_out  ,  save_address):

        r_image = cv2.cvtColor(reconstructed_img, cv2.COLOR_RGB2GRAY)
        centroid_image = make_centroid_img(centroids , reconstructed_img.shape)
        centroid_image = cv2.cvtColor(centroid_image, cv2.COLOR_BGR2GRAY)
        dic = cv2.cvtColor(dic, cv2.COLOR_BGR2GRAY)


        if which_out == 1:
            image= r_image
        else:
            image= dic
            centroid_image = centroid_image[3: -3, 3: -3]
            if image.shape!= (1400,1400):

                image = pad_to_size(image, target_shape=(1400,1400))
                print(image.shape)  # should print (1400, 1400)
        print(image.shape, centroid_image.shape)

        # Apply Gaussian filter to smooth the original image (adjust sigma as needed)
        # smoothed_image = filters.gaussian(centroid_image, sigma=3.5)

        # Segment the yeast cell regions using a threshold
        # threshold = filters.threshold_otsu(smoothed_image)
        # centroid_image = smoothed_image > threshold

        # Apply Gaussian filter to smooth the original image (adjust sigma as needed)
        smoothed_image = filters.gaussian(image, sigma=0.000000001)

        # Segment the yeast cell regions using a threshold
        threshold = filters.threshold_otsu(smoothed_image)
        binary_image = smoothed_image > threshold

        # Compute distance transform using scipy.ndimage
        distance_transform = ndi.distance_transform_edt(binary_image)

        # Use the centroid image as markers for watershed segmentation
        markers = measure.label(centroid_image)

        # Apply watershed algorithm
        labels = watershed(-distance_transform, markers, mask=binary_image)

        # Remove small objects (adjust min_size as needed)
        labels = morphology.remove_small_objects(labels, min_size=20)
        # Output the number of labeled yeast cells
        # num_cells = labels.max()
        regions = regionprops(labels)

        print("watershed cell counts:" , len(regions))
        label_rgb = color.label2rgb(labels, image=image, bg_label=0)
        label_rgb_uint8 = img_as_ubyte(label_rgb)  # scales [0,1] -> [0,255]

        # # Plot the results
        # fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        # axes[0].imshow(image, cmap='gray')
        # axes[0].set_title('Original Image')
        # axes[1].imshow(centroid_image, cmap='gray')
        # axes[1].set_title('Centroid Image')
        # axes[2].imshow(color.label2rgb(labels, image=image, bg_label=0))
        # axes[2].set_title('Segmented Image with Watershed')
        # for ax in axes:
        #     ax.axis('off')
        # plt.tight_layout()
        # plt.show()
        return len(regions) , label_rgb_uint8


In [28]:
import numpy as np

def pad_to_size(img, target_shape=(1400,1400)):
    """
    Pad a 2D image with zeros on all sides to reach target_shape.
    """
    current_shape = img.shape
    pad_height = target_shape[0] - current_shape[0]
    pad_width = target_shape[1] - current_shape[1]

    # Compute padding for top/bottom and left/right
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left

    # Apply zero padding
    padded_img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0)
    return padded_img



In [29]:
def process_image_tv(img, output_dir="outputs"):
    # Make output directory
    os.makedirs(output_dir, exist_ok=True)

    # Read image (grayscale)

    # ------------------- TV DENOISING -------------------
    tv = denoise_tv_chambolle(img, weight=0.1)
    tv = (tv * 255).astype(np.uint8)
    # cv2.imwrite(os.path.join(output_dir, "denoised_tv.png"), tv)

    # ------------------- SHARPENING TV -------------------

    # 2. High-boost sharpening on TV
    blur_hb = cv2.GaussianBlur(tv, (5,5), 0)
    mask = cv2.subtract(tv, blur_hb)
    k = 8.0
    tv_highboost = cv2.add(tv, cv2.multiply(mask, k))
    # cv2.imwrite(os.path.join(output_dir, "tv_sharpened_highboost.png"), tv_highboost)

    # print(f"Saved: denoised_tv.png + sharpened TV versions in {output_dir}")
    return tv_highboost



In [33]:

saved_add = 'C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/'
manual_count_file= 'C:/Users/narges/PycharmProjects3/full_size_data_training/results/match_auto_and_manaull_counts_corrected_high_discrepancy.csv'


# for both train and test
address = ['train'  , 'test']
for i in range(len(address)):

    all_animals = sorted([
            f for f in os.listdir(saved_add + address[i] + '/dic/')
            if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')) #and main_name_ends_with_1(f)
        ])
    # print(all_animals)

    tiles_predicted ='./results/Paper_figures/Light_U_Net/'+address[i] +'_outputs_light_u_net_256/'

    # 'results\Paper_figures\Light_U_Net\train_outputs_light_u_net_256'

    save_dir = './results/Paper_figures/Light_U_Net/'+address[i] +'_segmented_reconstructed_regions/'
    os.makedirs(save_dir, exist_ok=True)
    for animal_name in all_animals:
        # if animal_name.split('.')[0] != '':
        animal_name = animal_name.split('.')[0]
        # animal_name= '10rl2'
        original_image_add = saved_add + address[i] + '/dic/' + animal_name + '.png'
        label_image_add = saved_add + address[i] + '/gt/' + animal_name + '.png'
        print(label_image_add)
        # quality_metrics , is_blur = analyze_image_quality(original_image_add)
        # print(quality_metrics)

        reconstructed , gt_cell_centroids , rec_cell_centroids ,rec_cell_centroids_aftre_removing_overlaps ,  rec_visualization = reconstruct_image_from_tiles_return_cell_locations( manual_count_file , tiles_predicted  , save_dir, animal_name , label_image_add,'image_pr_regions.jpg', )
                     # 'image_pr_centroids1.jpg', )
        # print(reconstructed.shape)
        # dic, reconstructed_img, centroids , real_cell_counts
        dic= cv2.imread(original_image_add)
        sharped_dic= process_image_tv(dic, output_dir="./denoisong/")
        seg_save= save_dir + animal_name + '_segment.png'
        print(seg_save)
        if len(rec_cell_centroids_aftre_removing_overlaps) >0:

            labels , label_rgb_uint8 = my_watershed(sharped_dic , reconstructed, rec_cell_centroids_aftre_removing_overlaps , seg_save,0)
            cv2.imwrite( seg_save ,label_rgb_uint8)
        else:
            labels = 0
        save_results_csv_all(save_dir , manual_count_file ,   animal_name , gt_cell_centroids ,  rec_cell_centroids , rec_cell_centroids_aftre_removing_overlaps ,labels,  '_' )
        # break


C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/train/gt/10ld1.png
Gt cell counts:  379
Predicted cell counts:   450
Predicted cell counts after removing overlaps:   395
Visualizations with detected cells saved to ./results/Paper_figures/Light_U_Net/train_segmented_reconstructed_regions/
./results/Paper_figures/Light_U_Net/train_segmented_reconstructed_regions/10ld1_segment.png
(1400, 1400) (1400, 1400)
watershed cell counts: 347
Manual GT cells:  381 
Auto Ground truth cells _ method 1: 379 
Auto Ground truth cells _ method 2 : 379 
Reconstructed cells: 450 
removed overlapped cells: 395 
watershedd cells: 347
C:/Users/narges/PycharmProjects3/pythonProject3/Compelet_dataset_preprocessing/datasets/train/gt/10ld3.png
Gt cell counts:  238
Predicted cell counts:   314
Predicted cell counts after removing overlaps:   281
Visualizations with detected cells saved to ./results/Paper_figures/Light_U_Net/train_segmented_reconstructed_regions/
./results/Pa