# Script for preprocessing of Tina's Calcium imaging data

### Dependencies: matplotlib, scikit image, h5py, numpy, pyqtgraph (interactive plots of 3D stacks), roi (custom python class for drawing rois)

### Overview of steps
1. Import tiff stack
2. Open using pyqtgraph (interactive browsing of data)
3. Simple movement correction
4. Extraction of time series from ROIs
5. Some visualizations of ROI time series

In [170]:
import matplotlib
matplotlib.use('Qt5Agg')

%gui qt

In [171]:
from skimage.io import imread, imshow
import h5py
from matplotlib import pyplot as plt
from matplotlib.widgets import Button
import numpy as np
import pyqtgraph as pq
from sys import path
from os.path import sep, exists
from os import mkdir, makedirs, getcwd

In [172]:
# import custom roi module
path.insert(1, getcwd() + sep + 'utils')

from roiDrawing import roi

from roiVisualization import illustrateRois

### 1. Import tiff stack

#### Choose and import tiff file (imaging data) and generate directory for storing analysis plots

In [173]:
parentDir = sep.join(getcwd().split(sep)[:-1])

rawtiff = 'JF549-BAPTA_MPM/170920/170920_549_whitenoise3cw_00001.tif'
[dye, date, expt] = rawtiff.split(sep)
plotDir = parentDir + sep + 'Plots'

In [174]:
with open(parentDir + sep + rawtiff, 'rb') as fh:
    endofhead = 0 
    fpv = -1
    numVols = -1
    
    while(not endofhead):
        line = str(fh.readline()[:-1])
        
        # extract version
        if 'VERSION_' in line:
            print(line)
            
        # get channel info
        if 'hChannels' in line:
            if 'channelSave' in line:
                print(line)
                nCh = int(line[-2:-1])-1
        
        # get number of planes per z-stack
        if 'hStackManager' in line:
            if 'numFramesPerVolume' in line:
                print(line)
                fpv = int(line[-2:-1]) 
        if 'hFastZ' in line:
            if 'numVolumes' in line:
                print(line)
                lineString = str(line)
                numVols = int(lineString[lineString.find('=')+1:-1])
    
        if not 'SI' in line:
            endofhead = 1

b"SI.VERSION_MAJOR = '2016b'"
b"SI.VERSION_MINOR = '1'"
b'SI.hChannels.channelSave = 2'
b'SI.hFastZ.numVolumes = 1'


In [175]:
rawstack = imread(parentDir + sep + rawtiff)
stackshape = rawstack.shape
print(stackshape)

(1794, 512, 512)


In [176]:
# Generate directory where to save plots
saveDir = sep.join([plotDir,dye,date,expt[:-4]])
if not exists(saveDir):
    makedirs(saveDir)

#### Adjust frame settings to match chosen file
These are parameters that can not be retreaved from the file header

In [177]:
caCh = 0

fps = 20 # check that with notes

numimgs = int(stackshape[0])

if not '_z_' in rawtiff:
    fpv = 1
    numVols = int(stackshape[0]/fpv)
elif fpv <0:
    fpv = int(stackshape[0]/numVols)


print("# frames: " + str(numimgs))
print("# volumes: " + str(numVols))
print("# frames per vol: " + str(fpv))

# frames: 1794
# volumes: 1794
# frames per vol: 1


#### Reshape tiff stack

In [178]:
stack5d = rawstack.reshape((int(stackshape[0]/(fpv*nCh)),fpv,nCh,stackshape[1], stackshape[2]))
print('Full dimensional stack: stack5d')
print(stack5d.shape)

print('Stack reduced to one channel: stack4d')
stack4d = np.squeeze(stack5d[:,:,caCh,:,:])

print(stack4d.shape)

Full dimensional stack: stack5d
(1794, 1, 1, 512, 512)
Stack reduced to one channel: stack4d
(1794, 512, 512)


### 2. Open with pyqt graph

