## New in version 2:
* updated traces variable from C to C_raw
* updated dff from C.raw to C_df
* added save workspace at end
* plotting of ROIs on mean image is up and running again

In [None]:

from IPython.display import display
from IPython.display import HTML
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)

# This line will hide code by default when the notebook is exported as HTML
di.display_html('<script>jQuery(function() {if (jQuery("body.notebook_app").length == 0) { jQuery(".input_area").toggle(); jQuery(".prompt").toggle();}});</script>', raw=True)

# This line will add a button to toggle visibility of code blocks, for use with the HTML export version
di.display_html('''<button onclick="jQuery('.input_area').toggle(); jQuery('.prompt').toggle();">Toggle code</button>''', raw=True)


In [None]:
# all modules necessary for this nb
import os
import sys
import pickle

import numpy as np
import pylab as pl
from sklearn.covariance import EmpiricalCovariance
from sklearn.cluster import KMeans, AffinityPropagation
from sklearn.metrics import silhouette_score as clust_score
from sklearn.preprocessing import StandardScaler
from scipy import stats as sstats
from matplotlib.patches import Circle, Wedge, Polygon
from matplotlib.collections import PatchCollection

# setting parameters for default matplotlib plots
%matplotlib inline
pl.rcParams['savefig.dpi'] = 300 # dpi for most publications
pl.rcParams['xtick.labelsize'] = 7
pl.rcParams['ytick.labelsize'] = 7
pl.rcParams['axes.labelsize'] = 7
from ipywidgets import interact

# needs to find the library of functions
sys.path.append('/home/fabios/code/forco/')  # to be replaced!

import utils as ut
import plots as pt

------------------

These cells are used to backup and restore the variables in the workspace, no need to re-run the whole notebook. The variables are saved into `/home/fabios/autorestore/NOTEBOOK_NAME` (this is no longer true, I believe - JSB 5/26/18). Use the function `save_workspace(db)` to save the variables at any point. Only **variables** get saved, not functions, so you have to re-define your functions.

**Avoid moving the notebook into a different folder or renaming it**.

In [None]:
%%javascript
var nb = IPython.notebook;
var kernel = IPython.notebook.kernel;
var command = "NOTEBOOK_NAME = '" + nb.base_url + nb.notebook_path + "'";
kernel.execute(command);

In [None]:
# NOTEBOOK_NAME = NOTEBOOK_NAME.split('/')[-1][:-6]
NOTEBOOK_NAME = 'forco'

In [None]:
from pickleshare import PickleShareDB

autorestore_folder = os.path.join(os.getcwd(), 'autorestore', NOTEBOOK_NAME)
db = PickleShareDB(autorestore_folder)
import sys
sys.path.append('/home/fabios/code')
from workspace import *
import IPython
ip = IPython.get_ipython()

# this will restore all the saved variables. ignore the errors listed.
load_workspace(ip, db)

# use `save_worspace(db)` to save variables at the end

In [None]:
def extract_single_cycle(time_ax, dff, cycles, cycle, cell,
                         cycle_start=-8):
    fc = ut.filter_cycle(time_ax, cycles, cycle)
    t0 = time_ax[fc][0]
    return time_ax[fc] - t0 + cycle_start, dff[:, cell][fc]


def extract_single_cycle_signal(time_ax, signal, cycles, cycle,
                         cycle_start=-8):
    fc = ut.filter_cycle(time_ax, cycles, cycle)
    t0 = time_ax[fc][0]
    return time_ax[fc] - t0 + cycle_start, signal[fc]


def extract_single_cycle_time_ax(time_ax, cycles, cycle_duration=4, cycle_start=-8):
    # single_cycle_time_bins = filter_cycle(time_ax, cycles, cycle).sum() 
    min_len = np.inf
    for i, c in enumerate(cycles):
        time_ax_single = time_ax[ut.filter_cycle(time_ax, cycles, i)]-time_ax[ut.filter_cycle(time_ax, cycles, i)][0]
        time_ax_single = time_ax_single[time_ax_single < cycle_duration] + cycle_start
        l = len(time_ax_single)
        if l<min_len:
            min_len = l
    return time_ax_single[:min_len]


------------------

# Forco is for Trace Conditioning

## Enter here the correct folder

In [None]:
# we go up one folder from autorestore_folder
a = autorestore_folder.split('/')[:-3]
# data_folder = '/media/data/DATA1/dg_odor_nwoods/adam/2dayhabit/d1//'
data_folder = os.path.join('/', *a)

In [None]:
data_folder

In [None]:
import scipy.io as sio
from scipy.io import loadmat

In [None]:
# traces = np.loadtxt(os.path.join(data_folder, 'traces/C_raw.txt')).T
traces = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C_raw.txt')).T
# traces_raw = np.loadtxt(os.path.join(data_folder, 'traces/C_raw.txt')).T
events = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/S.txt')).T
dff = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C_df.txt')).T
denoised = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C.txt')).T
Cnn=np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/Cnn.txt'))
Coor = sio.loadmat(os.path.join(data_folder, 'CNMFe Conservative/Coor.mat'))['coor'][:,0]
#events = ut.event_detection_cnmfe_denoised(denoised)
dff_zs = ut.zscore_traces(dff)
# areas = np.loadtxt(os.path.join(data_folder, 'area/A.txt')).T

##   MAKE SURE YOU HAVE THE PROPER PIXEL DIMENSIONS FOR YOUR VIDEO BELOW
##def load_spatial_footprints_A(A_file, shape=(421, 514)):
##    return np.loadtxt(A_file).T.reshape([-1, shape[0], shape[1]])
##areas = load_spatial_footprints_A(os.path.join(data_folder, 'area/A.txt'))
mean_image, contours = ut.load_spatial_footprints(os.path.join(data_folder, 'CNMFe Conservative/Coor.mat'),
                                                  os.path.join(data_folder, 'CNMFe Conservative/Cnn.txt'),
                                                  key='coor')

filename = os.path.join(data_folder, 'arduino/behavior.txt')
behavior = ut.read_behavior(filename)
events_list = np.unique([b[1] for b in behavior])

In [None]:
# grab time axis from the xml file

import xml.etree.ElementTree as ET
xmlfile = os.path.join(data_folder, 'tseries.xml')
print "I infer the time axis from:\n", xmlfile
tree = ET.parse(xmlfile)
root = tree.getroot()

# unfortunately we miss the first frame
time_ax = np.r_[[child.attrib['absoluteTime']
                 for child in root.iter('Frame')]].astype(float)

In [None]:
# sync times
start_2p = ut.parse_behavior(behavior, 'BEGIN')[0]
behavior = [[float(b[0])-start_2p, b[1]] for b in behavior]
time_ax -= time_ax[0]

In [None]:
time_ax

# Modify this one here

Make sure that the following is correct. Use the `parse_behavior` function which accept regular expressions (for example, `tone*` to find all the events that start with `tone`).

