### Installation and Requirements

Please refer to the [_featureforest repo_](https://github.com/juglab/featureforest)

In [None]:
import pickle
from pathlib import Path

import h5py
import numpy as np
import torch
from PIL import Image, ImageSequence
from sklearn.ensemble import RandomForestClassifier

from tqdm.notebook import trange, tqdm

from featureforest.models import get_available_models, get_model
from featureforest.models.SAM import SAMAdapter
from featureforest.utils.data import (
    patchify,
    is_image_rgb, get_stride_margin,
    get_num_patches, get_stride_margin
)
from featureforest.postprocess import (
    postprocess,
    postprocess_with_sam, postprocess_with_sam_auto,
    get_sam_auto_masks
)


### Utility functions

In [2]:
def get_slice_features(
    image: np.ndarray,
    patch_size: int,
    overlap: int,
    model_adapter,
    storage_group,
):
    """Extract the model features for one slice and save them into storage file."""
    # image to torch tensor
    img_data = torch.from_numpy(image).to(torch.float32) / 255.0
    # for sam the input image should be 4D: BxCxHxW ; an RGB image.
    if is_image_rgb(image):
        # it's already RGB, put the channels first and add a batch dim.
        img_data = img_data[..., :3]  # ignore the Alpha channel (in case of PNG).
        img_data = img_data.permute([2, 0, 1]).unsqueeze(0)
    else:
        img_data = img_data.unsqueeze(0).unsqueeze(0).expand(-1, 3, -1, -1)

    # get input patches
    data_patches = patchify(img_data, patch_size, overlap)
    num_patches = len(data_patches)

    # set a low batch size
    batch_size = 8
    # for big SAM we need even lower batch size :(
    if isinstance(model_adapter, SAMAdapter):
        batch_size = 2

    num_batches = int(np.ceil(num_patches / batch_size))
    # prepare storage for the slice embeddings
    total_channels = model_adapter.get_total_output_channels()
    stride, _ = get_stride_margin(patch_size, overlap)

    if model_adapter.name not in storage_group:
        dataset = storage_group.create_dataset(
            model_adapter.name, shape=(num_patches, stride, stride, total_channels)
        )
    else:
        dataset = storage_group[model_adapter.name]

    # get sam encoder output for image patches
    # print("\nextracting slice features:")
    for b_idx in tqdm(range(num_batches), desc="extracting slice feature:"):
        # print(f"batch #{b_idx + 1} of {num_batches}")
        start = b_idx * batch_size
        end = start + batch_size
        slice_features = model_adapter.get_features_patches(
            data_patches[start:end].to(model_adapter.device)
        )
        if not isinstance(slice_features, tuple):
            # model has only one output
            num_out = slice_features.shape[0]  # to take care of the last batch size
            dataset[start : start + num_out] = slice_features
        else:
            # model has more than one output: put them into storage one by one
            ch_start = 0
            for feat in slice_features:
                num_out = feat.shape[0]
                ch_end = ch_start + feat.shape[-1]  # number of features
                dataset[start : start + num_out, :, :, ch_start:ch_end] = feat
                ch_start = ch_end


def predict_slice(
    rf_model, patch_dataset, model_adapter,
    img_height, img_width, patch_size, overlap
):
    """Predict a slice patch by patch"""
    segmentation_image = []
    # shape: N x target_size x target_size x C
    feature_patches = patch_dataset[:]
    num_patches = feature_patches.shape[0]
    total_channels = model_adapter.get_total_output_channels()
    stride, margin = get_stride_margin(patch_size, overlap)

    for i in tqdm(
        range(num_patches), desc="Predicting slice patches", position=1, leave=True
    ):
        input_data = feature_patches[i].reshape(-1, total_channels)
        predictions = rf_model.predict(input_data).astype(np.uint8)
        segmentation_image.append(predictions)

    segmentation_image = np.vstack(segmentation_image)
    # reshape into the image size + padding
    patch_rows, patch_cols = get_num_patches(
        img_height, img_width, patch_size, overlap
    )
    segmentation_image = segmentation_image.reshape(
        patch_rows, patch_cols, stride, stride
    )
    segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape(
        patch_rows * stride,
        patch_cols * stride
    )
    # skip paddings
    segmentation_image = segmentation_image[:img_height, :img_width]

    return segmentation_image


def apply_postprocessing(
    input_image, segmentation_image,
    smoothing_iterations, area_threshold, area_is_absolute,
    use_sam_predictor, use_sam_autoseg, iou_threshold
):
    post_masks = {}
    # if not use_sam_predictor and not use_sam_autoseg:
    mask = postprocess(
        segmentation_image, smoothing_iterations,
        area_threshold, area_is_absolute
    )
    post_masks["Simple"] = mask

    if use_sam_predictor:
        mask = postprocess_with_sam(
            segmentation_image,
            smoothing_iterations, area_threshold, area_is_absolute
        )
        post_masks["SAMPredictor"] = mask

    if use_sam_autoseg:
        sam_auto_masks = get_sam_auto_masks(input_image)
        mask = postprocess_with_sam_auto(
            sam_auto_masks,
            segmentation_image,
            smoothing_iterations, iou_threshold,
            area_threshold, area_is_absolute
        )
        post_masks["SAMAutoSegmentation"] = mask


    return post_masks

### Set the Input, RF Model and the result directory paths

In [None]:
# input image
data_path = "../datasets/Johan/dino_host/test_substacks/Stack02_bin4_(1-3598-3).tif"
data_path = Path(data_path)
print(f"data_path exists: {data_path.exists()}")

# random forest model
rf_model_path = "../datasets/Johan/dino_host/mito/host/rf_model.bin"
rf_model_path = Path(rf_model_path)
print(f"rf_model_path exists: {rf_model_path.exists()}")

# result folder
segmentation_dir = Path("../datasets/Johan/dino_host/mito/host/segmentation_result")
segmentation_dir.mkdir(parents=True, exist_ok=True)

# temporary storage path for saving extracted embeddings patches
storage_path = "./temp_storage.hdf5"

### Prepare the Input and RF Model

In [None]:
# get patch sizes
input_stack = Image.open(data_path)

num_slices = input_stack.n_frames
img_height = input_stack.height
img_width = input_stack.width

print(num_slices, img_height, img_width)
# print(patch_size, target_patch_size)

In [None]:
with open(rf_model_path, mode="rb") as f:
    rf_model = pickle.load(f)
    rf_model.set_params(verbose=0)

rf_model

### Initializing the Model for Feature Extraction

In [None]:
# list of available models
get_available_models()

In [7]:
model_name = "MobileSAM"

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"running on {device}")

