# Imports

In [None]:
from scipy.ndimage import gaussian_filter, label
from skimage.feature import peak_local_max
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
import numpy as np

%run generator.ipynb
%run model.ipynb


# import warnings
# warnings.filterwarnings('ignore')
# warnings.filterwarnings("ignore", category=UserWarning, module="imageio")
# from google.colab.patches import cv2_imshow  # Only required in Google Colab


In [None]:
# Check TensorFlow version
print(f"TensorFlow version: {tf.__version__}")

# List available physical devices
physical_devices = tf.config.list_physical_devices('GPU')
print("Available GPUs:")
for device in physical_devices:
    print(device)

# Check if TensorFlow is using the GPU
if tf.test.is_built_with_cuda():
    print("TensorFlow is built with CUDA")
    if len(physical_devices) > 0:
        print("TensorFlow is using the GPU")
    else:
        print("TensorFlow is not using the GPU, but it is available")
else:
    print("TensorFlow is not built with CUDA")



# Functions

In [None]:
class CustomDataGenerator(tf.keras.utils.Sequence):

    def extract_main_name(self ,filename):
        parts = filename.split('_')
        for i, part in enumerate(parts):
            if re.fullmatch(r'\d{3}', part):  # Find the first 3-digit part
                return '_'.join(parts[:i])    # Main name is everything before that
        return None  # If no tile index found

    def main_name_ends_with_1(self , filename):
        main_name = self.extract_main_name(filename)
        return main_name is not None and main_name.endswith('1')

    def __init__(self, data_dir, label_dir , label_dir_regions, batch_size, dim, n_channels=1, n_classes_or_output_dim=1 , shuffle=True, augmentor=None ):
        'Initialization'
        self.dim = dim  # e.g., (256, 256)
        self.batch_size = batch_size
        self.data_dir = data_dir  # Directory containing input images
        self.label_dir = label_dir  # Directory containing label images
        self.label_dir_regions = label_dir_regions
        self.n_channels = n_channels  # Will be 1 for grayscale
        self.n_classes_or_output_dim = n_classes_or_output_dim  # e.g., 1 for regression output
        self.shuffle = shuffle
        self.augmentor = augmentor  # Your ImageDataGenerator instance
        self.target_size=self.dim[0]
        self.filtered_filenames = sorted([
            f for f in os.listdir(self.data_dir)
            if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))# and self.main_name_ends_with_1(f)
             ])


        self.data_files = [os.path.join(data_dir, f) for f in self.filtered_filenames]
        self.label_files = [os.path.join(label_dir, f) for f in self.filtered_filenames]
        self.label_regions_files = [os.path.join(label_dir_regions, f) for f in self.filtered_filenames]

        # Get list of all image files in the directories
        # self.data_files = [os.path.join(data_dir, f) for f in self.filtered_filenames]
        #                     self.filtered_filenames = sorted([ f for f in os.listdir(data_dir)
        #                     if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')) and self.main_name_ends_with_1(f) ])

        # self.data_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir)
        #                         if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))])
        # self.label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir)
        #                          if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))])
        # self.label_regions_files = sorted([os.path.join(label_dir_regions, f) for f in os.listdir(label_dir_regions)
        #                          if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))])

        # Verify that we have matching pairs of data and label files
        assert len(self.data_files) == len(self.label_files), "Number of input and label images must match"

        self.indexes = np.arange(len(self.data_files))
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.indexes) / self.batch_size))

    def __getitem__(self, index):
        try:
            # Calculate safe indexes
            start_idx = index * self.batch_size
            end_idx = min((index + 1) * self.batch_size, len(self.indexes))
            batch_indexes = self.indexes[start_idx:end_idx]

            # Ensure we don't return empty batches
            if len(batch_indexes) == 0:
                batch_indexes = self.indexes[-self.batch_size:]  # Take last complete batch

            X, y , y_r, filenames_ = self.__data_generation(batch_indexes)

            if self.augmentor:
                augmented_X = np.empty_like(X)
                augmented_y = np.empty_like(y)
                for i in range(X.shape[0]):
                    img_aug, lbl_aug = self.augmentor.random_transform(X[i].astype('float32'),
                                                                     y[i])
                    img_aug = self.augmentor.standardize(img_aug)
                    augmented_X[i] = img_aug
                    augmented_y[i] = lbl_aug
                return augmented_X, augmented_y, y_r,  filenames_
            else:
                # print('batch returned')
                return X, y , y_r , filenames_
        except Exception as e:
            raise

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if self.shuffle:
            np.random.shuffle(self.indexes)


    def __data_generation(self, batch_indexes):
        X = np.empty((len(batch_indexes), *self.dim, self.n_channels), dtype=np.float32)
        y = np.empty((len(batch_indexes), *self.dim, 1), dtype=np.float32)  # Single channel output
        y_r= np.empty((len(batch_indexes), *self.dim, 1), dtype=np.float32)

        filenames=[]
        for i, idx in enumerate(batch_indexes):
            filenames.append(self.data_files[idx])
            # Load and process input image
            img = cv2.imread(self.data_files[idx])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

            lbl = cv2.imread(self.label_files[idx])
            lbl = cv2.cvtColor(lbl, cv2.COLOR_BGR2GRAY)
            # print('img index: ', idx , ' image shape: ', img.shape , lbl.shape)
            # print('img values: ',np.min(img), np.max(img ))
            lbl2= cv2.imread(self.label_regions_files[idx])
            lbl2 = cv2.cvtColor(lbl2, cv2.COLOR_BGR2GRAY)

            height, width = img.shape[:2]
            # print(self.dim)
            if (height, width) != (self.dim[0] , self.dim[1]):
                pad_top = (self.target_size - height) // 2
                pad_bottom = self.target_size - height - pad_top
                pad_left = (self.target_size - width) // 2
                pad_right = self.target_size - width - pad_left

                img = cv2.copyMakeBorder(img, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=0)
                lbl = cv2.copyMakeBorder(lbl, pad_top, pad_bottom, pad_left, pad_right,  borderType=cv2.BORDER_CONSTANT, value=0)
                lbl2 = cv2.copyMakeBorder(lbl2, pad_top, pad_bottom, pad_left, pad_right,  borderType=cv2.BORDER_CONSTANT, value=0)
            # Process labels
            img =img[:, :, np.newaxis]
            lbl = lbl[:, :, np.newaxis]
            lbl2 = lbl2[:, :, np.newaxis]
            # print('lable  values: ' ,np.min(lbl), np.max(lbl) , ' uniq vals: ', np.unique(lbl))

            X[i,] = img
            y[i,] = lbl
            y_r[i,] = lbl2


        return  X,y ,y_r , filenames