In [None]:
# -----------------------------------------------------------
# these times are relative to the single cycle
# and centered around tone onset
CONTINUOUS = True
CYCLE_START = -8  # seconds
CS_START = 0  # seconds (let's keep this at 0)
CS_DURATION = 2  # seconds
DELAY = 2  # seconds
US_DURATION = 2  # seconds  // IS THIS FIXED?
AFTER_US_PERIOD = 5
REWARD_WIN = 2
CYCLE_DURATION = abs(CYCLE_START) + CS_DURATION + DELAY + US_DURATION + AFTER_US_PERIOD
CS_END = CS_START + CS_DURATION
US_START = CS_START + DELAY +CS_DURATION
US_END = US_START + US_DURATION

# -----------------------------------------------------------
# these times are absolute times, taken from the arduino file
# when the tones starts and ends
tone_CSm_ons = ut.parse_behavior(behavior, 'TONE_CSM')
tone_CSm_offs = ut.parse_behavior(behavior, 'TONE_CSM', offset=CS_DURATION)
tone_rw_ons = ut.parse_behavior(behavior, 'TONE_RW')
tone_rw_offs = ut.parse_behavior(behavior, 'TONE_RW', offset=CS_DURATION)
rewards = np.r_[ut.parse_behavior(behavior, 'REWARD')]

# -----------------------------------------------------------
# when the experiment starts and ends, in absolute time
# begin_end = ut.parse_behavior(behavior, '[be]')
# when each cycle starts and ends
# (last cycle is usually oddly recorded)
if CONTINUOUS:
    cycles_starts = ut.parse_behavior(behavior, 'TONE_*', offset=CYCLE_START)
    cycles_ends = ut.parse_behavior(behavior, 'TONE_*', offset=CYCLE_DURATION+CYCLE_START)
else:
    cycles_starts = ut.parse_behavior(behavior, 'BEGIN')
    cycles_ends = ut.parse_behavior(behavior, 'END')
cycle_subtract = 0   #do we need to subtract off the last cycle because it's too short???
if cycle_subtract !=0:
    cycles = np.r_[zip(cycles_starts,  # offset will be ADDED, with sign
                   cycles_ends)][:cycle_subtract]
else:
    cycles = np.r_[zip(cycles_starts,  # offset will be ADDED, with sign
                   cycles_ends)]
print 'we are subtracting off this many cycles'
print cycle_subtract
# -----------------------------------------------------------
# which trials are a.p. and which reward
is_CSm = [any(map(lambda t: (t>=s) and (t<e), tone_CSm_ons)) for s, e in zip(cycles_starts, cycles_ends)]
is_rwt = [any(map(lambda t: (t>=s) and (t<e), tone_rw_ons)) for s, e in zip(cycles_starts, cycles_ends)]
true_CSm = np.where([any(map(lambda t: (t>=s) and (t<e), tone_CSm_ons)) for s, e in cycles])[0]
true_rwt = np.where([any(map(lambda t: (t>=s) and (t<e), tone_rw_ons)) for s, e in cycles])[0]
# is_rewarded = [any(map(lambda t: (t>=s) and (t<e) and any((t-rewards)<(CS_DURATION+DELAY+REWARD_WIN)), tone_rw_ons))
#                for s, e in zip(cycles_starts, cycles_ends)]
is_rewarded = [any(map(lambda r: (r<e)*(r>=s), rewards))
               for s, e in zip(cycles_starts, cycles_ends)]
is_not_rewarded = is_rwt * ~np.r_[is_rewarded]

In [None]:
ut.get_cycles_durations(cycles, time_ax)

In [None]:
max_cycles = len(cycles)

time_ax_single = extract_single_cycle_time_ax(time_ax, cycles, cycle_duration=CYCLE_DURATION, cycle_start=CYCLE_START)

In [None]:
print "The first cycle starts at %f and ends at %f seconds." % (cycles[0][0], cycles[0][1])

In [None]:
licks = ut.parse_behavior(behavior, "LICK")

In [None]:
conv_func = np.roll(np.exp(-np.arange(100)/10.), 50)

In [None]:
lick_trace = np.zeros_like(time_ax)

In [None]:
for l in licks:
    lick_trace[np.argmin(abs(time_ax-l))] += 1

In [None]:
lick_trace_conv = np.convolve(lick_trace, conv_func/np.sum(conv_func), mode='same')
pl.plot(time_ax, lick_trace_conv)
pl.xlim(0, 400)

In [None]:
len(is_rwt)

In [None]:
def plot_licks(tone_onsets, licks, ax=None, positions=None, filter_ons=None, **vlines_args):
    if ax is None:
        fig, ax = pl.subplots(1, 1)
    if filter_ons is None:
        filter_ons = [True] * len(tone_onsets)
    if positions is None:
        positions = range(len(tone_onsets))
    for i, tone_ons in zip(positions, tone_onsets):
        lick_filt = ((licks-tone_ons)>CYCLE_START) * ((licks-tone_ons)<CYCLE_DURATION)
        ax.vlines((licks-tone_ons)[lick_filt], i, i+1, **vlines_args)
    return ax

In [None]:
fig, ax = pl.subplots(1, 1, figsize=(6, 4))
plot_licks(tone_rw_ons, licks, ax=ax, positions=range(len(true_rwt)), color='b')
plot_licks(tone_CSm_ons, licks, ax=ax, positions=len(true_rwt)+np.arange(len(true_CSm)), color='r')
# axs[0].set_title('CS+')
# plot_licks(tone_csm_ons, licks, ax=axs[1])
# axs[1].set_title('CS-')
# plot_licks(tone_ap_ons, licks, ax=axs[2])
# axs[2].set_title('AP')
# for ax in axs:
ax.set_xlim(CYCLE_START, CYCLE_START+CYCLE_DURATION)
# ax.set_ylim(-10, len(cycles))
#ax.set_xlim(CYCLE_START, 15)
ax.set_ylabel('Trial #')
ax.set_xlabel("Time from tone onset (s)")
pt.plot_period_bar(ax, -3, delta_y=2, color='g', start_end=(CS_START, CS_END))
pt.plot_period_bar(ax, -3, delta_y=2, color='m', start_end=(US_START, US_END))

In [None]:
def compute_lick_ratios(licks, cycles):
    lick_ratios = []
    for s, e in cycles:
        l = licks - s + CYCLE_START
        licks_during = ((l>CS_START)*(l<CS_END+DELAY)).sum()
        licks_all = ((l>-CS_DURATION-DELAY)*(l<CS_END+DELAY)).sum()
        lick_ratios.append(1.*licks_during/licks_all if licks_all>0 else -1)
    return np.r_[lick_ratios]
lick_ratios = compute_lick_ratios(licks, cycles)

