<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/Image/NeuN_Volumes_Cellpose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib
import tifffile
import os
import numpy as np
import matplotlib.pyplot as plt
import skimage as ski
import pandas as pd
from tqdm import tqdm
import scipy as sci
from cellpose import models
import statsmodels.api as sm
from statsmodels.formula.api import ols

In [None]:
"""
IMAGE MANIPULATION
"""
def dimension_to_front(data,dim_x):
    dim_list = list(np.arange(len(data.shape)))
    transpose = [dim_list[dim_x]] + dim_list[:dim_x] + dim_list[dim_x+1:]
    data_t = data.transpose(transpose)
    data_t =data_t.astype('float16')
    invs_transpose = list(np.argsort(transpose))
    return data_t, transpose, invs_transpose

def med_filt_stack(img,d):
    tiff_format = 'zcyx'
    color_dim = tiff_format.index('c')
    z_dim = tiff_format.index('z')
    color_dim=1
    z_dim=0
    for c in range(img.shape[color_dim]):
        for z in range(img.shape[z_dim]):
            img[z,c,:,:] = ski.filters.median(img[z,c,:,:],footprint=np.ones([d,d]))
    return img

def tiff_to_RGB_stack(tiff_data,rbg_channel_order = [1,0,0]):
    tiff_data = tiff_data[:,rbg_channel_order,:,:]
    tiff_data[:,1,:,:]=0
    rgb_stack = tiff_data.transpose([0,2,3,1])*225
    rgb_stack[rgb_stack<0]=0
    rgb_stack[rgb_stack>255]=255
    rgb_stack=rgb_stack.astype('uint8')
    return rgb_stack

def preprocess_image(tiff_data):
    tiff_data = ski.filters.gaussian(
        tiff_data, sigma=[1,1,1],
        mode='reflect',truncate = 4,
        channel_axis=1)
    tiff_data = med_filt_stack(tiff_data,5)
    tiff_data = norm_channels(tiff_data,'zcxy',clip=True,pct=[99,1])
    return tiff_data

def split_proj(tiff_data):
    print(tiff_data.shape)
    num_c = tiff_data.shape[3]
    fig,ax=plt.subplots(1,num_c,figsize=(num_c*1,1),dpi=300)
    colors = ['Reds','Greens','Blues','gray']
    for c in range(num_c):
        ax[c].imshow(np.max(tiff_data[:,:,:,c],axis=0), cmap=colors[c],vmin=0,vmax=255)
        ax[c].axis('off')
    plt.tight_layout()

def norm_channels(im_data,tiff_format,clip=False,pct=[99,1]):
    print(im_data.shape)
    color_dim = tiff_format.index('c')
    print('norm color dim =',color_dim)
    data_t, transpose, invs_transpose = dimension_to_front(im_data,color_dim)
    num_channels = data_t.shape[0]
    for c in np.arange(num_channels):
        c_data = data_t[c,:]
        (top,bot) = np.percentile(c_data.flatten(),pct)
        c_data = (c_data-bot)/(top-bot)
        data_t[c,:] = c_data
    im_data = data_t.transpose(invs_transpose)
    if clip:
        im_data[im_data>1]=1
        im_data[im_data<0]=0
    return im_data

In [None]:
"""
Model Building
"""
def fit_cell_model(img_rgb, cell_channel=1, nuclear_channel=0,channel_axis=3):
    cell_diam=40
    cell_model = models.Cellpose(model_type='cyto3',gpu=True)
    (cell_masks, flows, styles, diams) = cell_model.eval(
        img_rgb, channels=[cell_channel,nuclear_channel],
        channel_axis = channel_axis, diameter=cell_diam,
        do_3D=False, stitch_threshold=0.02)

    results_dict = {'cell_masks':cell_masks,
                'flows':flows,
                'styles':styles,
                'diams':diams,
                'cell_model':cell_model,}
    return results_dict

def fit_nuclear_model(img_rgb,nuclear_channel=3,channel_axis=3,nuc_diam = None):

    nuclear_model = models.Cellpose(model_type='nuclei',gpu=True)
    (nuclear_masks, _, _, _)=  nuclear_model.eval(
        img_rgb, channels=[nuclear_channel, 0],
        channel_axis = channel_axis, diameter=nuc_diam,
        do_3D=False, stitch_threshold=0.02,)
    return nuclear_masks

def save_masks(cell_masks,nuclear_masks,cell_masks_og,curr_tif):
    image_name = os.path.splitext(os.path.basename(curr_tif))[0]
    np.save( image_name+'_cell_masks',cell_masks,allow_pickle=True)
    np.save( image_name+'_nuclear_masks',nuclear_masks,allow_pickle=True)
    np.save(a image_name+'_cell_masks_og',cell_masks_og,allow_pickle=True)

