### Credits

This is based on the method used in this paper:
Hernando M. Vergara, Constantin Pape, Kimberly I. Meechan, Valentyna Zinchenko, Christel Genoud, Adrian A. Wanner, Kevin Nzumbi Mutemi, Benjamin Titze, Rachel M. Templin, Paola Y. Bertucci, Oleg Simakov, Wiebke Dürichen, Pedro Machado, Emily L. Savage, Lothar Schermelleh, Yannick Schwab, Rainer W. Friedrich, Anna Kreshuk, Christian Tischer, Detlev Arendt,

Whole-body integration of gene expression and single-cell morphology,
Cell, Volume 184, Issue 18, 2021, Pages 4819-4837.e22, ISSN 0092-8674, https://doi.org/10.1016/j.cell.2021.07.017.

Full text links:
 - [Cell](https://www.sciencedirect.com/science/article/pii/S009286742100876X)
 - [bioarxiv](https://www.biorxiv.org/content/10.1101/2020.02.26.961037v1)



In [None]:
import time

import tifffile
import numpy as np
from matplotlib import pyplot as plt
import mrcfile

import napari

In [None]:
viewer = napari.Viewer()

In [None]:
'''
Open the imod aligned mrc file using the package mrcfile

The stack is opened as memory mapped file so sections can be loaded in
memory as needed.

x : mrcfile memory map object
    the image stack array is available with x.data
    
'''
x = mrcfile.mmap('stack_ali.mrc')

In [None]:
'''
napari will only read a section when the z-slider is moved
'''

viewer.add_image(x.data)

In [None]:
def correct_z(a, lref, uref):
'''
    Adjust the histogram of the section. 
    based on https://github.com/mobie/platybrowser-project/tree/main/misc/intensity_correction
    https://www.biorxiv.org/content/10.1101/2020.02.26.961037v1
    
    a : array
        the image section to be corrected
    lref : float
        the lower intensity to correct to
    uref :
        the upper intensity to correct to
        
    returns : array
        the image with the corrected histogram
'''
    agood = a[a > 0]
    u = np.percentile(agood, 95)
    low = np.percentile(agood, 5)
    c = (uref - lref)*(a - u)/(u - low) + uref
    
    if a.dtype.itemsize == 1:
        c = np.where(c < 0, 0, c)
        c = np.where(c > 255, 255, c)
        c = c.astype(np.uint8)
    
    return c


In [None]:
'''
The the upper and lower reference intensities from the median
of a few randomly selected sections

rs : array [int]
    randomly selected section indices
_d : array
    random sections cropped to get remove strange things on edges
xref : array
    median reference image
uref : number
    the 95th percentile intensity
lref : number
    the 5th percentile intensity
'''

rs = np.random.randint(0, len(x.data), 20)

_d = x.data[rs, 1000:-1000, 1000:-1000]
_d = _d[_d > 0]


xref = np.median(_d, axis=0)
print(xref.shape)

uref = np.percentile(xref, 95)
lref = np.percentile(xref, 5)

print(uref, lref)


In [None]:
'''
test the correction on the random sections and view it
'''

rlist = list()
for r in sorted(rs):
    rlist.append(correct_z(x.data[r], lref, uref))

viewer.add_image(np.stack(rlist), colormap="gray")
    

### Using memmap from tifffle

```python
>>> memmap_image = memmap(
...     'temp.tif',
...     shape=(256, 256, 3),
...     dtype='float32',
...     photometric='rgb'
... )
>>> type(memmap_image)
<class 'numpy.memmap'>
>>> memmap_image[255, 255, 1] = 1.0
>>> memmap_image.flush()
>>> del memmap_image
```

In [None]:
'''
Use the tifffile.memmap function so the output can be written section-by-section. 
'''

t1 = time.time()
mmtif = tifffile.memmap('stack_ali_zcor.tif',
                        shape=x.data.shape,
                        dtype=np.float32)
                        
for i, _x in enumerate(x.data):
    mmtif[i] = correct_z(_x, lref, uref)
    if i % 100 == 0:
        print(i,time.time() - t1)
        t1 = time.time()
    

In [None]:
del mmtif

In [None]:
xc = tifffile.memmap('stack_ali_zcor.tif')

In [None]:
viewer.add_image(xc)

In [None]:
'''
Bin the corrected stack by 4

'''

t1 = time.time()
nz = xc.data.shape[0]
ny = xc.data.shape[1]//4
nx = xc.data.shape[2]//4
mmtif = tifffile.memmap('stack_ali_zcor_bin4.tif',
                        shape=(nz, ny, nx),
                        dtype=np.float32)
                        
for i, _x in enumerate(xc):
    k = _x.reshape(10000//4, 4, 10000//4, 4)
    kb = k.mean(axis=(1,3))
    mmtif[i] = kb
    if i % 100 == 0:
        print(i,time.time() - t1)
        t1 = time.time()

In [None]:
del mmtif

In [None]:
xb = tifffile.memmap('stack_ali_zcor_bin4.tif')
viewer.add_image(xb)

In [None]:
viewer.add_image(np.moveaxis(xb, 1, 0))