In [None]:
def analyze_batches(generator, n_batches=2):
    fig, axes = plt.subplots(n_batches, 3, figsize=(15, 5*n_batches))

    for i in range(n_batches):
        # Get batch
        X_batch, y_batch = generator[i]

        # Sample random image from batch
        img_idx = np.random.randint(0, X_batch.shape[0])
        sample_img = X_batch[img_idx]
        sample_mask = y_batch[img_idx]

        # Plot
        axes[i,0].imshow(sample_img)
        axes[i,0].set_title(f'Batch {i+1} Image\nShape: {sample_img.shape}')
        axes[i,0].axis('off')

        axes[i,1].imshow(sample_mask.squeeze(), cmap='gray')
        axes[i,1].set_title(f'Batch {i+1} Mask\nMean: {sample_mask.mean():.4f}')
        axes[i,1].axis('off')

        # Histogram
        axes[i,2].hist(sample_img.flatten(), bins=50, alpha=0.7, label='Image')
        axes[i,2].hist(sample_mask.flatten(), bins=50, alpha=0.7, label='Mask')
        axes[i,2].set_title(f'Batch {i+1} Value Distribution')
        axes[i,2].legend()

    plt.tight_layout()
    plt.show()

    # Print statistics
    for i in range(n_batches):
        X_batch, y_batch = generator[i]
        print(f"\nBatch {i+1} Statistics:")
        print(f"Images - Min: {X_batch.min():.4f}, Max: {X_batch.max():.4f}, Mean: {X_batch.mean():.4f}, Std: {X_batch.std():.4f}")
        print(f"Masks  - Min: {y_batch.min():.4f}, Max: {y_batch.max():.4f}, Mean: {y_batch.mean():.4f}, Std: {y_batch.std():.4f}")

In [None]:

