# Import and settings

In [1]:
from honeybee_comb_inferer.inference import HoneyBeeCombInferer
import sysconfig
import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import gc
import torch
import cv2
from tqdm import tqdm

import math
import tempfile
from joblib import Parallel, delayed
from scipy.stats import mode

2025-03-19 19:41:57.593399: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-19 19:41:57.647220: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-19 19:41:57.670261: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-19 19:41:57.676425: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-19 19:41:57.734188: I tensorflow/core/platform/cpu_feature_guar

In [2]:
possible_models = ["unet_effnetb0", "deeplabv3_resnet18","unet_resnet18", "unet_resnet34"]
model_name = "unet_effnetb0"
data_path = sysconfig.get_path('data')
path_to_pretrained_models = os.path.join(data_path, "models")
device = "cuda"
model = HoneyBeeCombInferer(model_name = model_name, path_to_pretrained_models = path_to_pretrained_models, device = device)

  state_dict = torch.load(path_to_state_dict)


# Functions

In [None]:
def create_masked_images(folder, model):
    """
    For each image in `folder`, run your comb segmentation method to get a mask,
    then set any bee-class pixels to 0 (black) in the original image.
    Save these masked images to a subfolder 'masked/' under `folder`.
    """
    masked_dir = os.path.join(folder, "masked")
    os.makedirs(masked_dir, exist_ok=True)

    # Get image files
    image_files = sorted(glob(os.path.join(folder, "*.[pj][np][ge]*")))

    for img_path in tqdm(image_files, desc=f"Masking bees in {folder}"):
        # Read the original image (assuming grayscale)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Warning: could not read {img_path}")
            continue

        # Run segmentation to get mask, e.g. single-image variant:
        # pred_mask = model.infer(img_path)
        # Or if you have a direct single-image function:
        pred_mask = model.infer(img_path, return_logits=False)  # shape (H, W), ints 0–8

        # Zero out pixels labeled as bee or bee-in-cell
        bee_pixels = (pred_mask==1) | (pred_mask==8)
        img[bee_pixels] = 0  # set bees to black

        # Save to 'masked' subfolder with same base name
        base_name = os.path.basename(img_path)
        out_path = os.path.join(masked_dir, base_name)
        cv2.imwrite(out_path, img)

    print(f"Saved masked images to: {masked_dir}")
    return masked_dir



def create_background_masked_median_mode(folder, window_size=10, tile_size=(512,512), sampling_rate=5, use_median=True):
    """
    1) Uses the masked images in `folder/masked` (where black=0 means 'no data' for those pixels).
    2) Rolls over these images to compute a 'rolling median' stack.
    3) For each pixel (in tile-based fashion), we gather the median frames, *ignore zeros*,
       and compute the final pixel either via median or mode.
    """
    masked_folder = os.path.join(folder, "masked")
    file_list = sorted(glob(os.path.join(masked_folder, "*.[pj][np][ge]*")))
    # Filter out any existing background or unneeded files
    file_list = [f for f in file_list if "background" not in os.path.basename(f).lower()]
    file_list = file_list[::sampling_rate]
    num_files = len(file_list)
    print("Number of masked images:", num_files)

    if num_files < window_size:
        raise ValueError("Not enough images to apply rolling median; adjust window_size or sampling_rate.")

    def read_image(filepath):
        return cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)

    # Read first image to get shape
    first_img = read_image(file_list[0])
    if first_img is None:
        raise ValueError("Cannot read the first masked image.")
    H, W = first_img.shape
    print(f"Image shape: {H} x {W}")

    num_medians = num_files - window_size + 1
    print("Number of median images to produce:", num_medians)

    # Create a memmap file to hold the 'median images'
    memmap_file = os.path.join(tempfile.gettempdir(), "median_images.dat")
    median_memmap = np.memmap(memmap_file, dtype='uint8', mode='w+', shape=(num_medians, H, W))

    # Rolling median setup (ignoring black pixels in each window)
    window_imgs = []
    for f in file_list[:window_size - 1]:
        img = read_image(f)
        if img is not None:
            window_imgs.append(img)

    median_index = 0
    for f in tqdm(file_list[window_size - 1:], desc="Computing rolling medians (masked)"):
        img = read_image(f)
        if img is None:
            continue
        window_imgs.append(img)
        if len(window_imgs) == window_size:
            # Stack images: shape = (window_size, H, W)
            stack_ = np.stack(window_imgs, axis=0)
            # Create a masked array where pixels equal to 0 are masked out
            masked_stack = np.ma.masked_equal(stack_, 0)
            # Compute median along axis=0, ignoring masked (black) pixels.
            # For pixels where all values are masked, fill with 0.
            median_img = np.ma.median(masked_stack, axis=0).filled(0).astype(np.uint8)
            median_memmap[median_index, :, :] = median_img
            median_index += 1
            window_imgs.pop(0)

    median_memmap.flush()
    print("Rolling median images computed.")

    # Now compute the final pixel ignoring black, either via median or mode
    median_memmap = np.memmap(memmap_file, dtype='uint8', mode='r', shape=(num_medians, H, W))
    background = np.zeros((H, W), dtype=np.uint8)

    n_tiles_y = math.ceil(H / tile_size[0])
    n_tiles_x = math.ceil(W / tile_size[1])
    print(f"Processing background in {n_tiles_y} x {n_tiles_x} tiles, ignoring black=0 pixels...")

    def process_tile(i, j):
        i_end = min(i + tile_size[0], H)
        j_end = min(j + tile_size[1], W)
        # Extract tile of shape (num_medians, tile_h, tile_w)
        tile_stack = median_memmap[:, i:i_end, j:j_end]
        N, th, tw = tile_stack.shape

        # Flatten each (th, tw) patch across N frames => shape (N, th*tw)
        tile_flat = tile_stack.reshape(N, -1)
        out_tile = np.zeros((th*tw,), dtype=np.uint8)

        for k in range(th*tw):
            pixel_values = tile_flat[:, k]
            # Filter out zeros
            nonzero = pixel_values[pixel_values != 0]
            if len(nonzero) == 0:
                # No valid data => keep it black
                out_tile[k] = 0
            else:
                if use_median:
                    out_tile[k] = np.median(nonzero).astype(np.uint8)
                else:
                    # Use mode from scipy, ignoring zeros
                    # The mode can return multiple values, but we only need the first
                    val, _ = mode(nonzero)
                    out_tile[k] = val[0].astype(np.uint8)

        return i, i_end, j, j_end, out_tile.reshape(th, tw)

    # Parallel tile processing
    results = Parallel(n_jobs=8)(
        delayed(process_tile)(i, j)
        for i in range(0, H, tile_size[0])
        for j in range(0, W, tile_size[1])
    )

    for (i, i_end, j, j_end, tile_result) in results:
        background[i:i_end, j:j_end] = tile_result

    # Display and save
    out_path = os.path.join(folder, "background_masked_ignoreblack.png")
    cv2.imwrite(out_path, background)
    print("Masked background (ignoring black) saved to:", out_path)

    plt.figure(figsize=(24, 16))
    plt.imshow(background, cmap='gray')
    if use_median:
        plt.title("Background from masked frames (Median, ignoring black)")
    else:
        plt.title("Background from masked frames (Mode, ignoring black)")
    plt.axis("off")
    plt.show()