In [179]:
if fpv==1:
    pq.image(stack4d.reshape(stack4d.shape[0], 
                             stack4d.shape[1],
                             stack4d.shape[2]))
else:
    pq.image(stack4d.reshape(stack4d.shape[0], 
                             stack4d.shape[1] * stack4d.shape[2],
                             stack4d.shape[3]))

### Optional: Collapse volume using max projection

In [180]:
if fpv>1:
    stackMP = stack4d.max(axis=1)
    stackMP.shape

    # Visualise new stack in pyqt graph:
    pq.image(stackMP)

In [181]:
# Use max projection?
if fpv>1:
    tiffstack = stackMP
else:
    tiffstack = rawstack

numframes = tiffstack.shape[0]

### 3. Simple movement correction
#### Alignment to reference plane

In [182]:
# pick reference frame, e.g. middle of last volume scanned.
refframe = int(numframes-1)
print('reference frame: ' + str(refframe))

refimg = tiffstack[refframe]
plt.imshow(refimg,cmap = 'Greys_r', vmin=0);
plt.title('Reference frame')
plt.show()

reference frame: 1793


In [183]:
# compute cross correlation between ref. image and all other images in tiff stack
framecorr = np.zeros((numframes,1))

for frame in range(numframes):
    framecorr[frame] = np.corrcoef(tiffstack[frame].ravel(), refimg.ravel())[0,1]

In [184]:
# Set threshold (min. correlation coefficient) for selection of frames
corrTH = 0.7

corrFig, ax = plt.subplots(1,1,figsize=(10,3))
ax.plot(framecorr)
ax.set_ylabel('correlation coeff.')
ax.set_xlabel('frame')
ax.set_title('person correlation with reference frame ('+str(refframe)+
             '), threshold: '+str(corrTH)+
             ' ('+str(round(100*int(sum(framecorr>corrTH))/numframes,2))+'%)')
ax.axhline(y=corrTH, xmin=0, xmax=numframes, color='r')
plt.show()
plt.savefig(saveDir+sep+'stackCorr_'+expt[:-4]+'.pdf', format = 'pdf')

In [185]:
# Select frames above threshold
slct_numframes = int(sum(framecorr>corrTH))
slct_frames = (framecorr>corrTH).astype('int')
slct_tiffstack = np.zeros((slct_numframes,tiffstack.shape[1],tiffstack.shape[2]))
for ind, frame in enumerate(np.where(framecorr>corrTH)[0]):
    slct_tiffstack[ind] = tiffstack[frame]

In [186]:
# Open in pyqtgraph
pq.image(slct_tiffstack)
print(slct_tiffstack.shape)
plt.show()

(1527, 512, 512)


### 4. Extraction of time series from ROIs
#### Draw ROI's

In [187]:
# Choose type of image for ROI drawing (std is default)
imgtype = 'std'  # ['std', 'mean', 'median']

# Interactive figure to draw ROIs (uses roi.py)
fig, axs = plt.subplots(figsize=(10,6))
if imgtype == 'mean':
    sample = slct_tiffstack.mean(axis=0)
elif imgtype =='median':
    sample = np.median(slct_tiffstack,axis=0)
else:
    sample = slct_tiffstack.std(axis=0)
    
axs.imshow(sample, cmap='gray')
drawing = roi.RoiDrawing(axs, sample)
ax_butt = plt.axes([0.1, 0.15, 0.1, 0.075])
wipe_butt = Button(ax_butt, 'Wipe')
wipe_butt.on_clicked(drawing.wipe)

ax_decr = plt.axes([.1, .25, .1, .075])
decr_button = Button(ax_decr, 'Prev ROI')
decr_button.on_clicked(drawing.focus_decr)

ax_incr = plt.axes([.1, .35, .1, .075])
incr_button = Button(ax_incr, 'Next ROI')
incr_button.on_clicked(drawing.focus_incr)

plt.show()

In [190]:
# Save image of ROIs after drawing them
fig.savefig(saveDir+sep+'rois_'+expt[:-4]+'.pdf', format = 'pdf')