def test_batch_consistency(generator, n_test=5):
    means_X, means_y = [], []
    print(len(generator))
    for i in range(min(n_test, len(generator))):
        X, y = generator[i]
        means_X.append(X.mean())
        means_y.append(y.mean())

    plt.figure(figsize=(10,4))
    plt.subplot(121)
    plt.plot(means_X, 'o-')
    plt.title('Image Batch Means')
    plt.xlabel('Batch Number')

    plt.subplot(122)
    plt.plot(means_y, 'o-')
    plt.title('Mask Batch Means')
    plt.xlabel('Batch Number')

    plt.show()

    print(f"Image mean variation: {np.std(means_X):.4f}")
    print(f"Mask mean variation: {np.std(means_y):.4f}")


In [None]:

def detect_cells(density_map, sigma=2, threshold_factor=1.5):
    """Bulletproof cell detector"""
    # Input validation
    if density_map is None:
        raise ValueError("Input cannot be None")

    # Convert and normalize
    density_map = np.array(density_map, dtype=np.float32)
    if density_map.ndim == 3:
        density_map = density_map.squeeze()

    # Normalize if needed
    if np.max(density_map) > 1.0:
        density_map   /=np.max(density_map)  # Normalize by max value


    # Processing
    smoothed = gaussian_filter(density_map, sigma=sigma)
    threshold = np.mean(smoothed) + threshold_factor * np.std(smoothed)
    coords = peak_local_max(smoothed, min_distance=2, threshold_abs=threshold)

    return [(int(x), int(y)) for y, x in coords]


In [None]:
def normalize_to_uint8(image):
        if image.dtype == np.float32 or image.dtype == np.float64:
            image = np.clip(image, 0, 1)  # Ensure range [0, 1]
            image = (image * 255).astype(np.uint8)
        return image

In [None]:

def save_sample(img, label , gt,  folder_name , sample_file_name, lable_type ,saved_cell_counts, gt_cell_counts , predict_cell_counts):
    folder_name = folder_name + str(sample_file_name)+ "/"
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    # print('image size is :', folder_name , img.shape , label.shape , gt.shape)
    output_path = os.path.join(folder_name, 'counts.txt')
    with open(output_path, 'a') as file:
        # file.write('ground truth cell counts : '+str(gt_cell_counts) + '\n')
        # file.write('predicted cell counts : '+str(predict_cell_counts) + '\n')
        file.write(f'saved auto cell counts:{str(saved_cell_counts)}\n')
        file.write(f'Ground Truth Cells: {str(gt_cell_counts)}\n')
        file.write(f'Predicted Cells {str(lable_type)} : {str(predict_cell_counts)}\n\n\n')

    # squeezed_label =  np.squeeze(label)
    # squeezed_gt =  np.squeeze(gt)

    img = normalize_to_uint8(img)
    # squeezed_label = normalize_to_uint8(squeezed_label)
    # squeezed_gt = normalize_to_uint8(squeezed_gt)


    # # Ensure images are in the correct range
    # if img.dtype == np.float64 or img.dtype == np.float32:
    #     img = np.interp(img, (img.min(), img.max()), (0, 1))
    #
    # if squeezed_label.dtype == np.float64 or squeezed_label.dtype == np.float32:
    #     squeezed_label = np.interp(squeezed_label, (squeezed_label.min(), squeezed_label.max()), (0, 1))
    # if squeezed_gt.dtype == np.float64 or squeezed_gt.dtype == np.float32:
    #     squeezed_gt = np.interp(squeezed_gt, (squeezed_gt.min(), squeezed_gt.max()), (0, 1))

    print('image size is : ',img.shape)
    # Save images using OpenCV (preserves exact dimensions)
    cv2.imwrite(os.path.join(folder_name, 'image_1.jpg'), img)
    cv2.imwrite(os.path.join(folder_name, f'image_pr_{lable_type}.jpg'), (label * 255).astype(np.uint8))
    cv2.imwrite(os.path.join(folder_name, f'image_gt_{lable_type}.jpg'), (gt * 255).astype(np.uint8))

    #  Calculate figure size in inches (512 pixels / dpi)
    # dpi = 100  # Default is 100, but you can adjust
    # height, width = img.shape[0], img.shape[1]
    #
    # figsize = (width / dpi, height / dpi)
    #
    # # Save img
    # fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    # ax.imshow(img)
    # # ax.set_title('Image')
    # plt.axis('off')  # Turn off the axis
    # output_path = os.path.join(folder_name, 'image_1.jpg')
    # plt.savefig(output_path, bbox_inches='tight', pad_inches=0)  # Save without white spaces
    # plt.close(fig)  # Close the figure to free up memory
    # print(img.shape)
    #
    # # Save the third image
    # fig, ax = plt.subplots()
    # ax.imshow(squeezed_label)
    # # ax.set_title('Squeezed Label')
    # plt.axis('off')  # Turn off the axis
    # output_path = os.path.join(folder_name, 'image_pr_'+lable_type+'.jpg')
    # plt.savefig(output_path, bbox_inches='tight', pad_inches=0)  # Save without white spaces
    # plt.close(fig)  # Close the figure to free up memory
    #
    #  # Save the third image
    # fig, ax = plt.subplots()
    # ax.imshow(squeezed_gt)
    # # ax.set_title('Squeezed gt')
    # plt.axis('off')  # Turn off the axis
    # output_path = os.path.join(folder_name, 'image_gt_'+lable_type+'.jpg')
    # plt.savefig(output_path, bbox_inches='tight', pad_inches=0)  # Save without white spaces
    # plt.close(fig)  # Close the figure to free up memory
    # plt.show()

