In [38]:
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
from tifffile import imwrite, imsave
import datetime
from numba import jit, njit


In [14]:
image = imread('2-CrXAS-movie@577-5eV.tif')

In [39]:
def crop(imframe, n):
    """This is the old function for tiling the image, it is slower but more understandable than the 
    tile_image() function that uses np.reshape()"""
    M = int(imframe.shape[0] / n)
    N = int(imframe.shape[1] / n)
    tiles = [imframe[x:x + M, y:y + N] for x in range(0, imframe.shape[0], M) for y in range(0, imframe.shape[1], N)]
    return tiles


def tile_image(imframe: np.ndarray, tilesize: tuple):
    """Returns an array of image tiles with shape (X"""
    img_height, img_width = imframe.shape
    tileheight, tilewidth = tilesize

    tiled_array = imframe.reshape(img_height // tileheight, tileheight, img_width // tilewidth, tilewidth)
    tiled_array = tiled_array.swapaxes(1,2)
    return tiled_array


def reform_image(tilearray: np.ndarray, originalimage: np.ndarray):
    """Reforms an image from a tile array with shape (Y, X, n, m) where Y is the tile row number, X is the tile column
    number, n is the tile pixel height and m is the tile pixel width.  Returns an image with the shape of the original imag"""
    finalsize = originalimage.shape
    tilearray = tilearray.swapaxes(1,2)
    return np.reshape(tilearray, finalsize)


def plane_level(img: np.ndarray):
    """Takes a 2d image (numpy array), calculates the mean plane, and returns that 2d image after subtracting the
    mean plane"""

    """Create x, y, and z (pixel intensity) arrays"""
    dimensions = img.shape
    totalpoints = dimensions[0] * dimensions[1]
    xarray = np.linspace(0, totalpoints - 1, totalpoints) % dimensions[0]
    yarray = np.linspace(0, dimensions[1] - 1, totalpoints)  # a sequence that makes y values that correspond to the z points
    xarray = xarray.astype(int)
    yarray = yarray.astype(int)
    flatimg = img.flatten()
    leveled_flat = np.zeros(flatimg.shape)

    """Calculate the best fit plane through the datapoints"""
    cvals = np.ones(xarray.shape)  # make an array of zeros to represent the coefficients in the below equation
    coefs = np.linalg.lstsq(np.stack([xarray, yarray, cvals]).T, flatimg, rcond=None)[0]  # least squares solver for the coefficients of the plane
    mean = np.mean(flatimg)

    """Calculate the leveled Z value """
    for index, pixel in enumerate(flatimg):
        z_fit = (coefs[0] * xarray[index] + coefs[1] * yarray[index] + coefs[2])
        zdiff = z_fit-np.mean(mean)
        leveled_flat[index] = pixel-zdiff

    final_img = leveled_flat.reshape(img.shape)
    return final_img.astype(int)

@njit
def median_prominence_threshold(imgarray, prominence):
    """Takes an input array and thresholds it using median + prominence"""
    brlevel = np.median(imgarray)
    threshold = brlevel + prominence
    return (imgarray > threshold)

In [40]:
segmented_timeseries = np.zeros(image.shape)
for step, frame in enumerate(image[0:5,:,:]):
    tiles = tile_image(frame, (64, 64))
    segmented = np.zeros(tiles.shape)
    for i, tile in enumerate(tiles):
        for j, subtile in enumerate(tile):
            leveltile = plane_level(subtile)
            segmentedtile = median_prominence_threshold(leveltile,8000)
            segmented[i,j,:,:] =segmentedtile



    reformed = reform_image(segmented, frame)
    segmented_timeseries[step,:,:] = reformed.astype(np.uint8)
    print(datetime.datetime.now(), step)

2023-09-28 13:35:25.834105 0
2023-09-28 13:35:42.832782 1
2023-09-28 13:35:59.750210 2
2023-09-28 13:36:17.232004 3
2023-09-28 13:36:34.192269 4


In [9]:
imsave("mov_2_median_plus_8000.tif",segmented_timeseries.astype(np.uint8))