# Script for preprocessing of Tina's Calcium imaging data

### Dependencies: matplotlib, scikit image, h5py, numpy, pyqtgraph (interactive plots of 3D stacks), roi (custom python class for drawing rois)

### Overview of steps
1. Import tiff stack
2. Open using pyqtgraph (interactive browsing of data)
3. Simple movement correction
4. Extraction of time series from ROIs
5. Some visualizations of ROI time series

In [2]:
import matplotlib
matplotlib.use('Qt5Agg')

%gui qt

In [3]:
from skimage.io import imread, imshow
import h5py
from matplotlib import pyplot as plt
from matplotlib.widgets import Button
import numpy as np
import pyqtgraph as pq
from sys import path
from os.path import sep, exists
from os import mkdir, makedirs, getcwd



In [4]:
# import custom roi module
path.insert(1, getcwd() + sep + 'utils')

from roiDrawing import roi

### 0. Choose files to analyse

In [5]:
parentDir = sep.join(getcwd().split(sep)[:-1])

rawtiff = 'JF549-BAPTA_MPM/170920/170920_549_bar2ccw_00001.tif'

rawh5 = 'JF549-BAPTA_MPM/170920/170920_bar2ccw_0001.h5'

[dye, date, expt] = rawtiff.split(sep)
plotDir = parentDir + sep + 'Plots'

### 1. Import tiff stack

#### Choose and import tiff file (imaging data)

In [43]:
with open(parentDir + sep + rawtiff, 'rb') as fh:
    endofhead = 0 
    
    fpv = 1

    while(not endofhead):
        line = str(fh.readline()[:-1])
        
        # extract version
        if 'VERSION_' in line:
            print(line)
            
        # get channel info
        if 'hChannels' in line:
            if 'channelSave' in line:
                print(line)
                nCh = int(line[-2:-1])-1
        
        # get number of planes per z-stack
        if 'hStackManager' in line:
            if 'numFramesPerVolume' in line:
                print(line)
                fpv = int(line[-2:-1]) 
        if 'hFastZ' in line:
            if 'numVolumes' in line:
                print(line)
    
        if not 'SI' in line:
            endofhead = 1

b"SI.VERSION_MAJOR = '2016b'"
b"SI.VERSION_MINOR = '1'"
b'SI.hChannels.channelSave = 2'
b'SI.hFastZ.numVolumes = 1'


In [44]:
rawstack = imread(parentDir + sep + rawtiff)
stackshape = rawstack.shape
print(stackshape)

(1794, 512, 512)


In [45]:
# Generate directory where to save plots
saveDir = sep.join([plotDir,dye,date,expt[:-4]])
if not exists(saveDir):
    makedirs(saveDir)

#### Adjust frame settings to match chosen file
These are parameters that can not be retreaved from the file header

In [46]:
caCh = 0

fps = 20 # check that with notes

numimgs = int(stackshape[0])
numVols = stackshape[0]/fpv
print("# frames: " + str(numimgs))
print("# volumes: " + str(numVols))

# frames: 1794
# volumes: 1794.0


#### Reshape tiff stack

In [47]:
stack5d = rawstack.reshape((int(stackshape[0]/(fpv*nCh)),fpv,nCh,stackshape[1], stackshape[2]))
print('Full dimensional stack: stack5d')
print(stack5d.shape)

print('Stack reduced to one channel: stack4d')
stack4d = np.squeeze(stack5d[:,:,caCh,:,:])

print(stack4d.shape)

Full dimensional stack: stack5d
(1794, 1, 1, 512, 512)
Stack reduced to one channel: stack4d
(1794, 512, 512)


### 2. Open with pyqt graph

In [48]:
if fpv==1:
    pq.image(stack4d.reshape(stack4d.shape[0], 
                             stack4d.shape[1],
                             stack4d.shape[2]))
else:
    pq.image(stack4d.reshape(stack4d.shape[0], 
                             stack4d.shape[1] * stack4d.shape[2],
                             stack4d.shape[3]))