In [None]:
def show_save_all_sample(img , c_label , c_gt , folder_name , sample_file_name , saved_cell_counts , gt_counts  , predict_cell_counts , labletype):

    squeezed_c_label = np.squeeze(c_label)
    # Ensure images are in the correct range
    if img.dtype == np.float64 or img.dtype == np.float32:
        img = np.interp(img, (img.min(), img.max()), (0, 1))

    if squeezed_c_label.dtype == np.float64 or squeezed_c_label.dtype == np.float32:
        squeezed_c_label = np.interp(squeezed_c_label, (squeezed_c_label.min(), squeezed_c_label.max()), (0, 1))

    # figure, axis = plt.subplots(1, 3, figsize=(15, 5))
    # axis[0].imshow(img)
    # axis[0].set_title('Image')
    # axis[0].axis('off')
    #
    #
    # axis[1].imshow(c_gt)
    # axis[1].set_title('Centroids Ground Truth')
    # axis[1].axis('off')
    #
    # axis[2].imshow(squeezed_c_label)
    # axis[2].set_title('Centroids Prediction')
    # axis[2].axis('off')
    # plt.show()
    save_sample(img, squeezed_c_label ,c_gt, folder_name , sample_file_name, labletype ,saved_cell_counts,  gt_counts , predict_cell_counts  )

# Plotting Light-U-Net Predictions

In [None]:
model_regions = buildModel_Light_U_net(input_dim = (256,256,1))

model_regions.load_weights('./checkpoints/Light_U_Net_regions_256_20250921-013403.hdf5')


## Plotting Train & Test set

In [None]:
base_address= './Dataset/'
train_img_path = base_address + 'train/dic_tiles_256/'
train_centroids_img_path = base_address + 'train/gt_tiles_256/'
train_regions_img_path = base_address + 'train/regions_tiles/'

test_img_path = base_address + 'test/dic_tiles_256/'
test_centroids_img_path = base_address + 'test/gt_tiles_256/'
test_regions_img_path = base_address + 'test/regions_tiles_256/'

train_generator = CustomDataGenerator(
    data_dir=train_img_path,
    label_dir=train_centroids_img_path,
    label_dir_regions= train_regions_img_path,
    batch_size=4,
    dim=(256, 256),  # Target size
    n_channels=1,     # Grayscale
    shuffle=True,
    augmentor=None  # Optional
)
test_generator = CustomDataGenerator(
    data_dir=test_img_path,
    label_dir=test_centroids_img_path,
    label_dir_regions= test_regions_img_path,
    batch_size=4,
    dim=(256, 256),  # Target size
    n_channels=1,     # Grayscale
    shuffle=True,
    augmentor=None  # Optional
)
print('train batch length: ' , len(train_generator) , 'test batch length: ' , len(test_generator))


In [None]:
folder_name= './results/Light_U_Net/train_outputs_light_u_net_256/'
# os.makedirs(folder_name)
# saved_files= os.listdir(folder_name)
# print('sample: ', saved_files[0])