In [None]:
def plot_lick_ratios(lick_ratios, true_rwt, true_CSm, axs=None, colors=['b', 'r']):
    
    if axs is None:
        fig, axs = pl.subplots(1, 2, gridspec_kw={'width_ratios':(1, 5)}, sharey=True)
    ax = axs[1]
    y, bins, patches = ax.hist([lick_ratios[true_rwt], lick_ratios[true_CSm]], color=colors,
               bins=np.arange(0, 1.1, .25))
    ax.text(0.05, 12, 'BLUE = reward, RED = CSminus')
    ax.set_xlabel("Anticipatory licking ratio (CS+trace/baseline+CS+trace)")
    ax.set_xticks(bins)
    # ax.set_xticklabels(['no licks', 0, 0.5, 1])
    ax = axs[0]
    ax.hist([lick_ratios[true_rwt], lick_ratios[true_CSm]], color=['b', 'r'],
               bins=np.arange(-1, 0, 0.25))
    ax.set_xlim(-1, -0.75)
    ax.set_xticks((-0.875,))
    ax.set_xticklabels(('no licks',))
    ax.set_ylabel('Frequency')
    
    return axs
plot_lick_ratios(lick_ratios, true_rwt, true_CSm, axs=None)

## THE BELOW LICK RATIOS ARE NOT READY FOR PRIME TIME - JSB

In [None]:
def compute_lick_ratios_B(licks, cycles):
    lick_ratios_B = []
    for s, e in cycles:
        l = licks - s + CYCLE_START
        licks_during = ((l>CS_START)*(l<CS_END+DELAY)).sum()
        licks_all = ((l>-CS_DURATION-DELAY)*(l<CS_END+DELAY)).sum()
        lick_ratios_B.append(1.*licks_during/licks_all if licks_all>0 else 0)
    return np.r_[lick_ratios_B]
lick_ratios_B = compute_lick_ratios_B(licks, cycles)

In [None]:
np.mean(lick_ratios_B[is_rwt])

In [None]:
np.mean(lick_ratios_B[is_CSm])

In [None]:
#Plot CNMFe ROIs
ig, ax = pl.subplots(1, 2, figsize=(8, 8))
colors = pl.cm.rainbow(np.linspace(0, 1, 35))
ax[0].set_xticks(())
ax[0].set_yticks(())
ax[0].imshow(Cnn, alpha=1, cmap=pl.cm.gray)
ax[0].set_title('avg image')
#CNMFe outputs roi coordinates in Coor with a bunch of false coordinates that make your ROIs look weird. 
#This code is to get rid of these outliers.
for neuron in xrange(Coor.shape[0]):
    for dim in xrange(2):
        mask=np.where(abs(np.diff(Coor[neuron][dim]))>10)
        for x in mask:
            x+=1
        Coor[neuron][dim][mask]=np.nan
#find number of coordinates in smallest ROI (store this value in min_coor)
min_coor=1000
for neuron in Coor:
    temp=neuron.shape[1]
    if temp<min_coor:
        min_coor=temp
#plot Cnn, which is the correlation image and overlay the ROIs 
patches=[]
pl.imshow(Cnn, cmap=pl.cm.gray)
for neuron in xrange(Coor.shape[0]):
    polygon = Polygon(np.transpose(Coor[neuron][:,10:]))
    patches.append(polygon)
    ax[1].text(Coor[neuron][0,min_coor-1], Coor[neuron][1,min_coor-1], neuron+1, color='y',size=8)
polygon = Polygon(np.transpose(Coor[18][:,10:]))
patches.append(polygon)
ax[1].text(Coor[18][0,min_coor-1], Coor[18][1,min_coor-1], 21+1, color='y',size=8)
colors = 100*np.random.rand(len(patches))
p = PatchCollection(patches, alpha=0.4)
p.set_array(np.array(colors))
ax[1].add_collection(p)
ax[1].set_title('CNMFe')
ax[1].yaxis.set_ticks([])
ax[1].xaxis.set_ticks([])

## df/f

In [None]:
pl.figure(figsize=(5, 12))
[pl.plot(time_ax/60., dff[:, cell]/50.+cell,lw=.5) for cell in range(dff.shape[1])];
pl.xlabel("time (min)")

In [None]:
behavior

In [None]:
pl.figure(figsize=(10, 25))
[pl.plot(time_ax/60., denoised[:, cell]/50.+cell,lw=1) for cell in range(dff.shape[1])];
pl.xlabel("time (min)")

In [None]:
licks = ut.parse_behavior(behavior, 'LICK')

In [None]:
def plot_it(cell=0, signals=0,):
    if signals==0:
        signal = traces
        print'raw signal'
        delta_y = 20
        lick_length = 20
    elif signals==1:
        signal = denoised
        print'denoised signal'
        delta_y = 20
        lick_length = 20
    elif signals==2:
        signal = events
        print'events signal'
        delta_y = 20
        lick_length = 20
    elif signals == 3:
        signal = dff_zs
        print'z-scored raw'
        delta_y = 3
        lick_length = 3
        
    all_dffs = ut.compute_all_dffs(time_ax, signal, cell=cell, cycles=cycles, time_ax_single=time_ax_single)
    
    pl.figure(figsize=(5, 10))
    
    
    for cycle in xrange(len(cycles)):
        pl.plot(time_ax_single, all_dffs[cycle]-cycle*delta_y, color='b' if is_rwt[cycle] else 'r')
        which_licks = ut.search_events(cycles, cycle, event_times=licks)
        pl.vlines(licks[which_licks]-cycles[cycle][0]-8, -cycle*delta_y, -cycle*delta_y-lick_length,
#                   lw=1, color=(0.5, 0.5, 1) if is_rwt[cycle] else (1, 0.5, 0.5), zorder=0)
                  lw=1, color=(0.5, 0.5, 0.5), zorder=0)

    ypos = pl.axis()[-1]
    pt.plot_period_bar(pl.axes(), ypos, color='g', start_end=(0, 2), delta_y=lick_length)
    pt.plot_period_bar(pl.axes(), ypos, color='m', start_end=(4, 6), delta_y=lick_length)

interact(plot_it, cell=(0, dff.shape[1]-1, 1),signals=(0,3,1))

In [None]:
pl.figure(figsize=(10, 15))
pl.imshow(dff.T, extent=(time_ax[0]/60., time_ax[-1]/60., 0, dff.shape[1]),
          aspect='auto', interpolation='nearest', vmin=0,
          cmap=pl.cm.viridis);
pl.xlabel('Time (m)')
pl.ylabel('Cell #')

In [None]:
signals = 0

if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events signal'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'