### Optional: Collapse volume using max projection

In [49]:
if fpv > 1:
    stackMP = stack4d.max(axis=1)
    stackMP.shape

    # Visualise new stack in pyqt graph:
    pq.image(stackMP)
else:
    stackMP = rawstack

In [50]:
# Use max projection?

tiffstack = stackMP
#tiffstack = rawstack

numframes = tiffstack.shape[0]

### 3. Simple movement correction
#### Alignment to reference plane

In [51]:
# pick reference frame, e.g. middle of last volume scanned.
refframe = int(numframes-1)
print('reference frame: ' + str(refframe))

refimg = tiffstack[refframe]
plt.imshow(refimg,cmap = 'Greys_r', vmin=0);
plt.title('Reference frame')
plt.show()

reference frame: 1793


In [52]:
# compute cross correlation between ref. image and all other images in tiff stack
framecorr = np.zeros((numframes,1))

for frame in range(numframes):
    framecorr[frame] = np.corrcoef(tiffstack[frame].ravel(), refimg.ravel())[0,1]

In [53]:
# Set threshold (min. correlation coefficient) for selection of frames
corrTH = 0.7

corrFig, ax = plt.subplots(1,1,figsize=(10,3))
ax.plot(framecorr)
ax.set_ylabel('correlation coeff.')
ax.set_xlabel('frame')
ax.set_title('person correlation with reference frame ('+str(refframe)+
             '), threshold: '+str(corrTH)+
             ' ('+str(round(100*int(sum(framecorr>corrTH))/numframes,2))+'%)')
ax.axhline(y=corrTH, xmin=0, xmax=numframes, color='r')
plt.show()
plt.savefig(saveDir+sep+'stackCorr_'+expt[:-4]+'.pdf', format = 'pdf')

In [54]:
# Select frames above threshold
slct_numframes = int(sum(framecorr>corrTH))
slct_tiffstack = np.zeros((slct_numframes,tiffstack.shape[1],tiffstack.shape[2]))
for ind, frame in enumerate(np.where(framecorr>corrTH)[0]):
    slct_tiffstack[ind] = tiffstack[frame]

In [55]:
# Open in pyqtgraph
pq.image(slct_tiffstack)
print(slct_tiffstack.shape)
plt.show()

(1794, 512, 512)


### 4. Extraction of time series from ROIs
#### Draw ROI's

In [56]:
# Choose type of image for ROI drawing (std is default)
imgtype = 'std'  # ['std', 'mean', 'median']

# Interactive figure to draw ROIs (uses roi.py)
fig, axs = plt.subplots(figsize=(10,6))
if imgtype == 'mean':
    sample = slct_tiffstack.mean(axis=0)
elif imgtype =='median':
    sample = np.median(slct_tiffstack,axis=0)
else:
    sample = slct_tiffstack.std(axis=0)
    
axs.imshow(sample, cmap='gray')
drawing = roi.RoiDrawing(axs, sample)
ax_butt = plt.axes([0.1, 0.15, 0.1, 0.075])
wipe_butt = Button(ax_butt, 'Wipe')
wipe_butt.on_clicked(drawing.wipe)

ax_decr = plt.axes([.1, .25, .1, .075])
decr_button = Button(ax_decr, 'Prev ROI')
decr_button.on_clicked(drawing.focus_decr)

ax_incr = plt.axes([.1, .35, .1, .075])
incr_button = Button(ax_incr, 'Next ROI')
incr_button.on_clicked(drawing.focus_incr)

plt.show()

In [57]:
# Save image of ROIs after drawing them
fig.savefig(saveDir+sep+'rois_'+expt[:-4]+'.pdf', format = 'pdf')

In [58]:
# Visualize drawn ROIs
fig, axs = plt.subplots(nrows=1,ncols=len(drawing.rois), figsize=(3*len(drawing.rois),4))
for ind, r in enumerate(drawing.rois):
    axs[ind].imshow(r.get_mask())
    axs[ind].set_title('roi' + str(ind))
    