for i in range(max (1 , len(train_generator))):
    tr_X, tr_y ,tr_y_r, tr_filenames  = train_generator[i]
    print(tr_filenames)
    # te_X_, te_y_  ,te_y_r, te_filenames = test_generator[i]
    # print('train_gen data shape: ' ,tr_X.shape , tr_y.shape , 'label img min, mean, max : ', np.min(tr_y[0]) , np.mean(tr_y[0]), np.max(tr_y[0]))
    # print('test_gen data shape: ' ,te_X_.shape , te_y_.shape , 'label min, mean, max : ', np.min(te_y_[0]) , np.mean(te_y_[0]), np.max(te_y_[0]))

    pr_tr_x_r = model_regions.predict(tr_X)

    for i in range( max (1, tr_X.shape[0])):

        if pr_tr_x_r[i] is None:
            raise ValueError("Input density_map cannot be None")

        gt_cell_centroids = detect_cells(tr_y[i], sigma=2, threshold_factor=1.5)
        predict_cell_counts=  detect_cells(pr_tr_x_r[i], sigma=2.5, threshold_factor=1.5)
        # print("saved cell counts: " ,cell_counts[i] , "cell counts in gt: " , len(gt_cell_centroids), "cell counts in prediction: ", len(predict_cell_counts))

        show_save_all_sample(tr_X[i] , pr_tr_x_r[i], tr_y_r[i] , folder_name , tr_filenames[i].split('/')[-1] ,len(gt_cell_centroids) , len(gt_cell_centroids) , len(predict_cell_counts) ,'regions' )
        print("sample ",  tr_filenames[i].split('/')[-1] , " saved")
        # break

## Plotting Eval set

In [None]:
base_address= './Dataset/'
eval_img_path = base_address + 'eval/dic_tiles_256/'

eval_generator = CustomDataGenerator(
    data_dir=eval_img_path,
    label_dir=eval_img_path,
    label_dir_regions= eval_img_path,
    batch_size=4,
    dim=(256, 256),  # Target size
    n_channels=1,     # Grayscale
    shuffle=True,
    augmentor=None  # Optional
)

print('eval batch length: ' , len(eval_generator))


In [None]:
folder_name= './results/Light_U_Net/eval_outputs_light_u_net_256/'
# os.makedirs(folder_name)
# saved_files= os.listdir(folder_name)
# print('sample: ', saved_files[0])

for i in range(max (1 , len(eval_generator))):
    tr_X, tr_y ,tr_y_r, tr_filenames  = eval_generator[i]
    print(tr_filenames)
    # te_X_, te_y_  ,te_y_r, te_filenames = test_generator[i]
    # print('train_gen data shape: ' ,tr_X.shape , tr_y.shape , 'label img min, mean, max : ', np.min(tr_y[0]) , np.mean(tr_y[0]), np.max(tr_y[0]))
    # print('test_gen data shape: ' ,te_X_.shape , te_y_.shape , 'label min, mean, max : ', np.min(te_y_[0]) , np.mean(te_y_[0]), np.max(te_y_[0]))

    pr_tr_x_r = model_regions.predict(tr_X)

    for i in range( max (1, tr_X.shape[0])):

        if pr_tr_x_r[i] is None:
            raise ValueError("Input density_map cannot be None")
        gt_cell_centroids = detect_cells(tr_y[i], sigma=2, threshold_factor=1.5)
        predict_cell_counts=  detect_cells(pr_tr_x_r[i], sigma=2.5, threshold_factor=1.5)
        # print("saved cell counts: " ,cell_counts[i] , "cell counts in gt: " , len(gt_cell_centroids), "cell counts in prediction: ", len(predict_cell_counts))

        show_save_all_sample(tr_X[i] , pr_tr_x_r[i], tr_y_r[i] , folder_name , tr_filenames[i].split('/')[-1] ,len(gt_cell_centroids) , len(gt_cell_centroids) , len(predict_cell_counts) ,'regions' )
        print("sample ",  tr_filenames[i].split('/')[-1] , " saved")
        # break

# Plotting U-Net Predictions

In [None]:
model_regions = build_unet_model(input_dim = (256,256,1))

model_regions.load_weights('./checkpoints/U_Net_regions_20250908-155828.hdf5')


## Plotting Train & Test set

