In [14]:
import os
import numpy as np
from spatialdata import read_zarr
import dask.array as da

In [15]:
import dask.array as da
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
import random as r

class Feature_extractor(BaseEstimator, TransformerMixin):

    def __init__(self, functions, image_size=1024, to_flatten=False):
        super().__init__()
        self.to_flatten = to_flatten
        self.functions = functions
        self.max_height = None
        self.max_width = None
        self.image_size = image_size

    def fit(self, X, y=None):
        # Compute max height and width from Dask arrays
        heights = [img.shape[0] for img in X]
        widths = [img.shape[1] for img in X]
        self.max_height = max(heights)
        self.max_width = max(widths)
        return self

    def transform(self, X):
        all_features = []

        for img in X:
            # Normalize per image - compute min and max
            img = img.astype(np.float32).squeeze()
            img_min, img_max = da.compute(img.min(), img.max())
            img = (img - img_min) / (img_max - img_min + 1e-8)  # Avoid division by zero

            h, w = img.shape

            x = r.randint(0, w-self.image_size)
            y = r.randint(0, h-self.image_size)

            # Create normalized coordinates as Dask arrays
            x_coords, y_coords = da.meshgrid(da.arange(w), da.arange(h))
            x_coords = x_coords[y:y+self.image_size,x:x+self.image_size] / self.max_width
            y_coords = y_coords[y:y+self.image_size,x:x+self.image_size] / self.max_height

            img = img[y:y+self.image_size,x:x+self.image_size]

            image_features = [img, x_coords, y_coords]

            # Apply functions: ensure they handle NumPy arrays or compute Dask arrays before applying
            for f, params in self.functions.items():
                # Compute img to NumPy for function if needed
                img_np = img.compute() if isinstance(img, da.Array) else img
                feat = f(img_np, **params)
                feat_da = da.from_array(feat) if not isinstance(feat, da.Array) else feat
                image_features.append(feat_da)

            if self.to_flatten:
                # Flatten all features and stack as (pixels, features)
                flattened = [feat.flatten() if isinstance(feat, da.Array) else feat.ravel() for feat in image_features]
                stacked = da.stack(flattened, axis=1).rechunk({1: -1})
                all_features.append(stacked)
            else:
                stacked = da.stack(image_features)  # shape: (num_features, H, W)
                all_features.append(stacked)

        if self.to_flatten:
            # Return list of 2D arrays (pixels x features)
            return da.concatenate(all_features, axis=0)
        else:
            # Return list of 3D arrays (features x H x W)
            return all_features

In [16]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline

class Custom_Pipeline(Pipeline):

    def predict(self, X):
        images, feutures, H, W = X.shape
        X = da.transpose(X, (0, 2, 3, 1))

        X = X.reshape(images*H * W, feutures)
        pred = super().predict(X)
        return pred.reshape(images, H, W)


class Custom_Random_Forest(RandomForestClassifier):
    def predict(self, X):
        images, feutures, H, W = X.shape
        X = da.transpose(X, (0, 2, 3, 1))

        X = X.reshape(images*H * W, feutures)
        pred = super().predict(X)
        return pred.reshape(images, H, W)

In [17]:
from scipy.ndimage import gaussian_filter, gaussian_gradient_magnitude, gaussian_laplace

trans = Feature_extractor(functions={
    gaussian_filter: {'sigma':1},
    gaussian_gradient_magnitude: {'sigma':1},
    gaussian_laplace: {'sigma':1},
}, to_flatten=True)

In [18]:
# data_dir = r"E:\data\train"
# max_width = 0
# max_height = 0
# for arg in os.listdir(data_dir):
#     path = os.path.join(data_dir, arg)
#     if os.path.isdir(path) and arg.endswith(".zarr"):
#         sdata = read_zarr(path)

#         img = sdata['annotations'].data

#         h, w = img.shape

#         if max_width<w:
#             max_width = w

#         if max_height<h:
#             max_height = h

In [19]:
functions = {
    gaussian_filter: {'sigma':1},
    gaussian_gradient_magnitude: {'sigma':1},
    gaussian_laplace: {'sigma':1},
}