plt.show()
fig.savefig(saveDir+sep+'rois2_'+expt[:-4]+'.pdf', format = 'pdf')

In [59]:
# Save time series
roiTS = np.nan*np.ones((slct_numframes,len(drawing.rois)))

for i, r in enumerate(drawing.rois):
    coords = np.where(r.get_mask())
    roiTS[:,i] = slct_tiffstack[:,coords[0], coords[1]].mean(1)

np.save(saveDir+sep+'rois_'+expt[:-4],roiTS) # save as npy file (easy to load into pyton)
np.savetxt(saveDir+sep+'rois_'+expt[:-4]+'.csv',roiTS, delimiter=',') # save as csv (easy to open e.g. in excel)

### 5. Some visualizations of ROI time series
#### Plot time series of chosen ROIs

In [60]:
# Select frame range for time series plot
ts = 10
te = slct_numframes

# generate time vector
time = np.linspace(ts/fps, te/fps, (te-ts))

In [61]:
# Make time series plot
fig, axs = plt.subplots(figsize=(10,4))

for r in drawing.rois:
    coords = np.where(r.get_mask())
    
    bl = slct_tiffstack[ts:te,coords[0], coords[1]].mean(1).mean(0)
    axs.plot(time, (slct_tiffstack[ts:te,coords[0], coords[1]].mean(1)-bl)/bl, '.-')
plt.show()
axs.legend(['ROI '+str(i+1) for i in range(len(drawing.rois))], fontsize=12)
axs.set_xlabel('Time [s]', fontsize=13)
axs.set_ylabel('$(F - F_0) / F_0$', fontsize=13);

In [62]:
# Save after making adjustments through figure window GUI
fig.savefig(saveDir+sep+'roi-traces_'+expt[:-4]+'.pdf', format = 'pdf')

## Process h5 file (behavior, trigger pulses)
### Set experimental parameter

In [111]:
tmParams = {
    'px2mm': 0.03,
    'calibParam':(1.5, 1.5),
    'rBall':48, #mm
    'tickLength': 0.002, #s
    'tickAmp':0.2 #V
}
arenaSize = 270 # degrees

trialS = 120

wsParams = {
    'fps': 20000,
    'channelNames': ['x1','x2','y1','y2','arena','unknown','RearCamFramClockchannel','TwoPFrameClockchannel'],
    'Offsets': [2.05, 2.04, 2.05, 2.05] #Offsets of X0,X1,Y0,Y1 tick channels
}

# You can check chanel names using list(myh5['header']['AIChannelNames'])

### Load file

In [7]:
myh5 = h5py.File(parentDir + sep + rawh5, 'r')
h5keys = [name for name in myh5]

print(h5keys)

['header', 'sweep_0001']


In [8]:
# Select sweep
sweepInd = 1

#for name in myh5[h5keys[sweepInd]]:
#    print(name)
    
sweepkey = h5keys[sweepInd]
sweepdat = list(myh5[h5keys[sweepInd]]['analogScans'])

In [79]:
scalingCoeffs = list(myh5['header']['Acquisition']['AnalogScalingCoefficients'][0])
print(scalingCoeffs)

[0.00036112387879333203, 0.00032877755951857268, -3.2766974414957873e-14, 2.2555316988730738e-18]


### Quick visual check of the recording

In [81]:
channelFig = plt.figure(figsize=(10,10))

tS = int(14*wsParams['fps'])
tE = int(15*wsParams['fps'])


for t in range(len(sweepdat)):
    ax = channelFig.add_subplot(len(sweepdat),1,t+1)
    ax.plot(sweepdat[t][tS:tE]*scalingCoeffs[0],'.-')
plt.show()

channelHistFig = plt.figure(figsize=(15,2))

tS = 0
tE = -1

for t in range(len(sweepdat)):
    ax = channelHistFig.add_subplot(1,len(sweepdat),t+1)
    ax.hist(sweepdat[t][tS:tE]*scalingCoeffs[0], bins=50)

channelHistFig.tight_layout()
plt.show()