In [None]:
#This cell can take several minutes to complete. Plots every Ca trace for every cell (trialxtime for each cell)
ncells_x =  4
ncells_y = int(np.ceil(dff.shape[1]/4.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y


if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events signal'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'


def plot_all_cycles(cell, ax, showlicks = False):
#     ax = pl.subplots(1, figsize=(8, 5.5))[1];
    

    offset = 0
    for cycle in xrange(len(cycles)):
        t, tr = extract_single_cycle(time_ax, signal, cycles, cycle, cell)
        ax.plot(t, tr+offset, color='red' if is_CSm[cycle] else 'blue')
#         if showlicks == True:
#             pl.vlines(lick_timestamps[cycle], offset, offset+.3, color='k', alpha = 1, zorder = 0, linewidth = 0.3)
#             pl.vlines(reward_timestamps[cycle], offset, offset+.3, color='r', alpha = 1, zorder = 0, linewidth = 0.8)
        offset = offset + 20
        
    ax.set_ylim([-0.5, offset])
    ax.fill_between([CS_START, CS_START+CS_DURATION],
                    -0.5, offset, color='g', alpha = 0.2)
    ax.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    -0.5, offset, color='m', alpha = 0.2)
#     print "num trials=", sum(TRIAL_FILTER)
    return ax


def plot_summary(ncells_x, ncells_y, cells,
                 rescalex=2, rescaley=1, cmap=pl.cm.rainbow,
                 cs_start_end=(0, 2), us_start_end=(4, 6), cs_color='g', us_color='m'):
    fig, axs = pl.subplots(ncells_y, ncells_x, sharex=True, sharey=True,
                           figsize=(ncells_x*rescalex, ncells_y*rescaley))
    
    
    for cell, ax in zip(cells, axs.flatten()):
        ax = plot_all_cycles(cell, ax)
#         ax = plot_mean(time_ax_single, traces_means, traces_std, cell-1, ax=ax, color=col)
#         plot_period_bar(ax, -0.1, 0.02, color=cs_color, start_end=cs_start_end)
#         plot_period_bar(ax, -0.1, 0.02, color=us_color, start_end=us_start_end)
        ax.text(CYCLE_DURATION + CYCLE_START, len(cycles)-2, cell+1)
        
        
#     ax.set_xticks((time_ax_single[0], time_ax_single[-1]))
#     ax.set_yticks((-.15, 0.5))
#     ax.set_xlim((time_ax_single[0], time_ax_single[-1]))
#     ax.set_ylim((-.15, 0.5))
    # ax.set_xlabel('Time (s)')
    # ax.set_ylabel('Cell #')
    return fig, ax


# interact(plot_all_cycles, cell=(0, dff.shape[1]-1, 1));

plot_summary(10, 10, xrange(dff.shape[1]),
                 rescalex=5, rescaley=5, cmap=pl.cm.rainbow,
                 cs_start_end=(0, 2), us_start_end=(4, 6), cs_color='g', us_color='m')

In [None]:
def plot_single_cycle(cell=0, cycle=0, signals=0):
    if signals==0:
        signal = traces
        print'raw signal'
    elif signals==1:
        signal = denoised
        print'denoised signal'
    elif signals==2:
        signal = events
        print'events signal'
    elif signals == 3:
        signal = dff_zs
        print'z-scored raw'

#     cell = 0
#     cycle = 0
#     cycle_len = filter_cycle(cycle).sum()
    t, tr = extract_single_cycle(time_ax, signal, cycles, cycle, cell)
    pl.plot(t, tr, color='red' if is_CSm[cycle] else 'blue')
    pl.ylim(-5, 70)
#     pl.xlim(CYCLE_START, CYCLE_START+CYCLE_DURATION)
    pl.fill_between([CS_START, CS_START+CS_DURATION],
                    -0.35, -0.25, color='g')
    pl.vlines(CS_START, -0.35, 2, color='g')
    pl.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    -0.35, -0.25, color='m')
    pl.vlines(CS_START+CS_DURATION+DELAY, -0.35, 2, color='m')
interact(plot_single_cycle, cycle=(0, max_cycles-1, 1), cell=(0, dff.shape[1]-1, 1),signals=(0,3,1))
# plot_single_cycle(3, 10)

In [None]:
def combine_cycles(time_ax, dff, cycles, cell, lim_len=-1):
    max_cycles = len(cycles)
    return np.r_[[dff[:, cell-1][ut.filter_cycle(time_ax, cycles, cycle)][:lim_len]
                  for cycle in range(max_cycles)]]

In [None]:
all_dffs = combine_cycles(time_ax, signal, cycles, cell, lim_len=len(time_ax_single))

In [None]:
sorts = []
for cell in xrange(dff.shape[1]):
    all_dffs = combine_cycles(time_ax, signal, cycles, cell, lim_len=len(time_ax_single))
    sorts.append(np.argsort([time_ax_single[np.argmax(all_dffs[cycle])] for cycle in xrange(max_cycles)]))

In [None]:
def plot_me(cell=0, signals=0, sort=False, reward=True):
   
    if signals==0:
        signal = traces
        print'raw signal'
    elif signals==1:
        signal = denoised
        print'denoised signal'
    elif signals==2:
        signal = events
        print'events signal'
    elif signals == 3:
        signal = dff_zs
        print'z-scored raw'
    try:
        all_dffs = combine_cycles(time_ax, signal, cycles, cell, lim_len=len(time_ax_single))
        if sort:
            all_dffs = all_dffs[sorts[cell]]
            if reward:
                which_ones = np.where(np.r_[is_rwt][sorts[cell]])
            else:
                which_ones = np.where(np.r_[is_CSm][sorts[cell]])
        else:
            if reward:
                if cycle_subtract !=0:
                    which_ones = np.where((is_rwt)[:cycle_subtract])
                else:
                    which_ones = np.where(is_rwt)
            else:
                if cycle_subtract !=0:
                    which_ones = np.where((is_CSm)[:cycle_subtract])
                else:
                    which_ones = np.where(is_CSm)
                
    except IndexError:
        print "There are no such trials."
        all_dffs = [[]]
    
        
    pl.imshow(all_dffs[which_ones], extent=(CYCLE_START, CYCLE_START+CYCLE_DURATION, 0.5, max_cycles+0.5),
              origin='lower', cmap=pl.cm.hot, aspect='auto', interpolation='nearest', vmin=-5, vmax=50)
    pl.xlabel('Time (s)')
    pl.ylabel('Trial #')
    pl.fill_between([CS_START, CS_START+CS_DURATION],
                    max_cycles, max_cycles+1, color='g')
    pl.vlines(CS_START, 0, max_cycles, color='g')
    pl.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    max_cycles, max_cycles+1, color='m')
    pl.vlines(CS_START+CS_DURATION+DELAY, 0, max_cycles, color='m')
    pl.plot([CYCLE_START, CYCLE_START+CYCLE_DURATION], [0.5, max_cycles+0.5], 'k--')
    pl.ylim(0.5, max_cycles+1.)
interact(plot_me, cell=(0, dff.shape[1], 1), signals=(0,3,1), reward=True, sort=True)
# plot_me(18, reward=True)

In [None]:
def plot_me(ax, cell=0, sort=False, reward=True):
    if reward:
        if cycle_subtract !=0:
            which_ones = np.where((is_rwt)[:cycle_subtract])
        else:
            which_ones = np.where(is_rwt)
    else:
        if cycle_subtract !=0:
            which_ones = np.where((is_CSm)[:cycle_subtract])
        else:
            which_ones = np.where(is_CSm)
    try:
        all_dffs = combine_cycles(time_ax, dff, cycles[which_ones], cell, lim_len=len(time_ax_single))
        if sort:
            all_dffs = all_dffs[sorts[cell]]
    except IndexError:
        print "There are no such trials."
        all_dffs = [[]]
        
    ax.imshow(all_dffs, extent=(CYCLE_START, CYCLE_START+CYCLE_DURATION, 0.5, max_cycles+0.5),
              origin='lower', cmap=pl.cm.hot, aspect='auto', interpolation='nearest', vmin=-5, vmax=50)