def load_masks(curr_tif):
    image_name = os.path.splitext(os.path.basename(curr_tif))[0]
    cell_masks = np.load( image_name+'_cell_masks.npy')
    nuclear_masks = np.load( image_name+'_nuclear_masks.npy')
    cell_masks_og = np.load( image_name+'_cell_masks_og.npy')
    return (cell_masks,nuclear_masks,cell_masks_og)

In [None]:
"""
Analyze Models
"""
def find_child_nucleii(cell_masks,nuclear_masks,cell_df):
    cell_df['nucleii']=np.nan
    cell_df['nuc_count']=np.nan
    cell_df['nucleii']=cell_df['nucleii'].astype(object)
    for cell in cell_df.index:
        cell_nuc_overlap = nuclear_masks [cell_masks==cell]
        cell_nuc_overlap=[n for n in cell_nuc_overlap if n>0]
        cell_df.at[cell,'nucleii'] = np.unique(cell_nuc_overlap)
        cell_df.at[cell,'nuc_count'] = cell_df.at[cell,'nucleii'].size
    return cell_df


def analyze_ferrets(cell_df,cell_masks,voxel_dim_um):

    cell_df['feret_d']=np.nan
    cell_df['feret_d']= cell_df['feret_d'].astype(object)
    for id in cell_df.index:
        cell_df.at[id,'feret_d'] = nd_feret(cell_masks==id,zxy_scalars=voxel_dim_um[[2,0,1]])

    ferrets = cell_df['feret_d'].to_numpy()
    ferrets = np.stack(ferrets,axis=0)
    cell_df['feret_z'] = ferrets[:,0]
    cell_df['feret_x'] = ferrets[:,1]
    cell_df['feret_y'] = ferrets[:,2]
    cell_df['min_fer'] = np.min(ferrets,axis=1)
    cell_df['box_vol'] = np.prod(ferrets,axis=1)
    cell_df['pct_fill'] = cell_df['sizes_um3']/cell_df['box_vol']

    return cell_df

def nd_feret(mask,zxy_scalars=None):
    if zxy_scalars is None:
        scalars=np.ones_like(mask.shape)
    n_coords=np.where(mask)
    feret_diams = list()
    for s, coord in zip(zxy_scalars,n_coords):
        diam = np.max(coord)-np.min(coord)+1
        feret_diams.append(diam*s)
    return np.asarray(feret_diams)

In [None]:
"""
NeuN vol divided by Num Nuclei
"""

# for curr_tif in tqdm(tiff_list):

def calc_neun_layer_mask(tiff_data_raw,thresh=0.2):
    max_neun = np.max(tiff_data_raw[:,1,:,:],axis=0)
    neun_z = (max_neun - np.mean(max_neun)) / np.std(max_neun)
    chunks = ski.measure.label(neun_z>thresh)*1.0
    size,label=np.histogram(chunks,bins=np.unique(chunks))
    label=label[1:-1]
    size=size[1:]
    chunk_sizes ={l:s for l,s in zip(label,size)}
    biggest = label[size == np.nanmax(size)]
    neun_layer_mask = chunks==biggest
    neun_layer_mask_4d=np.stack([neun_layer_mask]*tiff_data_raw.shape[1],axis=0)
    neun_layer_mask_4d=np.stack([neun_layer_mask_4d]*tiff_data_raw.shape[0],axis=0)
    tiff_data_masked = tiff_data_raw*neun_layer_mask_4d
    return neun_layer_mask, tiff_data_masked

def mean_volumes(nuclear_masks,tiff_data_masked,neun_layer_mask):
    # plt.imshow(np.sum(nuclear_masks,axis=0))

    neun = tiff_data_masked[:,1,:,:]
    neun_z = (neun - np.mean(neun)) / np.std(neun)
    neun_pos = neun_z >0.5

    num_nuclei = np.max(nuclear_masks)

    neun_vox_vol = np.sum(neun_pos)
    voxel_dim_um=np.array([0.3977476,0.3977476,0.5])
    voxel_vol_um3 = np.prod(voxel_dim_um)
    neun_vox_vol_um3 = neun_vox_vol * voxel_vol_um3
    mean_vol_um3 = neun_vox_vol_um3/num_nuclei
    return (mean_vol_um3, num_nuclei,neun_vox_vol_um3)


def np_to_rgb(raw,channel_bal=[1,1,1]):
    norm = (raw - np.min(raw))/(np.max(raw) - np.min(raw))
    rgb = np.stack( [norm*channel_bal[0],
                     norm*channel_bal[1],
                     norm*channel_bal[2],
                     ],axis=-1)
    return rgb