### Extract and process treadmill measurements

In [82]:
# Store treadmill values in dict
tm_raw = {}
for ch in range(4):
    tm_raw[wsParams['channelNames'][ch]] = sweepdat[ch]*scalingCoeffs[0]

# extract arena position
arena_raw = sweepdat[4]

# rescale arena signal
arena = arenaSize*(arena_raw - 0.5*np.max(arena_raw))/np.max(arena_raw)

In [141]:
def processTMSignals(rawtmsignal, tmParams, wsParams):
    tmchan = tm_raw.keys()
    tmticks= {} # new dict for digitizes values
    
    maxTicks = 3
    
    for i, ch in enumerate(tmchan):
        chfilt = mysmooth(rawtmsignal[ch]-np.median(rawtmsignal[ch]),int(wsParams['fps']*tmParams['tickLength']))
        chquant = quantizeTicks(chfilt, maxTicks, tmParams['tickAmp'])
        
        tmticks[ch] = chquant
        
    return tmticks

def mysmooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

def quantizeTicks(y, maxTicks, tickAmp):
    yq = y.copy()
    digBins = np.linspace(-maxTicks, maxTicks,2*maxTicks+1)*tickAmp*0.7-(tickAmp*0.7)/2.0
    yq = np.digitize(y, digBins)
    yq = (yq - (maxTicks+1))
    return yq

In [142]:
tm_ticks = processTMSignals(tm_raw, tmParams, wsParams)

In [149]:
sweepFig, (ax1, ax2) = plt.subplots(1,2,figsize=(10,4))
for i, k in enumerate(tm_raw.keys()):
    ax1.hist(tm_raw[k], 100, alpha=0.25); 
for i, k in enumerate(tm_ticks.keys()):
    ax2.hist(tm_ticks[k], 100, alpha=0.25); 
plt.show()

In [187]:
conversionFactor_pitch = px2mm*float(calibParam[0])
conversionFactor_yaw = px2mm*float(calibParam[1])

gammaRad = 45*np.pi/180.0

time = np.linspace(0, trialS, len(x1_dig)+1)

fps_ds = 500
time_ds = np.linspace(0, time[-2], fps_ds*trialS)

# compute virtual rotation of fly
vFwd = - (y1_dig + y2_dig) * np.cos(gammaRad); #add components along longitudinal axis
vSide = - (y1_dig - y2_dig) * np.sin(gammaRad); #add components along transversal axis
vRot = - (x1_dig + x2_dig)/2 #average measured displacement along aximuth

# convert A.U. --> pixel --> mm
vFwd = vFwd * conversionFactor_pitch # use scaling factor for pitch
vSide = vSide * 0.5*(conversionFactor_yaw + conversionFactor_pitch) #use mean
vRot = vRot * conversionFactor_yaw # use scaling factor for yaw

# downsample through linear interpolation
from scipy.interpolate import interp1d

f_vFwd = interp1d(time[:-1], vFwd, kind = 'linear')
f_vSide = interp1d(time[:-1], vSide, kind = 'linear')
f_vRot = interp1d(time[:-1], vRot, kind = 'linear')
f_arena = interp1d(time[:-1], arena, kind = 'linear')

vFwd_ds = f_vFwd(time_ds)
vSide_ds = f_vSide(time_ds)
vRot_ds = f_vRot(time_ds)
arena_ds = f_arena(time_ds)

# convert to mm/s
dt = np.hstack((np.diff(time_ds),np.mean(np.diff(time_ds)) ))
vFwd_ds = vFwd_ds / dt
vSide_ds = vSide_ds / dt
vRot_ds =  vRot_ds / dt

rotV = vRot_ds / rBall #mm/s to deg/s

In [188]:
yawFig, (ax1, ax2) = plt.subplots(2,1,figsize=(12,5))
ax1.plot(time_ds, arena_ds, 'k.', markersize=1)
ax1.set_ylabel('arena pos [deg]')
ax1.set_xlim(0,trialS)
ax2.plot(time_ds, np.mod(np.cumsum(rotV*dt*180/np.pi)+180,360)-180,'.', markersize=1)
ax2.set_ylabel('integrated rot. vel. [deg]')
ax2.set_xlabel('time [s]')
ax2.set_xlim(0,trialS)
plt.show()