#     ax.set_xlabel('Time (s)')
#     ax.set_ylabel('Trial #')
    ax.fill_between([CS_START, CS_START+CS_DURATION],
                    max_cycles, max_cycles+1, color='g')
    ax.vlines(CS_START, 0, max_cycles, color='g')
    ax.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    max_cycles, max_cycles+1, color='m')
    ax.vlines(CS_START+CS_DURATION+DELAY, 0, max_cycles, color='m')
    ax.plot([CYCLE_START, CYCLE_START+CYCLE_DURATION], [0.5, max_cycles+0.5], 'k--')
    ax.set_ylim(0.5, max_cycles+1.)

fig, axs = pl.subplots(int(np.ceil(dff.shape[1]/7.)),7, figsize=(7, int(np.ceil(dff.shape[1]/7.))), sharex=True, sharey=True)
for cell, ax in enumerate(axs.flatten()):
    plot_me(ax, cell)
    
print"note: cell 0 below appears to correspond to the last cell in the mean/SD trace figure below. ie, cell0 here does not = cell1 below (cell1 here = cell1 below)"

In [None]:
print"note: cell 0 below appears to correspond to the last cell in the mean/SD trace figure below. ie, cell0 here does not = cell1 below (cell1 here = cell1 below)"
def plot_me(ax, cell=0, sort=False, reward=True, signals=2):
    
    if signals==0:
        signal = traces
        print'raw signal'
    elif signals==1:
        signal = denoised
        print'denoised signal'
    elif signals==2:
        signal = events
        print'events'
    elif signals == 3:
        signal = dff_zs
        print'z-scored raw'
    try:
        all_dffs = combine_cycles(time_ax, signal, cycles, cell, lim_len=len(time_ax_single))
        if sort:
            all_dffs = all_dffs[sorts[cell]]
            if reward:
                which_ones = np.where(np.r_[is_rwt][sorts[cell]])
            else:
                which_ones = np.where(np.r_[is_CSm][sorts[cell]])
        else:
            if reward:
                if cycle_subtract !=0:
                    which_ones = np.where((is_rwt)[:cycle_subtract])
                else:
                    which_ones = np.where(is_rwt)
            else:
                if cycle_subtract !=0:
                    which_ones = np.where((is_CSm)[:cycle_subtract])
                else:
                    which_ones = np.where(is_CSm)
    except IndexError:
        print "There are no such trials."
        all_dffs = [[]]
    

    ax.imshow(all_dffs[which_ones], extent=(CYCLE_START, CYCLE_START+CYCLE_DURATION, 0.5, max_cycles+0.5),
              origin='lower', cmap=pl.cm.hot, aspect='auto', interpolation='nearest', vmin=-5, vmax=50)
#     ax.set_xlabel('Time (s)')
#     ax.set_ylabel('Trial #')
    ax.fill_between([CS_START, CS_START+CS_DURATION],
                    max_cycles, max_cycles+1, color='g')
    ax.vlines(CS_START, 0, max_cycles, color='g')
    ax.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    max_cycles, max_cycles+1, color='m')
    ax.vlines(CS_START+CS_DURATION+DELAY, 0, max_cycles, color='m')
    ax.plot([CYCLE_START, CYCLE_START+CYCLE_DURATION], [0.5, max_cycles+0.5], 'k--')
    ax.set_ylim(0.5, max_cycles+1.)

fig, axs = pl.subplots(int(np.ceil(dff.shape[1]/7.)),7, figsize=(7, int(np.ceil(dff.shape[1]/7.))), sharex=True, sharey=True)
for cell, ax in enumerate(axs.flatten()):
    plot_me(ax, cell)

In [None]:
def plot_em(traces_means, ax):
    ax.imshow(traces_means.T, extent=(time_ax_single[0], time_ax_single[-1], 0.5, dff.shape[1]+.5),
              origin='lower', cmap=pl.cm.gray_r, aspect='auto', interpolation='nearest')
    ax.fill_between([CS_START, CS_START+CS_DURATION],
                    (dff.shape[1]+1., dff.shape[1]+1.),
                    (dff.shape[1]+0.5, dff.shape[1]+0.5),
                    color='g')
    ax.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    (dff.shape[1]+1., dff.shape[1]+1.),
                    (dff.shape[1]+0.5, dff.shape[1]+0.5),
                    color=which)

In [None]:
# Mean responses
fig, axs = pl.subplots(5, 1, figsize=(6, 25), sharex=True, sharey=True)

signals=3
if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'

traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, cycles, time_ax_single)
which='m'
plot_em(traces_means, axs[0])
axs[0].set_title('All trials')

if cycle_subtract !=0:
    suba = cycles[np.where((is_CSm)[:cycle_subtract])]
    subr = cycles[np.where((is_rwt)[:cycle_subtract])]
    subrd = cycles[np.where((is_rewarded)[:cycle_subtract])]
    subnr = cycles[np.where((is_not_rewarded)[:cycle_subtract])]
else:
    suba = cycles[np.where(is_CSm)]
    subr = cycles[np.where(is_rwt)]
    subrd = cycles[np.where(is_rewarded)]
    subnr = cycles[np.where(is_not_rewarded)]

traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, suba, time_ax_single)
which='r'
plot_em(traces_means, axs[1])
axs[1].set_title('CSminus')
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, subr, time_ax_single)
which='c'
plot_em(traces_means, axs[2])
axs[2].set_title('Reward')
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, subrd, time_ax_single)
which='b'
plot_em(traces_means, axs[3])
axs[3].set_title('Rewarded')
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, subnr, time_ax_single)
which='b'
plot_em(traces_means, axs[4])
axs[4].set_title('Not Rewarded')
axs[4].set_xlabel('Time (s)')
axs[4].set_ylabel('Cell #')

In [None]:
def plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y, cells,
                 rescalex=2, rescaley=1, cmap=pl.cm.rainbow,
                 cs_start_end=(0, 2), us_start_end=(4, 6), cs_color='r', us_color='g',
                 ylim=(-0.15, 0.5), bary=(-0.1, 0.02), textxy=(0, 0.3)):
    colors = cmap(np.linspace(0, 1, traces_means.shape[1]))
    fig, axs = pl.subplots(ncells_y, ncells_x, sharex=True, sharey=True,
                           figsize=(ncells_x*rescalex, ncells_y*rescaley))
    for cell, ax, col in zip(cells, axs.flatten(), colors[cells[0]-1:]):
        ax = pt.plot_mean(time_ax_single, traces_means, traces_std, cell-1, ax=ax, color=col)
        pt.plot_period_bar(ax, bary[0], bary[1], color=cs_color, start_end=cs_start_end)
        pt.plot_period_bar(ax, bary[0], bary[1], color=us_color, start_end=us_start_end)
        ax.text(textxy[0], textxy[1], cell)
        ax.vlines([0], [ylim[0]], [ylim[1]], lw=1, zorder=0)
    ax.set_xticks((time_ax_single[0], time_ax_single[-1]))
    ax.set_yticks(ylim)
    ax.set_xlim((time_ax_single[0], time_ax_single[-1]))
    ax.set_ylim(ylim)
    
    # ax.set_xlabel('Time (s)')
    # ax.set_ylabel('Cell #')    
    return fig, ax


