In [1]:
from dask_image.imread import imread
import dask.array as da
import os
import numpy as np
import joblib
import dask  # Import Dask first
dask.config.set({'dataframe.query-planning': False})  # Disable query-planning

import dask.dataframe as dd  # Now import dask.dataframe
import pandas as pd
import dask.dataframe as dd
from spatialdata import read_zarr



In [2]:
from sklearn.base import BaseEstimator, TransformerMixin

class Feature_extractor(BaseEstimator, TransformerMixin):

    def __init__(self, functions, to_flatten=False):
        super().__init__()
        self.to_flatten = to_flatten
        self.functions = functions

    def fit(self, X, y=None):
        self.max_height = max(img.shape[0] for img in X)
        self.max_width = max(img.shape[1] for img in X)
        return self

    def transform(self, X):
        all_features = []

        for img in X:
            image_features = []

            img = img.astype(np.float32).squeeze()
            img = (img - img.min()) / (img.max() - img.min())
            image_features.append(img)

            h, w = img.shape

            x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
            x_coords = x_coords / self.max_width
            y_coords = y_coords / self.max_height
            image_features.append(x_coords)
            image_features.append(y_coords)

            for f, params in self.functions.items():
                image_features.append(f(img, **params))

            if self.to_flatten:
                image_features = [i.flatten() for i in image_features]
                all_features.append(np.stack(image_features, axis=1))
            else:
                all_features.append(np.stack(image_features))

        return np.array(all_features)

In [3]:
from sklearn.ensemble import RandomForestClassifier

class Custom_Random_Forest(RandomForestClassifier):
    def fit(self, X, y, sample_weight = None):
        images, feutures, H, W = X.shape
        X = np.transpose(X, (0, 2, 3, 1))

        X = X.reshape(images*H * W, feutures)
        y = y.reshape(images*H * W)

        return super().fit(X, y, sample_weight)
    
    def predict(self, X):
        images, feutures, H, W = X.shape
        X = np.transpose(X, (0, 2, 3, 1))

        X = X.reshape(images*H * W, feutures)
        pred = super().predict(X)
        return pred.reshape(images, H, W)

In [4]:
from scipy.ndimage import gaussian_filter, gaussian_gradient_magnitude, gaussian_laplace

trans = Feature_extractor(functions={
    gaussian_filter: {'sigma':1},
    gaussian_gradient_magnitude: {'sigma':1},
    gaussian_laplace: {'sigma':1},
})

In [5]:
from sklearn.pipeline import Pipeline

pipe = Pipeline([
    ("Feature extractor", trans),
    ("classifier", Custom_Random_Forest(n_jobs=-1))
])

In [9]:
from skimage.transform import resize

def resize_image_channelwise(image, new_shape=(64, 64), max_channels=30):
    C = min(image.shape[-1], max_channels)
    resized_channels = []
    for c in range(C):
        channel = image[:, :, c]
        resized = resize(channel, new_shape, preserve_range=True, anti_aliasing=True)
        resized_channels.append(resized)
    return np.stack(resized_channels, axis=-1)

In [10]:
data = []
labels = []

data_dir = r"E:\data\train"


for arg in os.listdir(data_dir):
    path = os.path.join(data_dir, arg)
    if os.path.isdir(path) and arg.endswith(".zarr"):
        sdata = read_zarr(path)

        # Resize all images (one by one to avoid memory spike)
        for name, image_obj in sdata.images.items():
            img = resize_image_channelwise(image_obj.data)
            data.append(img)  # not flattened

        # Assume annotation is 1 label per sample (not an image)
        ann = sdata.table["label_column"].values[0]  # adjust "label_column" as needed
        labels.extend([ann] * len(sdata.images))

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


KeyboardInterrupt: 

In [None]:
import joblib

for i, j in zip(data, labels):
    pipe.fit(i, j)
    joblib.dump(pipe, "random_forest.pkl")

MemoryError: Unable to allocate 1.78 GiB for an array with shape (15355, 15558) and data type float64

In [None]:
from PIL import Image

data_test = sdata['annotations'].values

img = Image.fromarray(data_test)
img.show()