In [189]:
# Assume initial position (0 0 0) = (x-coord, y-coord, theta): 
# --> fly in origin, aligned with x axis (head forward)
# During measurement coordinate system is fly-centered, moves with fly.
# Compute all changes along those axes by updating theta and
# projecting the position changes onto the fixed coordinate system

theta = np.cumsum(rotV * dt)
theta = np.mod((theta + np.pi),2*np.pi) - np.pi

# movement in x and y direction
yTM_i = vSide_ds * np.cos(-theta) - vFwd_ds * np.sin(-theta) #compute increments x_i
yTM = np.cumsum(yTM_i* dt) # integrate x_i to get path

xTM_i = vSide_ds * np.sin(-theta) + vFwd_ds * np.cos(-theta)
xTM = np.cumsum(xTM_i * dt)

transV = np.hypot(xTM_i, yTM_i)

### Downsample

In [190]:
print(numframes)
print(len(time_ds))

1794
60000


In [111]:
time_ds[-1]

119.99991428571428

In [126]:
bh_samp = np.round(time_ds[-1]/numframes, decimals=2)
print(bh_samp)

1.15


In [135]:
time_dds = np.linspace(0, time[-2], numframes)#20*trialS)

f_transV = interp1d(time_ds, transV, kind = 'linear')
f_rotV = interp1d(time_ds, rotV, kind = 'linear')
f_xTM = interp1d(time_ds, xTM, kind = 'linear')
f_yTM = interp1d(time_ds, yTM, kind = 'linear')
f_theta = interp1d(time_ds, theta, kind = 'linear')

transV_dds = f_transV(time_dds)
rotV_dds = f_rotV(time_dds)
xTM_dds = f_xTM(time_dds)
yTM_dds = f_yTM(time_dds)
theta_dds = f_theta(time_dds)

In [136]:
velFig, (ax1, ax2) = plt.subplots(2,1,figsize=(12,5))
ax1.plot(time_dds, rotV_dds)
ax1.set_ylabel('rotational vel. [mm/s]')
ax1.set_xlim(0,trialS)

ax2.plot(time_dds, transV_dds)
ax2.set_ylabel('translational vel. [mm/s]')
ax2.set_xlabel('time [s]')
ax2.set_xlim(0,trialS)

plt.show()

velFig.savefig('walkingVel_ds'+str(100*bh_samp)+'_'+rawh5.split(sep)[-1][:-3]+'_sweep'+str(sweepnum)+'.pdf', format = 'pdf')

In [137]:
traceplot = plt.figure(figsize = (10,10))
ax = traceplot.add_subplot(111)
sc = plt.scatter(xTM_dds, yTM_dds,s=5, c=theta_dds, vmin=-np.pi, vmax=np.pi, alpha=0.8, cmap = 'hsv')
plt.colorbar(sc)
plt.plot(xTM_dds[0], yTM_dds[0], color='black',marker='o')
ax.set_aspect('equal')

plt.show()

traceplot.savefig('walkingTrace_ds'+str(100*bh_samp)+'_'+rawh5.split(sep)[-1][:-3]+'_sweep'+str(sweepnum)+'.pdf', format = 'pdf')

In [138]:
fig, axs = plt.subplots(1,len(drawing.rois), figsize=(10,4))
for ri, r in enumerate(drawing.rois[:]):
        
    coords = np.where(r.get_mask())

    bl = data[:,coords[0], coords[1]].mean(1).mean(0)
    roi = (data[:,coords[0], coords[1]].mean(1)-bl)/bl

    axs[ri].plot(roi, transV_dds, '.', alpha=0.3, color='black')

    axs[ri].set_xlabel(str(ri+1))
    axs[ri].set_ylabel('vTrans')

fig.tight_layout()
plt.show()