In [None]:
signals=1
rewarded=0

if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'
if rewarded==0:
    if cycle_subtract !=0:
        rewar = cycles[np.where((is_rwt)[:cycle_subtract])]
    else:
        rewar=cycles[np.where(is_rwt)]
    print'all reward trials'
if rewarded==1:
    if cycle_subtract !=0:
        rewar = cycles[np.where((is_rewarded)[:cycle_subtract])]
    else:
        rewar=cycles[np.where(is_rewarded)]
    print'all rewarded trials'
if rewarded==2:
    if cycle_subtract !=0:
        rewar = cycles[np.where((is_not_rewarded)[:cycle_subtract])]
    else:
        rewar=cycles[np.where(is_not_rewarded)]
    print'all unrewarded reward trials'
    
ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, rewar, time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y, cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color='g', us_color='b', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
cycles.shape

In [None]:
cycles[:-1][np.where(is_rwt)[:-1]].shape

In [None]:
signals=1

if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'
    
ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, cycles[np.where(is_CSm)], time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y, cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color='g', us_color='r', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
first_reward_times = np.r_[[rewards[np.where((rewards-s)>0)[0][0]]-s+
                            CYCLE_START
                            for s, e in cycles[np.where(is_rewarded)]]]

In [None]:
first_reward_times

In [None]:
# first_reward_times_adjustment = first_reward_times-CS_DURATION-DELAY
first_reward_times_adjustment = first_reward_times
#test whether this works - should recapitulate the figure above
pl.hist(first_reward_times_adjustment, bins=np.arange(3.9, 6, 0.25));

In [None]:
pl.bar(range(len(first_reward_times_adjustment)), np.r_[first_reward_times_adjustment])
pl.xlabel('Time adjustment (s)')
pl.ylabel('Adjustment')
pl.ylim(4, 6)

In [None]:
cycles_shifted_first_reward = cycles[np.where(is_rewarded)]+first_reward_times_adjustment[:, None]

In [None]:
cycles[is_rewarded][:10]

In [None]:
cycles_shifted_first_reward[:10]

In [None]:
t, tr = extract_single_cycle(time_ax, dff, cycles_shifted_first_reward, 0, 0)

In [None]:
max_cycles

In [None]:
def plot_single_cycle(cell=0, cycle=0, signals=0):
    if signals==0:
        signal = traces
        print'raw signal'
    elif signals==1:
        signal = denoised
        print'denoised signal'
    elif signals==2:
        signal = events
        print'events'
    elif signals == 3:
        signal = dff_zs
        print'z-scored raw'

    t, tr = extract_single_cycle(time_ax, signal, cycles_shifted_first_reward, cycle, cell)
    pl.plot(t, tr, color='k')
    pl.ylim(-5, 50)
#     pl.xlim(CYCLE_START, CYCLE_START+CYCLE_DURATION)
    pl.vlines(CS_START+CS_DURATION+DELAY, -5, 50, color='b')
interact(plot_single_cycle, cycle=(0, cycles[np.where(is_rewarded)].shape[0], 1), cell=(0, dff.shape[1]-1, 1), signals=(0,3,1))
pl.title('Rewarded trials shifted for reward onset')
# plot_single_cycle(3, 10)

In [None]:
lim_len = 101 #need to adjust this to truncated length of signal that's plotted so that there are no blanks (due to shifting of signals during reward alignment) 

signals=0

if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'
    
def plot_me(ax, cycles, cell=0, sort=False):
    try:
        all_dffs = combine_cycles(time_ax, signal, cycles, cell, lim_len=lim_len)
        if sort:
            all_dffs = all_dffs[sorts[cell]]
    except IndexError:
        print "There are no such trials."
        all_dffs = [[]]
        
    ax.imshow(all_dffs, extent=(time_ax_single[0], time_ax_single[lim_len], 0.5, max_cycles+0.5),
              origin='lower', cmap=pl.cm.hot, aspect='auto', interpolation='nearest', vmin=-5, vmax=50)
#     ax.set_xlabel('Time (s)')
#     ax.set_ylabel('Trial #')
#    ax.fill_between([CS_START, CS_START+CS_DURATION],      #Tone onset is not uniform due to reward-onset shift in traces
#                    max_cycles, max_cycles+1, color='r')
#    ax.vlines(CS_START, 0, max_cycles, color='r')
#    ax.fill_between([CS_START+CS_DURATION+DELAY,
#                     CS_START+CS_DURATION+DELAY+US_DURATION],
#                    max_cycles, max_cycles+1, color='b')
    ax.vlines(0, 0, max_cycles, color='b')
    ax.plot([CYCLE_START, CYCLE_START+CYCLE_DURATION], [0.5, max_cycles+0.5], 'k--')
    ax.set_ylim(0.5, max_cycles+1.)

fig, axs = pl.subplots(int(np.ceil(dff.shape[1]/7.)),7, figsize=(7, int(np.ceil(dff.shape[1]/7.))), sharex=True, sharey=True)
for cell, ax in enumerate(axs.flatten()):
    plot_me(ax, cycles_shifted_first_reward, cell)
pl.suptitle('Aligned to REWARDED LICK onset')

In [None]:
if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'
    
ax.vlines(0, 0, max_cycles, color='g')
fig, axs = pl.subplots(int(np.ceil(dff.shape[1]/7.)),7, figsize=(7, int(np.ceil(dff.shape[1]/7.))), sharex=True, sharey=True)
for cell, ax in enumerate(axs.flatten()):
    plot_me(ax, cycles[is_rewarded], cell)
pl.suptitle('Rewarded trials aligned to TONE onset')
    

In [None]:
print'Rewarded trials shifted for reward onset'
signals=0

if signals==0:
    signal = traces
    print'raw signal'
elif signals==1:
    signal = denoised
    print'denoised signal'
elif signals==2:
    signal = events
    print'events'
elif signals == 3:
    signal = dff_zs
    print'z-scored raw'

if cycle_subtract !=0:
        rewar = cycles[np.where((is_rwt)[:cycle_subtract])]
else:
        rewar=cycles_shifted_first_reward   

ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, signal, rewar, time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y,
             cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color=[0]*4, us_color='b', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
all_dffs = np.zeros((len(cycles), dff.shape[1], len(time_ax_single)))

In [None]:
all_dffs.shape

In [None]:
#Are we looking at raw traces, denoised traces, or events?
signals=3