def apply_clahe_gray(img, clipLimit=2.0, tileGridSize=(8, 8)):
    """
    Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to a grayscale image.

    If the image is in float format (0-1), it is converted to uint8 (0-255) before processing.

    Parameters:
    - img: Input grayscale image (can be float [0-1] or uint8 [0-255])
    - clipLimit: Contrast limit for CLAHE
    - tileGridSize: Size of grid for CLAHE

    Returns:
    - CLAHE-enhanced grayscale image (uint8)
    """

    # Convert float (0-1) to uint8 (0-255)
    if img.dtype in [np.float32, np.float64]:
        img = (img * 255).clip(0, 255).astype(np.uint8)  # Scale & convert

    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    return clahe.apply(img)

# Run 

In [None]:
base_folder = "/home/beesbook/mnt/scratch/beesbook2024/extracted_frames/"
datecam_folders = ["20240720/cam-0", "20240725/cam-1","20240621/cam-1","20240720/cam-2","20240621/cam-3"]
datecam_folders = ["20240725/cam-1","20240621/cam-1","20240720/cam-2","20240621/cam-3"]

for datecam_folder in datecam_folders:
    folder = os.path.join(base_folder,datecam_folder)
    create_masked_images(folder,model)
    create_background_masked_median_mode(folder,
                                         window_size=10,
                                         tile_size=(512, 512),
                                         sampling_rate=1,
                                         use_median=True)  # or use_median=False for mode

Masking bees in /home/beesbook/mnt/scratch/beesbook2024/extracted_frames/20240725/cam-1:   2%| | 16/1056 [01:33<1:40:16,  5.79s/i

In [None]:
# apply clahe to results, and copy to local directory to look at
import shutil
base_folder = "/home/beesbook/mnt/scratch/beesbook2024/extracted_frames/"
datecam_folders = ["20240720/cam-0", "20240725/cam-1","20240621/cam-1","20240720/cam-2","20240621/cam-3"]
for datecam_folder in datecam_folders:
    folder = os.path.join(base_folder,datecam_folder)
    filename = os.path.join(folder,"background_masked_ignoreblack.png")
    img = plt.imread(filename)
    limg = apply_clahe_gray(img)
    outfile = 'combseg-output/'+'background_masked_ignoreblack_'+datecam_folder.replace('/','_')+'.png'
    cv2.imwrite(outfile,limg)
    print('wrote to',outfile)

wrote to combseg-output/background_masked_ignoreblack_20240720_cam-0.png
wrote to combseg-output/background_masked_ignoreblack_20240725_cam-1.png
wrote to combseg-output/background_masked_ignoreblack_20240621_cam-1.png
wrote to combseg-output/background_masked_ignoreblack_20240720_cam-2.png
wrote to combseg-output/background_masked_ignoreblack_20240621_cam-3.png