#### Save ROI data to file

In [193]:
# Generate dictionary with all ROI information -- assumes max projection, i.e. 
roiData = {
    'imgData': rawtiff,
    'img': sample,
    'numframes': slct_numframes,
    'slctframes': slct_frames,
    'fpv': fpv
}

# Save time series
roiTS = np.nan*np.ones((slct_numframes,len(drawing.rois)))
roiShapes = []
for i, r in enumerate(drawing.rois):
    coords = np.where(r.get_mask())
    roiTS[:,i] = slct_tiffstack[:,coords[0], coords[1]].mean(1)
    roiShapes.append(np.where(r.get_mask()))

roiData['numRoi'] = len(drawing.rois)
roiData['roiTS'] = roiTS
roiData['roiShapes'] = roiShapes

np.save(saveDir+sep+'roiData_'+expt[:-4],roiData) # save as npy file (easy to load into pyton)
np.savetxt(saveDir+sep+'roisTS_'+expt[:-4]+'.csv',roiTS, delimiter=',') # save as csv (easy to open e.g. in excel)

### 5. Some visualizations of ROI time series
#### Visualize chosen ROIs and plot time series

In [194]:
roifig = illustrateRois(roiData)
roifig.show()
roifig.savefig(saveDir+sep+'roi-viz_'+expt[:-4]+'.pdf', format = 'pdf')

In [195]:
# Select frame range for time series plot
ts = 0
te = slct_numframes

# generate time vector
time = np.linspace(ts/fps, te/fps, (te-ts))

In [196]:
# Make time series plot
fig, axs = plt.subplots(figsize=(10,4))

for r in drawing.rois:
    coords = np.where(r.get_mask())
    
    bl = slct_tiffstack[ts:te+1,coords[0], coords[1]].mean(1).mean(0)
    axs.plot(time, (slct_tiffstack[ts:te+1,coords[0], coords[1]].mean(1)-bl)/bl, '.-')
plt.show()
axs.legend(['ROI '+str(i+1) for i in range(len(drawing.rois))], fontsize=12)
axs.set_xlabel('Time [s]', fontsize=13)
axs.set_ylabel('$(F - F_0) / F_0$', fontsize=13);

In [198]:
# Save after making adjustments through figure window GUI
fig.savefig(saveDir+sep+'roi-traces_'+expt[:-4]+'.pdf', format = 'pdf')

#### Plot correlations between all selected ROIs

In [199]:
fig, axs = plt.subplots(len(drawing.rois),len(drawing.rois), figsize=(15,14))

ts = 100
te = slct_numframes

for ri1, r1 in enumerate(drawing.rois[:]):
    
    coords1 = np.where(r1.get_mask())
    bl1 = slct_tiffstack[ts:te+1,coords1[0], coords1[1]].mean(1).mean(0)
    roi1 = (slct_tiffstack[ts:te+1,coords1[0], coords1[1]].mean(1)-bl1)/bl1

    for ri2, r2 in enumerate(drawing.rois[:ri1+1]):
        coords2 = np.where(r2.get_mask())
                
        bl2 = slct_tiffstack[ts:te+1,coords2[0], coords2[1]].mean(1).mean(0)
        roi2 = (slct_tiffstack[ts:te+1,coords2[0], coords2[1]].mean(1)-bl2)/bl2
    
        axs[ri1,ri2].plot(roi2, roi1, '.', alpha=0.3, color='black')
        
        axs[ri1,ri2].set_xlabel('ROI '+str(ri2+1))
        axs[ri1,ri2].set_ylabel('ROI '+str(ri1+1))
        axs[ri1,ri2].set_xlim(min(roi2),max(roi2))
        axs[ri1,ri2].set_ylim(min(roi1),max(roi1))
        
fig.tight_layout()
plt.show()

fig.savefig(saveDir+sep+'roi-corr_'+expt[:-4]+'.pdf', format = 'pdf')