In [None]:
base_address= './Dataset/'
train_img_path = base_address + 'train/dic_tiles_256/'
train_centroids_img_path = base_address + 'train/gt_tiles_256/'
train_regions_img_path = base_address + 'train/regions_tiles_256/'

test_img_path = base_address + 'test/dic_tiles_256/'
test_centroids_img_path = base_address + 'test/gt_tiles_256/'
test_regions_img_path = base_address + 'test/regions_tiles_256/'

train_generator = CustomDataGenerator(
    data_dir=train_img_path,
    label_dir=train_centroids_img_path,
    label_dir_regions= train_regions_img_path,
    batch_size=4,
    dim=(256, 256),  # Target size
    n_channels=1,     # Grayscale
    shuffle=True,
    augmentor=None  # Optional
)
test_generator = CustomDataGenerator(
    data_dir=test_img_path,
    label_dir=test_centroids_img_path,
    label_dir_regions= test_regions_img_path,
    batch_size=4,
    dim=(256, 256),  # Target size
    n_channels=1,     # Grayscale
    shuffle=True,
    augmentor=None  # Optional
)
print('train batch length: ' , len(train_generator) , 'test batch length: ' , len(test_generator))



In [None]:

folder_name= './results/U_Net/train_outputs_u_net_256/'
# os.makedirs(folder_name)
# saved_files= os.listdir(folder_name)
# print('sample: ', saved_files[0])

for i in range(len(train_generator)):
    tr_X, tr_y ,tr_y_r, tr_filenames  = train_generator[i]
    # te_X_, te_y_  ,te_y_r, te_filenames = test_generator[i]
    print('test_gen data shape: ' ,tr_X.shape , tr_y.shape , 'label img min, mean, max : ', np.min(tr_y[0]) , np.mean(tr_y[0]), np.max(tr_y[0]))

    pr_tr_x_r = model_regions.predict(tr_X)


    for i in range( tr_X.shape[0]):
        # filename = counter* batch_size + i
        # if pr_tr_x_r[i] is None:
        #     raise ValueError("Input density_map cannot be None")

        if pr_tr_x_r[i] is None:
            raise ValueError("Input density_map cannot be None")
        gt_cell_centroids = detect_cells(tr_y[i], sigma=2, threshold_factor=1.5)
        predict_cell_counts=  detect_cells(pr_tr_x_r[i], sigma=2, threshold_factor=1)
        # print("saved cell counts: " ,cell_counts[i] , "cell counts in gt: " , len(gt_cell_centroids), "cell counts in prediction: ", len(predict_cell_counts))

        show_save_all_sample(tr_X[i] , pr_tr_x_r[i], tr_y_r[i] , folder_name , tr_filenames[i].split('/')[-1] ,len(gt_cell_centroids) , len(gt_cell_centroids) , len(predict_cell_counts) ,'regions' )
        print("sample ",  tr_filenames[i].split('/')[-1] , " saved")

    # break

## Plotting Eval set


In [None]:
folder_name= './results/U_Net/eval_outputs_u_net_256/'


for i in range(len(eval_generator)):
    tr_X, tr_y ,tr_y_r, tr_filenames  = eval_generator[i]
    # te_X_, te_y_  ,te_y_r, te_filenames = test_generator[i]
    print('test_gen data shape: ' ,tr_X.shape , tr_y.shape , 'label img min, mean, max : ', np.min(tr_y[0]) , np.mean(tr_y[0]), np.max(tr_y[0]))

    pr_tr_x_r = model_regions.predict(tr_X)


    for i in range( tr_X.shape[0]):

        if pr_tr_x_r[i] is None:
            raise ValueError("Input density_map cannot be None")
        gt_cell_centroids = detect_cells(tr_y[i], sigma=2, threshold_factor=1.5)
        predict_cell_counts=  detect_cells(pr_tr_x_r[i], sigma=2, threshold_factor=1)
        # print("saved cell counts: " ,cell_counts[i] , "cell counts in gt: " , len(gt_cell_centroids), "cell counts in prediction: ", len(predict_cell_counts))

        show_save_all_sample(tr_X[i] , pr_tr_x_r[i], tr_y_r[i] , folder_name , tr_filenames[i].split('/')[-1] ,len(gt_cell_centroids) , len(gt_cell_centroids) , len(predict_cell_counts) ,'regions' )
        print("sample ",  tr_filenames[i].split('/')[-1] , " saved")

    # break