In [None]:
model_adapter = get_model(model_name, img_height, img_width)

patch_size = model_adapter.patch_size
overlap = model_adapter.overlap

patch_size, overlap

In [9]:
# post-processing parameters
do_postprocess = True

smoothing_iterations = 25
area_threshold = 100        # to ignore mask regions with area below this threshold
area_is_absolute = True    # is area is based on pixels or pecentage (False)

use_sam_predictor = True
use_sam_autoseg = False
sam_autoseg_iou_threshold = 0.4

### Prediction

In [10]:
# create the slice temporary storage
storage = h5py.File(storage_path, "w")
storage_group = storage.create_group("slice")

In [None]:
for i, page in tqdm(
    enumerate(ImageSequence.Iterator(input_stack)),
    desc="Slices", total=num_slices, position=0
):
    # print(f"slice {i + 1}", end="\n")
    slice_img = np.array(page.convert("RGB"))

    get_slice_features(slice_img, patch_size, overlap, model_adapter, storage_group)

    segmentation_image = predict_slice(
        rf_model, storage_group[model_adapter.name], model_adapter,
        img_height, img_width,
        patch_size, overlap
    )

    img = Image.fromarray(segmentation_image)
    img.save(segmentation_dir.joinpath(f"slice_{i:04}_prediction.tiff"))

    if do_postprocess:
        post_masks = apply_postprocessing(
            slice_img, segmentation_image,
            smoothing_iterations, area_threshold, area_is_absolute,
            use_sam_predictor, use_sam_autoseg, sam_autoseg_iou_threshold
        )
        # save results
        for name, mask in post_masks.items():
            img = Image.fromarray(mask)
            seg_dir = segmentation_dir.joinpath(name)
            seg_dir.mkdir(exist_ok=True) 
            img.save(seg_dir.joinpath(f"slice_{i:04}_{name}.tiff"))



if storage is not None:
    storage.close()
    storage = None
Path(storage_path).unlink()

In [12]:
if storage is not None:
    storage.close()
    Path(storage_path).unlink()