if signals==0:
    for i, s in enumerate(cycles[:, 0]):
        for cell in range(dff.shape[1]):
            t, tr = extract_single_cycle(time_ax, dff, cycles, i, cell, )
            all_dffs[i][cell] = tr[:len(time_ax_single)]
elif signals==1:
    for i, s in enumerate(cycles[:, 0]):
        for cell in range(dff.shape[1]):
            t, tr = extract_single_cycle(time_ax, denoised, cycles, i, cell, )
            all_dffs[i][cell] = tr[:len(time_ax_single)]
elif signals==2:
    for i, s in enumerate(cycles[:, 0]):
        for cell in range(dff.shape[1]):
            t, tr = extract_single_cycle(time_ax, events, cycles, i, cell, )
            all_dffs[i][cell] = tr[:len(time_ax_single)]
elif signals==3:
    for i, s in enumerate(cycles[:, 0]):
        for cell in range(dff.shape[1]):
            t, tr = extract_single_cycle(time_ax, dff_zs, cycles, i, cell, )
            all_dffs[i][cell] = tr[:len(time_ax_single)]
    

In [None]:
def plot_me(cycle=0):
    
    pl.imshow(all_dffs[cycle], aspect='auto', interpolation='nearest', cmap=pl.cm.hot, vmin=0,
              extent=(time_ax_single[0], time_ax_single[-1], 0, dff.shape[1]))
    if is_rewarded[cycle]:
        pl.plot([-8], [40], 'bo', ms=13)
    pl.ylabel("Cell #")
    pl.fill_between([CS_START, CS_START+CS_DURATION],
                dff.shape[1], dff.shape[1]+1, color='g')
    pl.vlines(CS_START, 0, dff.shape[1], color='g')
    pl.fill_between([CS_START+CS_DURATION+DELAY,
                     CS_START+CS_DURATION+DELAY+US_DURATION],
                    dff.shape[1], dff.shape[1]+1, color='m')
    pl.vlines(CS_START+CS_DURATION+DELAY, 0, dff.shape[1], color='m')
    pl.xlabel("Time from tone onset (s)")
    pl.twinx()
#     pl.ylim(-60, 60)
    pl.plot(time_ax_single, all_dffs[cycle].sum(0), color='w')
    pl.ylabel("Population activity level")

interact(plot_me, cycle=(0, len(cycles), 1))
if signals==0:
    print'raw signal'
elif signals==1:
    print'denoised signal'
elif signals==2:
    print'events'

In [None]:
if signals==0:
    print'raw signal'
elif signals==1:
    print'denoised signal'
elif signals==2:
    print'events'
elif signals == 3:
    print'z-scored raw'
    
def plot_me(ax, cycle=0, sharey=True):
    ax.imshow(all_dffs[cycle], aspect='auto', interpolation='nearest', cmap=pl.cm.hot, vmin=0,
              extent=(time_ax_single[0], time_ax_single[-1], 0, dff.shape[1]))
    if is_rewarded[cycle]:
        ax.plot([-7], [40], 'bo', ms=7)
    pt.plot_period_bar(ax, 58, color='g', start_end=(CS_START, CS_END), delta_y=2)
    pt.plot_period_bar(ax, 58, color='m', start_end=(US_START, US_END), delta_y=2)
    ax.set_ylim(0, dff.shape[1]+2)
    ax = ax.twinx()
    ax.set_ylim(-50, 200)
    ax.plot(time_ax_single, all_dffs[cycle].sum(0), color='w')
    
fig, axs = pl.subplots(int(np.ceil(cycles.shape[0]/5.)),5, figsize=(8, 20))

for i, ax in enumerate(axs.flatten()):
    plot_me(ax, i)

In [None]:
if cycle_subtract !=0:
    all1 = all_dffs[np.where(is_rwt)[:cycle_subtract]]
    all2 = all_dffs[np.where(is_CSm)[:cycle_subtract]]
    all3 = all_dffs[np.where(is_rewarded)[:cycle_subtract]]
    all4 = all_dffs[np.where(is_not_rewarded)[:cycle_subtract]]
else:
    all1 = all_dffs[np.where(is_rwt)]
    all2 = all_dffs[np.where(is_CSm)]
    all3 = all_dffs[np.where(is_rewarded)]
    all4 = all_dffs[np.where(is_not_rewarded)]

In [None]:
all1.shape

In [None]:
cycles.shape

In [None]:
if signals==0:
    print'raw signal'
elif signals==1:
    print'denoised signal'
elif signals==2:
    print'events...is this correct?'
elif signals == 3:
    print'z-scored raw'
print'black = CSm; magenta= all reward trials; red = all unrewarded reward trials; blue = all rewarded trials'
m = np.mean(np.sum(all1, 1), 0)
s = np.std(np.sum(all1, 1), 0)/np.sqrt(all1.shape[0]-1)
pl.plot(time_ax_single, m, 'm')
pl.fill_between(time_ax_single, m-s, m+s, zorder=0, color="m", lw=0, alpha=0.1)
m = np.mean(np.sum(all2, 1), 0)
s = np.std(np.sum(all2, 1), 0)/np.sqrt(all2.shape[0]-1)
pl.plot(time_ax_single, m, 'k')
pl.fill_between(time_ax_single, m-s, m+s, zorder=0, color="k", lw=0, alpha=0.1)
m = np.mean(np.sum(all3, 1), 0)
s = np.std(np.sum(all3, 1), 0)/np.sqrt(all1.shape[0]-1)
pl.plot(time_ax_single, m, 'b')
pl.fill_between(time_ax_single, m-s, m+s, zorder=0, color="b", lw=0, alpha=0.1)
m = np.mean(np.sum(all4, 1), 0)
s = np.std(np.sum(all4, 1), 0)/np.sqrt(all1.shape[0]-1)
pl.plot(time_ax_single, m, 'r')
pl.fill_between(time_ax_single, m-s, m+s, zorder=0, color="r", lw=0, alpha=0.1)
ax = pl.axes()
pt.plot_period_bar(ax, 60, delta_y=2, color='g', start_end=(CS_START, CS_END))
pt.plot_period_bar(ax, 60, delta_y=2, color='m', start_end=(US_START, US_END))

In [None]:
lick_ratios = []
for s, e in cycles:
    l = licks - s - CS_DURATION - DELAY
    licks_during = ((l>CS_START)*(l<(CS_DURATION+DELAY))).sum()
    licks_all = ((l>CYCLE_START)*(l<(CS_DURATION+DELAY))).sum()
    lick_ratios.append(1.*licks_during/licks_all if licks_all>0 else -1)

In [None]:
pl.hist(lick_ratios, bins=30)
pl.xlim(0, 1)

In [None]:
t, lk = extract_single_cycle_signal(time_ax, lick_trace_conv, cycles, 0)
t, tr = extract_single_cycle(time_ax, dff, cycles, 0, 0)

