In [None]:
#import packages

import numpy as np
import napari
from skimage.measure import regionprops
from scipy.ndimage import zoom
import operator
from pathlib import Path
%gui qt


In [None]:
#load the images
labs = np.load('labs.npz')['arr_0']
mem_im = np.load('mem_im.npz')['arr_0']
nuc_im = np.load('nuc_im.npz')['arr_0']

#get the props of the labs
props = regionprops(labs)

#get the average centroid of the initial image cells
all_z = []
all_x = []
all_y = []
for prop in props:
    all_z.append(prop['centroid'][0])
    all_y.append(prop['centroid'][1])
    all_x.append(prop['centroid'][2])
average_centroid = (np.mean(all_z), np.mean(all_y), np.mean(all_x))

#make the original image prop dict
propDict = {}
propDict['original_im'] = {}
for prop in props:
    propDict['original_im'][prop['label']] = {}
    propDict['original_im'][prop['label']]['centroid'] = prop['centroid']
    propDict['original_im'][prop['label']]['coords'] = prop['coords']

#make a new output folder
Path('output').mkdir(parents=True, exist_ok=True)

#define a new (bigger) image shape
new_shape = (labs.shape[0] + 200, labs.shape[1] + 200, labs.shape[2] + 200)

#make a new nuclei and membrane image to hold the expanded sequence
whole_im = np.zeros((30, new_shape[0], new_shape[1], new_shape[2]), dtype=np.uint16)
whole_nuc = np.zeros((30, new_shape[0], new_shape[1], new_shape[2]), dtype=np.uint16)

#define a series of zooms to iterate through
scaling = np.arange(1.05, 2.55, 0.050)

for num, scale in enumerate(scaling):
    #make new images for the scaled membrane and nuclei
    new_image = np.zeros(new_shape, dtype=np.uint16)
    new_image_nuc = np.zeros(new_shape, dtype=np.uint16)
    
    #zoom by the appropriate scale
    zoomed_array = zoom(labs, (scale, scale, scale), order=0)
    zoomed_props = regionprops(zoomed_array)

    #add the new (zoomed) centroids to the dict
    propDict['zoomed_im'] = {}
    for prop in zoomed_props:
        propDict['zoomed_im'][prop['label']] = {}
        propDict['zoomed_im'][prop['label']]['centroid'] = prop['centroid']

        
    #from these, calculate the average centroid position of all the centroids
    all_z = []
    all_x = []
    all_y = []

    for cell in propDict['zoomed_im'].keys():
        all_z.append(propDict['zoomed_im'][cell]['centroid'][0])
        all_y.append(propDict['zoomed_im'][cell]['centroid'][1])
        all_x.append(propDict['zoomed_im'][cell]['centroid'][2])
    zoomed_average_centroid = (np.mean(all_z), np.mean(all_y), np.mean(all_x))    

    #loop through the cells
    for cell in propDict['original_im'].keys():
        
        #calculate the delta in the centroid position
        image_center = (new_image.shape[0]/2, new_image.shape[1]/2, new_image.shape[2]/2)
        original_centroid = propDict['original_im'][cell]['centroid']
        zoomed_centroid = propDict['zoomed_im'][cell]['centroid']
        
        delta = tuple(map(operator.sub, zoomed_centroid, original_centroid))

        #calculate the difference from the average centroid
        centroid_delta = tuple(map(operator.sub, zoomed_average_centroid, image_center))

        #get, and make a copy of the coordinates for this cell
        coords = propDict['original_im'][cell]['coords']
        coords_copy = np.copy(coords)

        #apply the shift in centroid position to each coordinate
        for i, index in enumerate(coords_copy):        
                index[0] = index[0] + delta[0] - centroid_delta[0]
                index[1] = index[1] + delta[1] - centroid_delta[1]
                index[2] = index[2] + delta[2] - centroid_delta[2]
                try:
                    new_image[index[0], index[1], index[2]] = mem_im[coords[i][0], coords[i][1], coords[i][2]]
                    new_image_nuc[index[0], index[1], index[2]] = nuc_im[coords[i][0], coords[i][1], coords[i][2]]
                except IndexError:
                    continue
    #save the files
    
    #add this image to the stack of expanded images
    whole_im[num,:,:,:] = new_image
    nuc_im[num,:,:,:] = new_image_nuc
    

    
#look at the image

viewer = napari.Viewer()   
viewer.add_image(whole_im, colormap = 'green', blending = 'additive')
viewer.add_image(nuc_im, colormap = 'magenta', blending = 'additive')
