In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
from tifffile import imwrite, imsave
import datetime
from numba import jit, njit, prange
from functools import partial
import time 
import pandas as pd
from skimage.measure import label, regionprops, regionprops_table
import trackpy as tp
import matplotlib as mpl
import matplotlib.pyplot as plt
import datetime

Fast implementation of the tiling, plane leveling, and segmentation functions

In [None]:
#These are efficient implementations of tiling, plane leveling, and segmentation using numba

@partial(jit, nopython=True)
def swap(x, t_args):
    """Executes the np transpose function in a way that numba will understand, there is a bug with how numba understands np.transpose and this is a workaround"""
    return np.transpose(x, t_args)

@njit
def tile_image(imframe: np.ndarray, tilesize: tuple):
    """Returns an array of image tiles with shape (Y, X, n, m) where Y is the tile row number, X is the tile column
    number, n is the tile pixel height and m is the tile pixel width"""
    img_height, img_width = imframe.shape
    tileheight, tilewidth = tilesize
    imcopy = imframe.copy() #workaround for numba reshape not supporting non-contiguous arrays
    tiled_array = imcopy.reshape(img_height // tileheight, tileheight, img_width // tilewidth, tilewidth)
    out_array = swap(tiled_array,(0,2,1,3))
    return out_array

@njit
def reform_image(tilearray: np.ndarray, originalsize: tuple):
    """Reforms an image from a tile array with shape (Y, X, n, m) where Y is the tile row number, X is the tile column
    number, n is the tile pixel height and m is the tile pixel width.  Returns an image with the shape of the original imag"""
    
    reordered = swap(tilearray,(0,2,1,3))
    contiguous = reordered.copy() #numba doesn't support numpy reshaping of non-congiguous arrays so you have to copy it as a workaround
    reformed = contiguous.reshape(originalsize)
    return reformed

@njit
def linalg_test(solvematrix, zmatrix):
    coefs = np.linalg.lstsq(solvematrix, zmatrix)[0]
    return coefs

@njit
def plane_level_njit(img: np.ndarray):
    """Takes a 2d image (numpy array), calculates the mean plane, and returns that 2d image after subtracting the
    mean plane"""

    """Create x, y, and z (pixel intensity) arrays"""
    dimensions = img.shape
    totalpoints = dimensions[0] * dimensions[1]
    xarray = np.linspace(0, totalpoints - 1, totalpoints) % dimensions[0]
    yarray = np.linspace(0, dimensions[1] - 1, totalpoints)  # a sequence that makes y values that correspond to the z points
    flatimg = img.flatten()+0.0

    """Calculate the best fit plane through the datapoints"""
    cvals = xarray*0+1.0  #make the cvals floats

    xyc = np.column_stack((xarray, yarray, cvals))
    coefs = np.linalg.lstsq(xyc, flatimg)[0]  #njit and lstsq require all the datatypes to be the same

    leveled_flat = flatimg-((coefs[0]*xarray+coefs[1]*yarray+coefs[2])-np.mean(flatimg))

    final_img = leveled_flat.reshape(img.shape)
    return final_img

@njit
def median_prominence_threshold(imgarray, prominence):
    """Takes an input array and thresholds it using median + prominence"""
    brlevel = np.median(imgarray)
    threshold = brlevel + prominence
    thresholded = (imgarray > threshold)
    
    
    return thresholded

#This is the main function that levels and segments an image
@njit(parallel=True)
def batch_level_segment(image: np.array, tileshape: tuple, imageshape: tuple, threshold):
    """efficient numba implementation of a function that takes a 3d image as a np array, tiles it, levels it, 
    and segments it """
    segmented_timeseries = np.zeros(image.shape)

    for frameid in range(image.shape[0]):
        
        frame = image[frameid,:,:]
        tiles = tile_image(frame, tileshape)
        segmented = np.zeros(tiles.shape)

        for i in prange(tiles.shape[0]):
            for j in prange(tiles.shape[1]):
                tile = tiles[i,j,:,:]
                leveltile = plane_level_njit(tile)
                segmentedtile = median_prominence_threshold(leveltile,threshold)
                segmented[i,j,:,:] = segmentedtile


        reformed = reform_image(segmented, imageshape)
        segmented_timeseries[frameid,:,:] = reformed.astype(np.uint8)
        print(frameid)

    return segmented_timeseries

In [None]:
#This is an implementation that subtracts masks from the file, it is the same as the function defined in crop_and_norm.ipynb
def mask_subtract(data: str, mask: str, outfile: str):
    """takes a numpy array and multiplies it buy a mask.  Data and mask must be the same dimensions, and mask must be valued 1-0."""
    image = imread(data)
    print(image.shape, image.dtype)
    original_dtype = image.dtype
    mask = imread(str(mask))
    print(mask.shape, mask.dtype)
    #this should normalize the mask to 0-1.  
    mask = (mask-np.min(mask))/(np.max(mask)-np.min(mask))
    print(np.min(mask), np.max(mask))

    try:
        try:
            for slice in range(image.shape[0]):
                image[slice,:,:] = image[slice,:,:]*mask[:,:,0]
        except:
                for slice in range(image.shape[0]):
                    image[slice,:,:] = image[slice,:,:]*mask[:,:]
    except: 
        print("Mask/Image dimension mismatch, masking skipped")

    image = image.astype(original_dtype)
    imsave(outfile, image)
    return image

In [None]:
#This function finds particles from individual image frames by tracking connected groups of pixels between a minsize and maxsize
def findparticles_3d_img(frames: np.array, minsize=5, maxsize = 200):
    """
        This function takes a 3d image "frames" with the structure [frame, x, y]

       Create a pandas dataframe, and then find the particles (regions) using skimage's label function.  Label connects
       regions of the same integer value, i.e. segmented regions. In this dataframe, I also save the perimeter, filled fraction,
       and the area.
    """
    features = pd.DataFrame()
    for i in range(frames.shape[0]):
        label_image = label(frames[i,:,:], background=0)
        props = regionprops_table(label_image,properties=("centroid", "area"))
        framedata = pd.DataFrame(props)
        framedata['frame'] = i
        features = features.append(framedata)
        print(datetime.datetime.now(), i)

    filtered = features[features['area'] >= minsize]
    filtered = filtered[filtered['area'] <= maxsize]
  
    return filtered

In [None]:
#Read a hyperspectral image
#image = imread('2-CrXAS-movie@577-5eV.tif')
#image = imread('3-CrXAS-movie@577-5eV.tif')
image = imread('/Users/apple/Sync/Research/NSLS Experiments 7-16-23/20230716_Ni5Cr/Oxidation XAS mapping/XAS_30um_2CA_1AN_570-584eV_0.2eV_step_Cr2p_oxidation2_region2.tif')

In [None]:
#Perform leveling and segmentation of an image, defining the tile size, original image size, and threshold 
output = batch_level_segment(image ,(128,128), (1024,1024),200)
segmented = output.astype(np.uint8)

imsave("cr5_ox2_200.tif",segmented)
#imsave("mov3_median_8000.tif",segmented)
#imsave("mov2_median_8000.tif", segmented)

In [None]:
#Mask the image after leveling and segmenting to avoid plane leveling artifacts at the mask edges
masked = mask_subtract("mov2_median_8000.tif", "2nd_oxidation_mask.tif", "mov2_masked_median_8000.tif")

In [None]:
#Finds the particles in the hyperspectral image and puts them into a .csv file 
#total = imread("mov2_masked_median_8000.tif")
#total= imread("/Users/apple/vscode/Research/Oxidation 3 results/median_8000_masked.tif")
total = imread("cr5_ox2_200.tif")
features = findparticles_3d_img(total, minsize=4, maxsize = 200)     #minsize of 5 for the crox particles, and 15 for the dark regions
features.to_csv("cr5_ox2_200_minsize_4.csv")