In [None]:
def plot_me(cell=1, cycle=0):
    t, lk = extract_single_cycle_signal(time_ax, lick_trace, cycles, cycle)
    cell = cell - 1
    pl.plot(t, lk)
    pl.ylim(0, 5)
    pl.twinx()
    for cell in xrange(dff.shape[1]):
        t, tr = extract_single_cycle(time_ax, dff, cycles, cycle, cell)
        pl.plot(t, tr+cell*10, color='k')
    pl.ylim(0, 500)
    pl.text(20, 400, 'RW' if is_rwt[cycle] else "AP", color='r', fontsize=18)
interact(plot_me, cell=(1, dff.shape[1], 1), cycle=(0, len(cycles)-1, 1))

In [None]:
dff.shape

In [None]:
pl.plot(dff[:, -2])
pl.plot(dff[:, -1]+50)
pl.plot(dff[:, -3]+100)
pl.plot(dff[:, -4]+150)

In [None]:
def plot_me(ax, cell=1, cycle=0):
    t, lk = extract_single_cycle_signal(time_ax, lick_trace, cycles, cycle)
    cell = cell - 1
    
#     for cell in xrange(dff.shape[1]):
#         t, tr = extract_single_cycle(time_ax, dff, cycles, cycle, cell)
#         ax.plot(t, tr+cell*10, color='k')
    ax.imshow(all_dffs[cycle], cmap=pl.cm.viridis, interpolation='nearest', vmin=0,
              extent=(time_ax_single[0], time_ax_single[-1], 0, dff.shape[1]))
#     ax.set_ylim(0, 500)
    
    ax = ax.twinx()
    ax.plot(t, lk,'r')
    ax.set_ylim(0, 5)
    ax.text(10, 2, 'RW' if is_rwt[cycle] else "AP", color='w', fontsize=18)
    
fig, axs = pl.subplots(6, 7, figsize=(10, 10), sharex=True, sharey=True)
for i, ax in enumerate(axs.flatten()):
    plot_me(ax, cell=0, cycle=i)

In [None]:
lick_ratios = ut.compute_lick_ratios(licks, cycles, cycle_start=CYCLE_START, cs_start=CS_START,
                                     cs_end=CS_END, delay=DELAY, cs_duration=CS_DURATION, zero_value=-1)

In [None]:
lick_ratios

In [None]:
cr_learned = (is_rwt * (lick_ratios>=0.8)) + (is_CSm * (lick_ratios<0.2))
sum(cr_learned)

In [None]:
cr_not_learned = (is_rwt * (lick_ratios<0.8)) + (is_CSm * (lick_ratios>=0.2))
sum(cr_not_learned)

In [None]:
which_cycles = cycles[np.where(cr_not_learned)]

ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = 2#first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, dff, which_cycles,
                                                  time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y,
             cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color="g", us_color="m", ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
which_cycles = cycles[cr_learned]

ncells_x = 7
ncells_y =int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, dff, which_cycles,
                                                  time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y,
             cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color='g', us_color='m', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
cr_learned_rwt = (is_rwt * (lick_ratios>=0.8))

In [None]:
cr_not_learned_rwt = (is_rwt * (lick_ratios<0.8))
sum(cr_not_learned_rwt)

In [None]:
which_cycles = cycles[cr_learned_rwt]

ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, dff, which_cycles,
                                                  time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y,
             cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color='g', us_color='m', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
which_cycles = cycles[cr_not_learned_rwt]

ncells_x = 7
ncells_y = int(np.ceil(dff.shape[1]/7.))
first_cell = 1
last_cell = first_cell + ncells_x*ncells_y
traces_means, traces_std = ut.compute_mean_traces(time_ax, dff, which_cycles,
                                                  time_ax_single)
plot_summary(time_ax_single, traces_means, traces_std, ncells_x, ncells_y,
             cells=range(first_cell, last_cell),
                cs_start_end=(CS_START, CS_END), us_start_end=(US_START, US_END),
                cs_color='g', us_color='m', ylim=(-5, 20), bary=(15, 2), textxy=(-8, 15))

In [None]:
def compute_mean_level(time_ax, dff, cycles, cycle, cell, start, end, cycle_start=CYCLE_START):
    t, tr = extract_single_cycle(time_ax, dff, cycles, cycle, cell, cycle_start=CYCLE_START)
    return tr[(t>=start) * (t<end)].mean()

In [None]:
from sklearn.preprocessing import StandardScaler

In [None]:
dff_zs = StandardScaler().fit_transform(dff)

In [None]:
cell = 0
mean_cs = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=CS_START, end=CS_END)
                 for cycle in xrange(len(cycles))]]
mean_base = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=-CS_DURATION, end=CS_START)
                   for cycle in xrange(len(cycles))]]
mean_us = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=US_START, end=US_END)
                 for cycle in xrange(len(cycles))]]
mean_delay = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=CS_END, end=CS_END+DELAY)
                    for cycle in xrange(len(cycles))]]

In [None]:
(mean_us-mean_base[is_rwt]).mean()

In [None]:
def plot_it(ax, vals):
    w1 = np.where(is_rwt)
    w2 = np.where(~np.r_[is_rwt])
#     ax.plot(is_rwt*1+np.random.randn(len(cycles))*.01, vals, 'o',
#             color='0.8', mew=0)
    ax.plot([-0.1, 0.1], [(vals[w1]).mean()]*2, 'k-', zorder=19)
    ax.errorbar([0], [np.mean(vals[w1])], [np.std(vals[w1])/np.sqrt(np.sum(is_rwt)-1)],
                color='k', zorder=19)
    ax.plot([0.9, 1.1], [(vals[w2]).mean()]*2, 'k-', zorder=19)
    ax.errorbar([1], [np.mean(vals[w2])], [np.std(vals[w2])/np.sqrt(np.sum(~np.r_[is_rwt])-1)],
                color='k', zorder=19)

fig, axs = pl.subplots(5, 3, figsize=(5, 10), sharey=True)

for cell, a in zip(range(0, 5), axs):
    mean_cs = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=CS_START, end=CS_END)
                 for cycle in xrange(len(cycles))]]
    mean_base = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=-8, end=CS_START)
                   for cycle in xrange(len(cycles))]]
    mean_us = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=US_START, end=US_END)
                 for cycle in xrange(len(cycles))]]
    mean_delay = np.r_[[compute_mean_level(time_ax, dff_zs, cycles, cycle, cell, start=CS_END, end=CS_END+DELAY)
                    for cycle in xrange(len(cycles))]]
    ax = a[0]
    plot_it(ax, mean_cs-mean_base)
    ax.text(0.5, 1, 'CS', horizontalalignment='center')
    ax.set_ylabel(cell)
    ax = a[1]
    plot_it(ax, mean_delay-mean_base)
    ax.text(0.5, 1, 'delay', horizontalalignment='center')
    ax = a[2]
    plot_it(ax, mean_us-mean_base)
    ax.text(0.5, 1, 'US', horizontalalignment='center')

for ax in axs.flatten():
    ax.set_xticks((0, 1))
    ax.set_xticklabels(['rew.', 'CSm'])
    ax.set_xlim((-0.2, 1.2))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    

axs[0][0].text(-2, -5, 'Mean z-scored fluorescence', rotation=90)

In [None]:
save_workspace(db)