In [20]:
def get_features(img, label, functions, image_size=1024):

    h, w = img.shape
    x = r.randint(0, w - image_size)
    y = r.randint(0, h - image_size)

    img = img.astype(np.float32).squeeze()
    img_min, img_max = da.compute(img.min(), img.max())
    img = (img - img_min) / (img_max - img_min + 1e-8)

    # Positional encodings
    x_coords, y_coords = da.meshgrid(da.arange(w), da.arange(h))
    x_coords = x_coords[y:y+image_size, x:x+image_size] / max_width
    y_coords = y_coords[y:y+image_size, x:x+image_size] / max_height

    # Crop image
    img = img[y:y+image_size, x:x+image_size]

    image_features = [img, x_coords, y_coords]

    for f, params in functions.items():
        # Compute img to NumPy for function if needed
        img_np = img.compute() if isinstance(img, da.Array) else img
        feat = f(img_np, **params)
        feat_da = da.from_array(feat) if not isinstance(feat, da.Array) else feat
        image_features.append(feat_da)

    return da.array(image_features), label[y:y+image_size, x:x+image_size]

In [21]:
# from spatialdata import SpatialData
# from spatialdata.models import Image2DModel, Labels2DModel

# new_sdata = SpatialData()

# count = 0
# data_dir = r"E:\data\train"

# for arg in os.listdir(data_dir):
#     path = os.path.join(data_dir, arg)
#     if os.path.isdir(path) and arg.endswith(".zarr"):
#         sdata = read_zarr(path)  # Ensure this is defined elsewhere

#         label = sdata['annotations'].data

#         data = [i.data.squeeze() for i in list(sdata.images.values())]

#         for img in data:
#             image, l = get_features(img, label, functions)

#             new_sdata[f"channel_{count}"] = Image2DModel.parse(
#             image,
#             )

#             new_sdata[f"label_{count}"] = Labels2DModel.parse(
#             l,
#             )
#             count += 1

# new_sdata.write(os.path.join(r"E:\data", "preprocess.zarr"), overwrite=True)

In [22]:
sdata = read_zarr(r"E:\data\preprocess.zarr")

In [None]:
X = []
y = []

random_data = r.sample(range(len(sdata.labels)), 200)

for i in random_data:
    flattened = [i[..., 0:512, 0:512].flatten() for i in sdata[f"channel_{i}"].values]
    stacked = np.stack(flattened, axis=1).astype(np.float16)
    X.append(stacked)
    y.append(sdata[f"label_{i}"].values[0:512, 0:512].flatten()//255)

X = np.concatenate(X)
y = np.concatenate(y)

In [74]:
X.shape

(52428800, 6)

In [81]:
from dask.distributed import Client
import joblib

model = RandomForestClassifier(n_estimators=50, max_depth=10, min_samples_leaf=5, n_jobs=-1)

# with joblib.parallel_backend("dask"):
model.fit(X, y)

joblib.dump(model, "HET.pkl")

Perhaps you already have a cluster running?
Hosting the HTTP server on port 52555 instead


['HET.pkl']

In [82]:
pool = [num for num in range(len(sdata.labels)) if num not in random_data]

random_test = r.sample(pool, 100)

test_data = []
test_labels = []

for i in random_test:
    flattened = [i.flatten() for i in sdata[f"channel_{i}"].values]
    stacked = np.stack(flattened, axis=1).astype(np.float16)
    test_data.append(stacked)
    test_labels.append(sdata[f"label_{i}"].values//255)


test = np.concatenate(test_data)
label = np.concatenate(test_labels)

In [59]:
test.shape

(104857600, 6)

In [97]:
import joblib 


model = joblib.load("HET.pkl")

pred = model.predict(test)

In [98]:
pred = pred.reshape(100, 1024, 1024)

In [105]:
def iou_binary(arr1, arr2):
    assert arr1.shape == arr2.shape, "Arrays must have the same shape"
    
    intersection = np.logical_and(arr1, arr2).sum()
    union = np.logical_or(arr1, arr2).sum()

    if union == 0:
        return 1.0 if intersection == 0 else 0.0

    score = intersection / union if union != 0 else 0
    
    return score

In [100]:
np.unique(pred)

array([0., 1.], dtype=float16)

In [None]:
mean = []

for p, l in zip(pred, test_labels):
    score = iou_binary(p, l)
    mean.append(score)


In [107]:
np.mean(mean)

0.3797008837412497