# Diggsietone is for associative learning experiments

# What's new in 2.2:
* added similarity_list.txt output
* changed baseline period for stability analysis from 4 - 2s prior to CS onset, to 2 - 0s prior
* avg heatmap now presents normalized averages for each cell
* multiple cosmetic changes to graphs/figures

# 2.3:
* fixed selectivity output that was incorrectly assigning US timeframe for trace selectivity analysis for rewarded and not rewarded CS+ trials

# 2.4:
* incorporated the quantity of "stable" (non selective) cells as part of the output file for cell selectivites
* percent selective output file created

### Importing everything necessary to run notebook

In [None]:

from IPython.display import display
from IPython.display import HTML
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)

# This line will hide code by default when the notebook is exported as HTML
di.display_html('<script>jQuery(function() {if (jQuery("body.notebook_app").length == 0) { jQuery(".input_area").toggle(); jQuery(".prompt").toggle();}});</script>', raw=True)

# This line will add a button to toggle visibility of code blocks, for use with the HTML export version
di.display_html('''<button onclick="jQuery('.input_area').toggle(); jQuery('.prompt').toggle();">Toggle code</button>''', raw=True)


In [None]:
# all modules necessary for this nb
import os
import sys
import pickle

import numpy as np
import pylab as pl
from sklearn.covariance import EmpiricalCovariance
from sklearn.cluster import KMeans, AffinityPropagation
from sklearn.metrics import silhouette_score as clust_score
from sklearn.preprocessing import StandardScaler
from scipy import stats as sstats

# setting parameters for default matplotlib plots
%matplotlib inline
pl.rcParams['savefig.dpi'] = 300 # dpi for most publications
pl.rcParams['figure.dpi'] = 300 # dpi for most publications
pl.rcParams['xtick.labelsize'] = 7
pl.rcParams['ytick.labelsize'] = 7
pl.rcParams['axes.labelsize'] = 7
from ipywidgets import interact

# needs to find the library of functions
sys.path.append('/home/fabios/code/forco/')  # to be replaced!

import utils as ut
import plots as pt

In [None]:
# a double percentage sign indicates a magic function. in this case, now we are writing this cell in javascript.

In [None]:
%%javascript
var nb = IPython.notebook;
var kernel = IPython.notebook.kernel;
var command = "NOTEBOOK_NAME = '" + nb.base_url + nb.notebook_path + "'";
kernel.execute(command);

In [None]:
NOTEBOOK_NAME = NOTEBOOK_NAME.split('/')[-1][:-6]

In [None]:
from pickleshare import PickleShareDB

autorestore_folder = os.path.join(os.getcwd(), 'autorestore', NOTEBOOK_NAME)
db = PickleShareDB(autorestore_folder)
import sys
sys.path.append('/home/fabios/code')
from workspace import *
import IPython
ip = IPython.get_ipython()

# this will restore all the saved variables. ignore the errors listed.
# load_workspace(ip, db)

# use `save_worspace(db)` to save variables at the end

## If already have run this nb and variables are saved, can run the following 3 cells to get selectivity list

In [None]:
#load_workspace(ip, db)

In [None]:
# def count_trials_from_key(k):
#     if k=='rewarded_tc' or k=='rewarded' or k=='rewarded_us' or k=='reward_onset':
#         title = np.sum(is_rewarded)
#     if k=='not_rewarded_us' or k=='not_rewarded' or k=='not_rewarded_tc':
#         title = np.sum(is_not_rewarded)
#     if k== 'reward_tc' or k=='reward_us' or k=='reward':
#         title = np.sum(is_rewardt)
#     if k== 'corrCSm_us' or k=='corrCSm_tc' or k=='corrCSm':
#         title = np.sum(is_corrCSmt)
#     if k== 'any_tc' or k=='any_cs':
#         title = len(cycles)
#     if k== 'errCSm_us' or k=='errCSm_tc' or k=='errCSm':
#         title = np.sum(is_errCSmt)
#     if k== 'CSm_us' or k=='CSm' or k=='CSm_tc':
#         title = np.sum(is_CSmt)
#     if k== 'pre_bouts':
#         title = len(pre_bouts)
#     if k== 'random_rewarded_errCSm_us':
#         title = len(is_random_rewarded_errCSmt)
#     if k == 'random_rewarded_pre_bouts_us':
#         title = len(is_random_rewarded_pre_boutst)

#     return title
# [(k, count_trials_from_key(k)) for k in np.sort(selectivity.keys())]

In [None]:
# outfile = os.path.join(data_folder, 'selectivity_list_2_3_w_percent.txt')
# significance = 0.01

# with open(outfile, 'w+') as f:
#     for k in keys:
#         corrected_ps = ut.adjust_pvalues([s[1] for s in selectivity[k][:, 0]])
#         n_up = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]>0))
#         n_down = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]<0))
#         n_total = events.shape[1]
#         n_stable = n_total - n_up - n_down
#         percent_up = (n_up*1.00/n_total*1.00)*100
#         percent_down = (n_down*1.00/n_total*1.00)*100
#         percent_stable = (n_stable*1.00/n_total*1.00)*100
#         n_trials = count_trials_from_key(k)
#         f.write("%s\n%d\n%.2f\n%d\n%.2f\n%d\n%.2f\n%d\n%d\n" % (k,n_up,percent_up,n_down,percent_down,n_stable,percent_stable,n_total,n_trials))
# #fig.tight_layout()

In [None]:
# importing all the definitions from the utils file
from sklearn.covariance import EmpiricalCovariance
def extract_patterns(time_ax, activity, CYCLE_START, STIM_START, STIM_END, mode='average'):
    n_cells = activity.shape[1]
    if mode == 'average':
        patterns = np.zeros((len(cycles), n_cells))
    elif mode == 'corr':
        cov_model = EmpiricalCovariance()
        patterns = np.zeros((len(cycles), n_cells, n_cells))
    # tas is a replacement for time_ax_single
    for i, (s, e) in enumerate(cycles):
        time_filter = ((time_ax>=(s-CYCLE_START+STIM_START)) * (time_ax<(s-CYCLE_START+STIM_END)))
        if mode == 'average':
            patterns[i] = activity[time_filter].mean(0)
        elif mode == 'corr':
            patterns[i] = cov_model.fit(activity[time_filter]).covariance_
    # patterns = patterns[:, which_cells]
    
    return patterns


def extract_single_cycle(time_ax, dff, cycles, cycle, cell,
                         cycle_start=-8):
    fc = ut.filter_cycle(time_ax, cycles, cycle)
    t0 = time_ax[fc][0]
    return time_ax[fc] - t0 + cycle_start, dff[:, cell][fc]


def extract_single_cycle_signal(time_ax, signal, cycles, cycle,
                         cycle_start=-8):
    fc = ut.filter_cycle(time_ax, cycles, cycle)
    t0 = time_ax[fc][0]
    return time_ax[fc] - t0 + cycle_start, signal[fc]


def extract_single_cycle_time_ax(time_ax, cycles, cycle_duration=4, cycle_start=-8):
    # single_cycle_time_bins = filter_cycle(time_ax, cycles, cycle).sum() 
    min_len = np.inf
    for i, c in enumerate(cycles):
        time_ax_single = time_ax[ut.filter_cycle(time_ax, cycles, i)]-time_ax[ut.filter_cycle(time_ax, cycles, i)][0]
        time_ax_single = time_ax_single[time_ax_single < cycle_duration] + cycle_start
        l = len(time_ax_single)
        if l<min_len:
            min_len = l
    return time_ax_single[:min_len]

def plot_heat_map(ax, time_ax, traces, cell, cycles, time_ax_single, normed=False, **args):
    alls = ut.compute_all_dffs(time_ax, traces, cell, cycles, time_ax_single)
    if normed:
        alls = alls/np.max(alls)
    ax.imshow(alls, origin='lower', aspect='auto',
              extent=(time_ax_single[0], time_ax_single[-1], 0, len(cycles)), **args)
    ymax = ax.axis()[-1]
    ax.vlines([STIM_START, STIM_END], 0, ymax, color='red')

def plot_heat_map_average(ax, time_ax, traces, cycles, time_ax_single, normed=False,
                          sorted=False, sorting=None, **args):
    alls = []
    for cell in xrange(traces.shape[1]):
        alls.append(ut.compute_all_dffs(time_ax, traces, cell, cycles[:-1], time_ax_single).mean(0))
    if normed:
        for i in xrange(len(alls)):
            alls[i] = alls[i] / np.max(alls[i])
    if not sorted:
        sorting = range(len(alls))
    else:
        if sorting is None:
            sorting = np.argsort([np.max(a) for a in alls])
        
    ax.imshow([alls[s] for s in sorting], origin='lower', aspect='auto',
              extent=(time_ax_single[0], time_ax_single[-1], 0, traces.shape[1]),
              **args)
    ymax = ax.axis()[-1]
    ax.vlines([STIM_START, STIM_END], 0, ymax, color='red')
    return sorting

def compute_auc_all_cycles(traces, time_ax, cycles, between=(0, 4)):
    tr_means = []
    for s, e in cycles:
        tr_means.append(traces[(time_ax>=(s-CYCLE_START+between[0])) * (time_ax<(s-CYCLE_START+between[1]))].mean())
    return np.r_[tr_means]

def decode(patterns, labels, which_cells=None, n_loops=10, n_jobs=1, cv=10):

    if which_cells is None:
        which_cells = [True] * patterns.shape[1]
    
    scores = []
    scores_chance = []

    ps = patterns[:, which_cells]
    ls = labels
    scores = cross_val_score(decoder, ps, ls, cv=cv, n_jobs=n_jobs)
    scores_chance = []
    for i in xrange(n_loops):
        scores_chance.append(cross_val_score(decoder, ps, np.random.permutation(ls), cv=cv, n_jobs=n_jobs))
    scores_chance = np.r_[scores_chance].flatten()
    return scores, scores_chance

def plot_behavior(ax, beh, event_types, **args):

    colors = pl.cm.rainbow(np.linspace(0, 1, len(event_types)))
    for i, (e, c) in enumerate(zip(event_types, colors)):
        which = beh[:, 1] == e
        if which.sum() > 0:
            ax.vlines(map(float, beh[beh[:, 1] == e][:, 0]), i, i+1, color=c, **args)
    ax.set_yticks(np.arange(4))
    ax.set_yticklabels(event_types)

    pt.nicer_plot(ax)
    
def compute_similarity(p1, p2):
    return p1.dot(p2)/np.sqrt(p1.dot(p1) * p2.dot(p2))

def extract_activity(time_ax, activity, cycles, CYCLE_START, STIM_START, STIM_END,
                     offset=0, which=None):
    if which is None:
        which = [True] * len(cycles)
    return np.r_[[activity[(time_ax>=(start-CYCLE_START+STIM_START+offset))
                           *(time_ax<(start-CYCLE_START+STIM_END+offset))].mean(0)
                  for start, stop in cycles[which]]]

def find_selective_cells(time_ax, events, cycles, cycle_start, stim_start, stim_end, offset=-4):
    during_odor = extract_activity(time_ax, events, cycles, cycle_start, stim_start, stim_end)
    pre_odor = extract_activity(time_ax, events, cycles, cycle_start, stim_start, stim_end,
                                offset=offset)
    return np.r_[[sstats.mannwhitneyu(p1, p2, alternative='two-sided')[-1] if any(p1!=p2) else .5
                           for p1, p2 in zip(pre_odor.T, during_odor.T)]]

def expand_event(time_ax, onsets, offsets):
    activation = np.zeros_like(time_ax)
    for on, off in zip(onsets, offsets):
        which = (time_ax>=on) * (time_ax<off)
        activation[which] = activation[which] + 1
    return activation

### Autorestoring folder and identifying files to load

In [None]:
# we go up one folder from autorestore_folder
a = autorestore_folder.split('/')[:-3]
# data_folder = '/media/data/DATA1/dg_odor_nwoods/adam/2dayhabit/d1//'
data_folder = os.path.join('/', *a)

In [None]:
from scipy.io import loadmat

In [None]:
#loadmat(os.path.join(data_folder, 'CNMFe/Coor.mat'))

In [None]:
# traces = np.loadtxt(os.path.join(data_folder, 'traces/C_raw.txt')).T
traces = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C.txt')).T #denoised traces
traces_raw = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C_raw.txt')).T
events = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/S.txt')).T
dff = np.loadtxt(os.path.join(data_folder, 'CNMFe Conservative/C_df.txt')).T
#Cnn=np.loadtxt('/media/data/DATA1/data/2p_mice_44-50/081117/TSeries-08112017-0742_45_235-001/CNMFe/Cnn.txt')
#Coor = sio.loadmat('/media/data/DATA1/data/2p_mice_44-50/081117/TSeries-08112017-0742_45_235-001/CNMFe/Coor.mat')['coor'][:,0]
#events = ut.event_detection_cnmfe_denoised(denoised)
#dff_zs = ut.zscore_traces(dff)
# areas = np.loadtxt(os.path.join(data_folder, 'area/A.txt')).T

##   MAKE SURE YOU HAVE THE PROPER PIXEL DIMENSIONS FOR YOUR VIDEO BELOW
##def load_spatial_footprints_A(A_file, shape=(421, 514)):
##    return np.loadtxt(A_file).T.reshape([-1, shape[0], shape[1]])
##areas = load_spatial_footprints_A(os.path.join(data_folder, 'area/A.txt'))
mean_image, contours = ut.load_spatial_footprints(os.path.join(data_folder, 'CNMFe Conservative/Coor.mat'),
                                                  os.path.join(data_folder, 'CNMFe Conservative/Cnn.txt'),
                                                  key='coor')

filename = os.path.join(data_folder, 'behavior.txt')
behavior = ut.read_behavior(filename)
events_list = np.unique([b[1] for b in behavior])

In [None]:
load_workspace(ip, db)

In [None]:
# grab time axis from the xml file

import xml.etree.ElementTree as ET
xmlfile = os.path.join(data_folder, 'tseries.xml')
print "I infer the time axis from:\n", xmlfile
tree = ET.parse(xmlfile)
root = tree.getroot()

# unfortunately we miss the first frame
time_ax = np.r_[[child.attrib['absoluteTime']
                 for child in root.iter('Frame')]].astype(float)

In [None]:
# sync times
start_2p = ut.parse_behavior(behavior, 'BEGIN')[0]
behavior = [[float(b[0])-start_2p, b[1]] for b in behavior]
time_ax -= time_ax[0]

In [None]:
print(time_ax.shape)
print(traces.shape)

In [None]:
#time_ax = time_ax[::4] # use this if video was averaged and need to adjust xml output to match

In [None]:
#time_ax = time_ax[0:20000] # use this if any video frames were truncated (often need to do this a video is averaged)

In [None]:
print(time_ax.shape)
print(traces.shape)

In [None]:
# make sure presentations are correct in timing.
behavior[:10]

In [None]:
# -----------------------------------------------------------
# these times are relative to the single cycle
# and centered around tone onset
CONTINUOUS = True
CYCLE_START = -8  # seconds
CS_START = 0  # seconds (let's keep this at 0)
CS_DURATION = 2  # seconds
DELAY = 2  # seconds
US_DURATION = 2  # seconds  // IS THIS FIXED?
AFTER_US_PERIOD = 5
REWARD_WIN = 2
CYCLE_DURATION = abs(CYCLE_START) + CS_DURATION + DELAY + US_DURATION + AFTER_US_PERIOD
CS_END = CS_START + CS_DURATION
US_START = CS_START + DELAY + CS_DURATION
US_END = US_START + US_DURATION

# -----------------------------------------------------------
# these times are absolute times, taken from the arduino file
# when the tones starts and ends
tone_CSm_ons = ut.parse_behavior(behavior, 'TONE_CSM')
tone_CSm_offs = ut.parse_behavior(behavior, 'TONE_CSM', offset=CS_DURATION)
tone_reward_ons = ut.parse_behavior(behavior, 'TONE_RW')
tone_reward_offs = ut.parse_behavior(behavior, 'TONE_RW', offset=CS_DURATION)
tone_airpuff_ons = ut.parse_behavior(behavior, 'TONE_AP')
tone_airpuff_offs = ut.parse_behavior(behavior, 'TONE_AP', offset=CS_DURATION)
tone_shock_ons = ut.parse_behavior(behavior, 'TONE_SHOCK')
tone_shock_offs = ut.parse_behavior(behavior, 'TONE_SHOCK', offset=CS_DURATION)
rewards = np.r_[ut.parse_behavior(behavior, 'REWARD')]
shocks = np.r_[ut.parse_behavior(behavior, 'SHOCK')]
airpuffs = np.r_[ut.parse_behavior(behavior, 'AIRPUFF')]
licks = np.r_[ut.parse_behavior(behavior, 'LICK')]

# -----------------------------------------------------------
# when the experiment starts and ends, in absolute time
# begin_end = ut.parse_behavior(behavior, '[be]')
# when each cycle starts and ends
# (last cycle is usually oddly recorded)
if CONTINUOUS:
    cycles_starts = ut.parse_behavior(behavior, 'TONE_*', offset=CYCLE_START)
    cycles_ends = ut.parse_behavior(behavior, 'TONE_*', offset=CYCLE_DURATION+CYCLE_START)
else:
    cycles_starts = ut.parse_behavior(behavior, 'BEGIN')
    cycles_ends = ut.parse_behavior(behavior, 'END')
cycle_subtract = 0   #do we need to subtract off the last cycle because it's too short???
if cycle_subtract !=0:
    cycles = np.r_[zip(cycles_starts,  # offset will be ADDED, with sign
                   cycles_ends)][:cycle_subtract]
else:
    cycles = np.r_[zip(cycles_starts,  # offset will be ADDED, with sign
                   cycles_ends)]
print 'we are subtracting off this many cycles'
print cycle_subtract
# -----------------------------------------------------------
# which trials are a.p. and which reward
is_CSmt = [any(map(lambda t: (t>=s) and (t<e), tone_CSm_ons)) for s, e in zip(cycles_starts, cycles_ends)]
is_rewardt = [any(map(lambda t: (t>=s) and (t<e), tone_reward_ons)) for s, e in zip(cycles_starts, cycles_ends)]
is_airpufft = [any(map(lambda t: (t>=s) and (t<e), tone_airpuff_ons)) for s, e in zip(cycles_starts, cycles_ends)]
is_shockt = [any(map(lambda t: (t>=s) and (t<e), tone_shock_ons)) for s, e in zip(cycles_starts, cycles_ends)]
true_CSmt = np.where([any(map(lambda t: (t>=s) and (t<e), tone_CSm_ons)) for s, e in cycles])[0]
true_rewardt = np.where([any(map(lambda t: (t>=s) and (t<e), tone_reward_ons)) for s, e in cycles])[0]
true_airpufft = np.where([any(map(lambda t: (t>=s) and (t<e), tone_airpuff_ons)) for s, e in cycles])[0]
true_shockt = np.where([any(map(lambda t: (t>=s) and (t<e), tone_shock_ons)) for s, e in cycles])[0]
# is_rewarded = [any(map(lambda t: (t>=s) and (t<e) and any((t-rewards)<(CS_DURATION+DELAY+REWARD_WIN)), tone_rw_ons))
#                for s, e in zip(cycles_starts, cycles_ends)]
is_rewarded = [any(map(lambda r: (r<e)*(r>=s), rewards))
               for s, e in zip(cycles_starts, cycles_ends)]
is_not_rewarded = is_rewardt * ~np.r_[is_rewarded]
is_shocked = [any(map(lambda r: (r<e)*(r>=s), shocks))
              for s, e in zip(cycles_starts, cycles_ends)]
is_not_shocked = is_shockt * ~np.r_[is_shocked]
is_airpuffed = [any(map(lambda r: (r<e)*(r>=s), airpuffs))
                for s, e in zip(cycles_starts, cycles_ends)]
is_not_airpuffed = is_airpufft * ~np.r_[is_airpuffed]
is_lickedust = [any(map(lambda r: (r<(s-CYCLE_START+US_START+US_DURATION))*(r>=(s-CYCLE_START+US_START)),
                       licks))
                for s, e in zip(cycles_starts, cycles_ends)]
#is_CSmerrt = is_lickedust * np.r_[is_CSmt]

# first half of the trials
is_first_half = np.r_[[False]*len(cycles)]
is_first_half[:len(cycles)/2] = True  # only take first half

# CS+ reward trials, only the first half
is_first_half_rewardt = is_rewardt * is_first_half

#Below, write code to extract n random trials from a condition (ie, 20 random CSm trials)
#is_random_rewarded_errCSmt = np.random.choice(np.where(is_rewardt)[0], size=is_errCSmt.sum(), replace=False)

# clean up artefact events at the beginning of each cycle
for s, e in cycles:
    events[time_ax==s] = 0

In [None]:
time_ax_single = extract_single_cycle_time_ax(time_ax, cycles,
                                              cycle_duration=CYCLE_DURATION, cycle_start=CYCLE_START)

# Licks

In [None]:
fig, axs = pl.subplots(3, 1, figsize=(2, 8), gridspec_kw={'height_ratios':(2, 1, 1)}, sharex=True)
pt.plot_licks(tone_reward_ons, licks, positions=true_rewardt, color='b', ax=axs[0])
pt.plot_licks(tone_CSm_ons, licks, positions=true_CSmt, color='r', ax=axs[0])
pt.plot_licks(tone_reward_ons, licks, color='b', ax=axs[1])
pt.plot_licks(tone_CSm_ons, licks, color='r', ax=axs[2])

# axs[0].set_title('CS+')
# plot_licks(tone_csm_ons, licks, ax=axs[1])
# axs[1].set_title('CS-')
# plot_licks(tone_ap_ons, licks, ax=axs[2])
# axs[2].set_title('AP')
# for ax in axs:
#ax.set_xlim(CYCLE_START, CYCLE_START+CYCLE_DURATION)
# ax.set_ylim(-10, len(cycles))
#ax.set_xlim(CYCLE_START, 15)
#ax.set_ylabel('Trial #')
#ax.set_xlabel("Time from tone onset (s)")


for ax in axs:
    ax.set_xlim(CYCLE_START, CYCLE_START+CYCLE_DURATION)
    ax.set_xlabel("Time from tone onset (s)")
    ax.set_ylabel('Trial #')
    pt.nicer_plot(ax)
    pt.plot_period_bar(ax, -3, delta_y=2, color='g', start_end=(CS_START, CS_END))
    pt.plot_period_bar(ax, -3, delta_y=2, color='m', start_end=(US_START, US_END))
    ymin, ymax = ax.axis()[2:]
    ax.vlines(CS_START, ymin, ymax, lw=1, color='g')
#     ax.vlines(CS_END, ymin, ymax, lw=1, color='g')
    ax.vlines(US_START, ymin, ymax, lw=1, color='m')
#     ax.vlines(US_END, ymin, ymax, lw=1, color='m')
    
#axs[2].set_xlabel('Time from CS onset (s)')

In [None]:
licks_bs = 1.*ut.compute_licks_during(licks, cycles,
                                      start=-CYCLE_START-DELAY-US_DURATION,
                                      end=-CYCLE_START)  # w.r.t. cycle start
licks_cs = 1.*ut.compute_licks_during(licks, cycles,
                                      start=-CYCLE_START,
                                      end=-CYCLE_START+CS_DURATION)
licks_tc_us = 1.*ut.compute_licks_during(licks, cycles,
                                         start=-CYCLE_START+CS_DURATION,
                                         end=-CYCLE_START+CS_DURATION+DELAY+US_DURATION)

In [None]:
lickrates_bs = 1.*licks_bs/(DELAY+US_DURATION)
lickrates_bs

In [None]:
lick_ratios = np.nan_to_num(1.*(licks_tc_us-licks_bs)/(licks_tc_us+licks_bs))

In [None]:
good_lick_trials = (licks_bs+licks_tc_us)>=5

In [None]:
pl.hist(lick_ratios[((licks_bs+licks_tc_us)>5) * is_rewardt], histtype='step', bins=50)
pl.hist(lick_ratios[((licks_bs+licks_tc_us)>5) * is_CSmt], histtype='step', bins=50)

In [None]:
is_errCSmt = (lick_ratios>0.8) * ((licks_tc_us+licks_bs)>4) * is_CSmt
print is_errCSmt.sum()

In [None]:
is_corrCSmt = ((licks_tc_us)==0) * is_CSmt
print is_corrCSmt.sum()

In [None]:
save_workspace(db)

# Lick labeling and selectivity

In [None]:
lick_labels = np.r_[[-1]*len(licks)]
for s, e in cycles:
    l_filter = (licks>=s) * (licks<(s-CYCLE_START))
    lick_labels[l_filter] = 0  # Baseline
    l_filter = (licks>=(s-CYCLE_START)) * (licks<(s-CYCLE_START+CS_DURATION))
    lick_labels[l_filter] = 1  # CS
    l_filter = (licks>=(s-CYCLE_START+CS_DURATION)) * (licks<(s-CYCLE_START+CS_DURATION+DELAY))
    lick_labels[l_filter] = 2  # Trace
    l_filter = (licks>=(s-CYCLE_START+CS_DURATION+DELAY)) * (licks<(s-CYCLE_START+CS_DURATION+DELAY+US_DURATION))
    lick_labels[l_filter] = 3  # US
    l_filter = ((licks>=(s-CYCLE_START+CS_DURATION+DELAY+US_DURATION)) *
        (licks<(s-CYCLE_START+CS_DURATION+DELAY+US_DURATION+AFTER_US_PERIOD)))
    lick_labels[l_filter] = 4  # Post US

In [None]:
def plot_it(c=0):
    fig, ax = pl.subplots(1, 1, figsize=(3, 2))
    ax.vlines(licks, lick_labels-0.2, lick_labels+0.2, lw=1)
#     ax.hlines([-1.5]*len(cycles), cycles[:, 0], cycles[:, 1], lw=1)
    ax.set_xlim(cycles[c][0], cycles[c][1])
    pt.plot_period_bar(ax, -2,
                       start_end=(cycles[c][0]-CYCLE_START, cycles[c][0]-CYCLE_START+CS_DURATION),
                       delta_y=0.2)
    pt.plot_period_bar(ax, -2,
                       start_end=(cycles[c][0]-CYCLE_START+CS_DURATION+DELAY,
                                  cycles[c][0]-CYCLE_START+CS_DURATION+DELAY+US_DURATION),
                       delta_y=0.2, color='m')
    ax.set_yticks(range(-1, 5))
    ax.set_yticklabels(['null', 'Pre', 'CS', 'Trace', 'US', 'Post'])
    ax.text(cycles[c][0]-CYCLE_START, -1.5, 'CS', ha='center', color='b', fontsize=5)
    ax.text(cycles[c][0]-CYCLE_START+CS_DURATION+DELAY, -1.5, 'US', ha='center', color='m', fontsize=5)
    
    text_trialtype = 'CS+' if is_rewardt[c] else 'CS-'
    color_trialtype = 'b' if is_rewardt[c] else 'r'
    ax.text(cycles[c][0], 5, text_trialtype, color=color_trialtype, fontsize=7)
    
interact(plot_it, c=(0, len(cycles)-1, 1))

In [None]:
lick_bout_distance = 1
bouts_starts = licks[np.r_[0, np.where(np.diff(licks)>lick_bout_distance)[0]+1]]
bouts_ends = licks[np.r_[np.where(np.diff(licks)>lick_bout_distance)[0]+1]]

In [None]:
pl.figure(figsize=(3, 2))
[pl.vlines(licks[(licks>=bs) * (licks<be)]-bs, i-0.2, i+0.2, color='k', lw=0.5)
 for i, (bs, be) in enumerate(zip(bouts_starts[:10], bouts_ends[:10]))];
pl.xlim(-2, 6)

In [None]:
bouts_number_threshold = 3

bouts = [licks[(licks>=bs) * (licks<be)]
         for i, (bs, be) in enumerate(zip(bouts_starts, bouts_ends))]
bouts_labels = [lick_labels[(licks>=bs) * (licks<be)]
                for i, (bs, be) in enumerate(zip(bouts_starts, bouts_ends))]
bouts_split = [[b[bl==l]  if (bl==l).sum()>=bouts_number_threshold else None for l in np.unique(lick_labels)]
               for b, bl in zip(bouts, bouts_labels)]

bouts_labels = [b for b, bb in zip(bouts_labels, bouts) if len(bb) > bouts_number_threshold]
bouts_split = [b for b, bb in zip(bouts_split, bouts) if len(bb) > bouts_number_threshold]
bouts = [b for b in bouts if len(b) > bouts_number_threshold]


In [None]:
t_pre = 10
t_post = 20
zz = StandardScaler().fit_transform(traces) 

def plot_it(c=0):
    
    b, bl, bs = bouts[c], bouts_labels[c], bouts_split[c]
    
    # colors = pl.cm.rainbow(np.linspace(0, 1, len(np.unique(lick_labels))))
    colors = ['k', 'k', 'b', 'k', 'm', 'k']

    fig, ax = pl.subplots(1, 1, figsize=(3, 2))

    t_filt = (time_ax>=(b[0]-t_pre)) * (time_ax<(b[0]+t_post))
    for cell in xrange(100):
        ax.plot(time_ax[t_filt]-b[0], zz[t_filt][:, cell]+cell, '-', lw=0.5, color='0.7')

    ax.spines['left'].set_visible(False)
    ax.set_yticks(())

    for c, (s, e) in enumerate(cycles):
        if e<time_ax[t_filt][0] or s>=time_ax[t_filt][-1]: continue
        dy = np.diff(ax.axis()[-2:])[0]/100.
        pt.plot_period_bar(ax, -2,
                           start_end=(s-CYCLE_START-b[0], s-CYCLE_START+CS_DURATION-b[0]),
                           delta_y=dy)
        pt.plot_period_bar(ax, -2,
                           start_end=(s-CYCLE_START+CS_DURATION+DELAY-b[0],
                                      s-CYCLE_START+CS_DURATION+DELAY+US_DURATION-b[0]),
                           delta_y=dy, color='m')
        text_trialtype = 'CS+' if is_rewardt[c] else 'CS-'
        color_trialtype = 'b' if is_rewardt[c] else 'r'
#         ax.text(s-CYCLE_START-b[0], ax.axis()[-1], text_trialtype, color=color_trialtype, fontsize=7)
    ax.set_xlabel('Time from lick bout onset (s)')
    ax = pl.twinx()
    
    for i, l_times in enumerate(bs):
        if l_times is None: continue
        print len(l_times)
        ax.vlines(l_times-b[0], i-0.2-1, i+0.2-1, colors=colors[i], lw=0.5)
    ax.set_ylim(-1, 5)
    ax.set_yticks(range(-1, 5))
    ax.set_yticklabels(['null', 'Pre', 'CS', 'Trace', 'US', 'Post'])

#     ax.set_xlim((-t_pre-10, t_post+10))
    ax.set_xlim((-8, 10))
    ax.set_ylim((-2, 5))
    
    ax.vlines(0, -2, 5, color='k', lw=1, linestyles='dashed')
    
interact(plot_it, c=(0, len(bouts)-1, 1))

In [None]:
pre_bouts = [b[1] for b in bouts_split if b[1] is not None]
cs_bouts = [b[2] for b in bouts_split if b[2] is not None]
tr_bouts = [b[3] for b in bouts_split if b[3] is not None]
us_bouts = [b[4] for b in bouts_split if b[4] is not None]
post_bouts = [b[5] for b in bouts_split if b[5] is not None]

In [None]:
fig, ax = pl.subplots(1, 1, figsize=(3, 2))
for these in [pre_bouts, cs_bouts, tr_bouts, us_bouts, post_bouts]:
    ax.hist([b[-1]-b[0] for b in these], np.linspace(0, 4, 30), histtype='step')
ax.legend(['pre_bouts', 'cs_bouts', 'tr_bouts', 'us_bouts', 'post_bouts'], fontsize=5)
ax.set_xlabel('Time length (s)')
ax.set_ylabel('Frequency')

pt.nicer_plot(ax)

In [None]:
pre_bouts_startends = np.r_[[[p[0], p[-1]] for p in pre_bouts]]

# First reward times

In [None]:
first_reward_times = np.r_[[rewards[np.where((rewards-s)>0)[0][0]]-s+
                            CYCLE_START
                            for s, e in cycles[np.where(is_rewarded)]]]

In [None]:
first_reward_times

In [None]:
pl.hist(first_reward_times, bins=np.arange(3.9, 6, 0.25));

In [None]:
# use this anytime to save the workspace
# save_workspace(db)

### Start Plotting Traces

In [None]:
# figure plots denoised df/f traces over single trial for each neuron
zz = StandardScaler().fit_transform(traces) #previously we just plotted 'traces' not 'zz'
#zz = traces

fig, ax = pl.subplots(1, 1, figsize = (5, 7))
[pl.plot(time_ax/60., zz[:, cell]*.1+cell, lw=0.2) for cell in range(zz.shape[1]-1, 0, -1)];
pl.yticks((0, (zz.shape[1]-1)))
ax.set_yticklabels((0, zz.shape[1]-1))
pl.xlabel("Time (seconds)")
pl.ylabel('Cell #')


In [None]:
# plot each trace continuously with tone presentations indicated by colorbands. 
def plot_it(start=0):
    fig, ax = pl.subplots(1, 1, figsize=(5, 5))
    [pl.plot(time_ax/60., traces[:, cell]*.1+cell, 'ko', lw=0.5, ms=0.1) for cell in range(traces.shape[1]-1, 0, -1)];
    [pl.plot(time_ax/60., traces[:, cell]*.1+cell, 'k-', lw=0.5) for cell in range(traces.shape[1]-1, 0, -1)];
    pl.yticks(range(0, (traces.shape[1]-1), 50))
#     ax.set_yticklabels(range(0, traces.shape[1]-1, 5))
    pl.xlabel("Time (min)")
    pl.ylabel('Cell #')

    # pl.fill_between(np.r_[0, np.repeat(np.r_[zip(odor6_ons, odor6_ons+STIM_DURATION)].flatten(), 2)]/60.,
    #                 np.r_[0, [51, 51, 0, 0]*len(odor6_ons)])
    colors = [(0.7, 0.7, 0.7), (1, 0.7, 0.7), (0.7, 1, 0.7), (0.7, 0.7, 1)]  # minus, shock, sucrose, airpuff CS
    for c, onss in zip(colors, [tone_CSm_ons, tone_shock_ons, tone_reward_ons, tone_airpuff_ons]):
        for onset in onss/60.:
            ax.fill_between([onset, onset+CS_DURATION/60.], [0]*2, [traces.shape[1]]*2, color=c, alpha=0.5, lw=0)
    colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]  # shock, sucrose, airpuff US
    for c, onss in zip(colors, [shocks, rewards, airpuffs]):
        for onset in onss/60.:
            ax.fill_between([onset, onset+US_DURATION/60.], [0]*2, [traces.shape[1]]*2, color=c, alpha=0.5, lw=0)
    ax.set_xlim(start, start+5)
interact(plot_it, start=(0, 50, 1))

In [None]:
def plot_all_cycles(cell, iso=None):
    fig, ax = pl.subplots(1, 1, figsize = (7,5))
    offset = 0
    iso = eval(iso)
    if iso is None:
        iso = np.r_[[True] * len(cycles)]
    for cycle in xrange(len(cycles)):
        t, tr = extract_single_cycle(time_ax, zz, cycles, cycle, cell)
        if iso[cycle]:
            ax.plot(t, tr+offset, color='k')
            offset = offset + 20
    ax.set_yticks(np.arange(np.sum(iso))*20)
    ax.set_yticklabels(np.arange(np.sum(iso)))
    ax.set_ylim([-0.5, offset])
    ax.set_ylabel('trial #')
    ax.set_xlabel('time from CS onset')
    ax.fill_between([CS_START, CS_END],
                    -0.5, offset, color='0.5')
    return ax
interact(plot_all_cycles, cell=(0, traces.shape[1]-1, 1), iso=['is_CSmt', 'is_rewardt', 'is_rewarded', 'is_not_rewarded'])

In [None]:
STIM_START = CS_START
STIM_END = CS_END
sorting = plot_heat_map_average(ax, time_ax, traces, cycles, time_ax_single, normed = True, sorted=True, vmax=10)

In [None]:
fig, axs = pl.subplots(5, 1, figsize=(6, 6), sharex=True, sharey=True,
                       gridspec_kw={'height_ratios':(0.2, 1, 1, 1, 1)})
axs[0].axis('off')

for i, iso, ax in zip(range(10), [is_CSmt, is_rewardt, is_rewarded, is_errCSmt], axs.flatten()[1:]):
    plot_heat_map_average(ax, time_ax, traces, cycles[iso], time_ax_single, normed = True, sorted=True,
                          sorting=sorting, vmax=1)
im = ax.images[0]

fig.colorbar(im, ax=axs[0], orientation='horizontal')

axs[1].text(12, 10, 'CS-', color='r')    
axs[2].text(12, 10, 'Reward', color='r')    
axs[3].text(12, 10, 'Rewarded', color='r')
axs[4].text(12, 10, 'CS- Error', color='r')
ax.set_xlabel('Time (s)')
ax.text(-10, 300, 'Cell #', rotation=90)

In [None]:
fig, axs = pl.subplots(len(traces[1])/6,7, figsize=(7,len(traces[1])/6))


for cell, ax in zip(range(traces.shape[1]), axs.flatten()):
    plot_heat_map(ax, time_ax, traces, cell, cycles[is_CSmt], time_ax_single, normed = True, vmax=0.85)
    # technically, vmax should = 1, but want to make signals stand out for visualization purposes
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(5.7, 3.7, cell+1, color='y')
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('CSm Trials')  

In [None]:
fig, axs = pl.subplots(len(traces[1])/6,7, figsize=(7,len(traces[1])/6))


for cell, ax in zip(range(traces.shape[1]), axs.flatten()):
    plot_heat_map(ax, time_ax, traces, cell, cycles[is_rewardt], time_ax_single, normed = True, vmax=0.85)
    # technically, vmax should = 1, but want to make signals stand out for visualization purposes
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(5.7, 3.7, cell+1, color='y')
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('Reward Trials')  

In [None]:
fig, axs = pl.subplots(len(traces[1])/6,7, figsize=(7,len(traces[1])/6))


for cell, ax in zip(range(traces.shape[1]), axs.flatten()):
    plot_heat_map(ax, time_ax, traces, cell, cycles[is_rewarded], time_ax_single, normed = True, vmax=0.85)
    # technically, vmax should = 1, but want to make signals stand out for visualization purposes
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(5.7, 3.7, cell+1, color='y')
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('Rewarded Reward Trials')  

In [None]:
fig, axs = pl.subplots(len(traces[1])/6,7, figsize=(7,len(traces[1])/6))


for cell, ax in zip(range(traces.shape[1]), axs.flatten()):
    plot_heat_map(ax, time_ax, traces, cell, cycles[is_errCSmt], time_ax_single, normed = True, vmax=0.85)
    # technically, vmax should = 1, but want to make signals stand out for visualization purposes
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(5.7, 3.7, cell+1, color='y')
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('Error CSm Trials')  

In [None]:
fig, axs = pl.subplots(len(traces[1])/6,7, figsize=(7,len(traces[1])/6))

for cell, ax in zip(range(traces.shape[1]), axs.flatten()):
    plot_heat_map(ax, time_ax, traces, cell, cycles[is_corrCSmt], time_ax_single, normed = True, vmax=0.85)
    # technically, vmax should = 1, but want to make signals stand out for visualization purposes
    s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(5.7, 3.7, cell+1, color='y')
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('Corr CSm Trials')

# Correlations

In [None]:
cov_model = EmpiricalCovariance()

In [None]:
zz = StandardScaler().fit_transform(traces)

In [None]:
cov_same = cov_model.fit(zz)

In [None]:
clust_model = KMeans(5)
cov_labels = clust_model.fit_predict(abs(cov_model.covariance_))
print clust_score(cov_model.covariance_, cov_labels)
pl.subplots(1, 1, figsize=(3, 3))

pl.imshow([c[np.argsort(cov_labels)] for c in cov_model.covariance_[np.argsort(cov_labels)]],
          aspect='auto', cmap=pl.cm.RdBu_r, vmin=-1, vmax=1)
pl.colorbar()
pl.xticks(())
pl.yticks(());

In [None]:
pl.hist(cov_model.covariance_.flatten()[cov_model.covariance_.flatten()<1], bins=50, normed=True);

In [None]:
def plot_it(cluster=0, start=0, heat=True):
    stop = start+10
    if not heat:
        [pl.plot(time_ax/60., zz[:, cell]/10.+i, lw=1) for i, cell in enumerate(np.where(cov_labels==cluster)[0])];
    else:
        pl.imshow(zz[(time_ax/60.>=start) * (time_ax/60.<stop)][:, np.where(cov_labels==cluster)[0]].T, extent=(start, stop, 0, (cov_labels==cluster).sum()),
                  origin='lower', aspect='auto', vmin=0, vmax=10)
    pl.xlim(start, stop)
interact(plot_it, cluster=(0, np.max(cov_labels), 1), start=(0, 60, 0.5))

In [None]:
def plot_it(cluster=0):
    list_of_cells = [np.where(cov_labels==cluster)[0]]
    rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
    list_of_cells.append(rest_of_cells)

    fig, ax = pl.subplots(1, 1, figsize=(3, 3))
    pt.plot_rois(mean_image, contours,
                 list_of_cells,
                 colors=['r', 'g', '0.7'],
                 ax=ax
                )
    ax.set_xticks(())
    ax.set_yticks(())
interact(plot_it, cluster=(0, np.max(cov_labels), 1))

# How much variance is explained by behavior

In [None]:
is_random_rewarded_pre_boutst = np.random.choice(np.where(is_rewarded)[0], size=len(pre_bouts), replace=False)
is_random_rewarded_errCSmt = np.random.choice(np.where(is_rewarded)[0], size=is_errCSmt.sum(), replace=False)

## Selectivity to each CS

In [None]:
selectivity = {}

In [None]:
to_plot = np.r_[['CSm', 'shock', 'reward', 'airpuff', 'errCSm', 'corrCSm']]

# CHOOSE WHAT PERIOD DEFINES THE CS WITH REF. TO CYCLE START
cs_timeframe = (CS_START-CYCLE_START, CS_START+CS_DURATION-CYCLE_START)
baseline_timeframe = (CS_START-2-CYCLE_START, CS_START-0-CYCLE_START)

for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'t'))
    
    if not np.any(which_cycles):
        continue
    
    selectivity[t] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                            cs_timeframe, baseline_timeframe)
    

## Selectivity to any CS

In [None]:
#labels_time_ax = np.zeros_like(time_ax)
#
## CHOOSE WHAT PERIOD DEFINES THE CS
#cs_timeframe = (CS_START-CYCLE_START, CS_START+CS_DURATION-CYCLE_START)
#
#selectivity['any_cs'] = ut.compute_selectivity(time_ax, events, cycles,
#                                               cs_timeframe, baseline_timeframe)

## Selectivity to each US

In [None]:
to_plot = np.r_[['CSm', 'shock', 'reward', 'airpuff', 'errCSm', 'corrCSm', 'random_rewarded_errCSm', 'random_rewarded_pre_bouts']]

# CHOOSE WHAT PERIOD DEFINES THE US
us_timeframe = (US_START-CYCLE_START, US_START+US_DURATION-CYCLE_START)

for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'t'))
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"_us"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                            us_timeframe, baseline_timeframe)

## Selectivity to trace

In [None]:
to_plot = np.r_[['CSm', 'shock', 'reward', 'airpuff', 'errCSm', 'corrCSm']]

tc_timeframe = (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY)

for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'t'))
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"_tc"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                            tc_timeframe, baseline_timeframe)

## Selectivity to any trace

In [None]:
#labels_time_ax = np.zeros_like(time_ax)
#
#tc_timeframe = (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY)
#
#selectivity['any_tc'] = ut.compute_selectivity(time_ax, events, cycles,
#                                               tc_timeframe, baseline_timeframe)

## Selectivity to each tc, only for delivered US

In [None]:
to_plot = np.r_[['shock', 'reward', 'airpuff', 'not_reward']]

# CHOOSE WHAT PERIOD DEFINES THE US
tc_timeframe = (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY)
    
for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'ed'))
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"ed_tc"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                                   tc_timeframe, baseline_timeframe)

## Selectivity to each US, only for delivered US

In [None]:
to_plot = np.r_[['shock', 'reward', 'airpuff', 'not_reward']]

# CHOOSE WHAT PERIOD DEFINES THE US
us_timeframe = (US_START-CYCLE_START, US_START+US_DURATION-CYCLE_START)
    
for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'ed'))
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"ed_us"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                                   us_timeframe, baseline_timeframe)

## Selectivity to each CS, only for delivered US

In [None]:
to_plot = np.r_[['shock', 'reward', 'airpuff', 'not_reward']]

# CHOOSE WHAT PERIOD DEFINES THE CS
cs_timeframe = (CS_START-CYCLE_START, CS_START+CS_DURATION-CYCLE_START)

for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_%s'%(t+'ed'))
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"ed"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                                 cs_timeframe, baseline_timeframe)

# Selectivity to licks during ITI

In [None]:
selectivity['pre_bouts'] = ut.compute_selectivity(time_ax, events, pre_bouts_startends, timeframe=None, baseline_timeframe=(-2, 0))

## Selectivity to reward, beginning at reward onset

In [None]:
first_reward_times = np.r_[[rewards[np.where((rewards-s)>0)[0][0]]-s+
                            CYCLE_START
                            for s, e in cycles[np.where(is_rewarded)]]]

In [None]:
#Want to get mean of activity for timeframe that begins when sucrose is first consumed (which often times is not when US window opens)

to_plot = np.r_[['reward']]

# CHOOSE WHAT PERIOD DEFINES THE US
reward_timeframe = zip(first_reward_times-CYCLE_START, first_reward_times-CYCLE_START+2)  # zip makes a list of pairs from two lists

for t in to_plot:
    
    # this one selects which trials to consider
    which_cycles = eval('is_rewarded')
    
    if np.sum(which_cycles) == 0:
        continue
    
    selectivity[t+"_onset"] = ut.compute_selectivity(time_ax, events, cycles[which_cycles],
                                                   reward_timeframe, baseline_timeframe)

In [None]:
selectivity.keys()

# PSTH

In [None]:
delta_t = np.diff(time_ax)[0]

In [None]:
def plot_licks(cycles, licks, ax0, ax1, cs_start, cs_duration, us_start, us_duration,
               colors=((0,0,0,0), (0,0,0,0))):

#     psth = []

#     ons, offs = odon, odoff
#     for start, stop in zip(ons, offs):
#         nstart = np.min(np.where(time_ax>=(start-t_pre)))
#         nstop = nstart + int((cs_duration+t_pre+t_post)/delta_t)
#         psth.append(events[:, cell][nstart:nstop])
#     psth = np.r_[psth]

#     im = ax1.imshow(psth, cmap=pl.cm.gray_r, vmin=0, vmax=10, origin='lower',
#                extent=(-t_pre, t_post+cs_duration, 0, len(ons)), aspect='auto')

    ax0.hist(np.concatenate([licks[(licks>=s) * (licks<e)] - (s-CYCLE_START)
                             for s, e in cycles]), histtype='bar', color='k',
             bins=int(CYCLE_DURATION/np.min(np.diff(time_ax))))

    for i, (s, e) in enumerate(cycles):
        ax1.vlines(licks[(licks>=s) * (licks<e)] - (s-CYCLE_START), i, i+1, 'r', lw=0.5)
#     ax0.bar(np.linspace(-t_pre, t_post+cs_duration, len(psth.sum(0))), (psth).sum(0), color='k', lw=1, width=0.1)
#     ax0.set_ylim(-10, 150)
    for s in ax0.spines.values():
        s.set_visible(False)
#     ax0.set_xticks((-t_pre, 0, t_post+cs_duration))
#     ax0.set_yticks(())
    pt.plot_period_bar(ax0, 75, delta_y=10, color=colors[0], start_end=(cs_start, cs_start+cs_duration))
    pt.plot_period_bar(ax0, 75, delta_y=10, color=colors[1], start_end=(us_start, us_start+us_duration))
    pt.nicer_plot(ax1)
    ax1.set_ylim(0, len(cycles))
    

    return im

In [None]:
def remove_plot(ax):
    for s in ax.spines.values():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())

In [None]:
## %%time
## Plotting Events
#
#def plot_it(odon, odoff, ax0, ax1, cs_start, cs_duration, us_start, us_duration,
#            colors=((0,0,0,0), (0,0,0,0))):
#
#    psth = []
#
#    ons, offs = odon, odoff
#    for start, stop in zip(ons, offs):
#        nstart = np.min(np.where(time_ax>=(start-t_pre)))
#        nstop = nstart + int((cs_duration+t_pre+t_post)/delta_t)
#        psth.append(events[:, cell][nstart:nstop])
#    psth = np.r_[psth]
#
#    im = ax1.imshow(psth, cmap=pl.cm.viridis, vmin=0, vmax=10, origin='lower',
#               extent=(-t_pre, t_post+cs_duration, 0, len(ons)), aspect='auto')
#    ax0.bar(np.linspace(-t_pre, t_post+cs_duration, len(psth.sum(0))), (psth).sum(0), color='k', lw=1, width=0.1)
#    ax0.set_ylim(-10, 150)
#    for s in ax0.spines.values():
#        s.set_visible(False)
#    ax0.set_xticks((-t_pre, 0, t_post+cs_duration))
#    ax0.set_yticks(())
#    pt.plot_period_bar(ax0, 75, delta_y=10, color=colors[0], start_end=(cs_start, cs_start+cs_duration))
#    pt.plot_period_bar(ax0, 75, delta_y=10, color=colors[1], start_end=(us_start, us_start+us_duration))
#    pt.nicer_plot(ax1)
#
#    return im
#
#t_pre = 6
#t_post = 6
#
#to_plot = np.r_[['CSm', 'shock', 'reward', 'airpuff']]
#colors = {}
#colors['CSm'] = ((0.7, 0.7, 0.7), (0, 0, 0))  # (CS, US)
#colors['shock'] = ((1, 0.7, 0.7), (1, 0, 0))
#colors['reward'] = ((0.7, 1, 0.7), (0, 1, 0))
#colors['airpuff'] = ((0.7, 0.7, 1), (0, 0, 1))
#which_exist = [eval('len(tone_%s_ons)>0'%t) for t in to_plot]
#n_rows = np.sum(which_exist)
#
## rank them by selectivity for something, for convenience
#rank_pvalue = np.argsort([s[1] for s in selectivity['rewarded'][:, 0]])
#for cell in rank_pvalue [:10]:
## for cell in np.argsort(selectivity['reward'][:, 1])[:10]:
## cell += 1
#    
#    fig, axs = pl.subplots(n_rows*2, 4, figsize=(6, 4), gridspec_kw={'height_ratios':[1, 2]*n_rows}, sharex=True)
#    
#    # activity
#    for i, t in enumerate(to_plot[which_exist]):
#        im = eval('plot_it(tone_%s_ons, tone_%s_offs, axs[i*2][0], axs[i*2+1][0],\
#                   CS_START, CS_DURATION, US_START, US_DURATION, colors=colors[\'%s\'])'
#                  %(t, t, t))
#        eval('axs[i*2+1][0].set_ylabel(\'%s\')'%t)
#    
#    # activity (us delivered)
#    for i, t in enumerate(to_plot[which_exist]):
#        try:
#            which = eval('[[b in np.where(is_%s)[0] for b in np.where(is_%s)[0]]]' % (t+'ed', t+'t'))
#            im = eval('plot_it(tone_%s_ons[which], tone_%s_offs[which], axs[i*2][1], axs[i*2+1][1],\
#                       CS_START, CS_DURATION, US_START, US_DURATION, colors=colors[\'%s\'])'
#                      %(t, t, t))
#        except NameError:
#            remove_plot(axs[i*2][1])
#            remove_plot(axs[i*2+1][1])
#
#    # activity (us NOT delivered)
#    for i, t in enumerate(to_plot[which_exist]):
#        try:
#            which = eval('[[b in np.where(is_not_%s)[0] for b in np.where(is_%s)[0]]]' % (t+'ed', t+'t'))
#            im = eval('plot_it(tone_%s_ons[which], tone_%s_offs[which], axs[i*2][2], axs[i*2+1][2],\
#                       CS_START, CS_DURATION, US_START, US_DURATION, colors=colors[\'%s\'])'
#                      %(t, t, t))
#        except NameError:
#            remove_plot(axs[i*2][2])
#            remove_plot(axs[i*2+1][2])
#        
#    # licks
#    for i, t in enumerate(to_plot[which_exist]):
#        im = eval('plot_licks(cycles[is_%s], licks, axs[i*2][-1], axs[i*2+1][-1],\
#                   CS_START, CS_DURATION, US_START, US_DURATION, colors=colors[\'%s\'])'
#                  %(t+'t', t))
#    
#    
#    fig.suptitle("Cell #%d"%cell)
#    for i, (t, v) in enumerate(selectivity.iteritems()):
#        p = v[cell][0][1]
#        stars = "***" if p<0.001 else "**" if p<0.01 else "*" if p<0.05 else "n.s."
#        axs[1][1].text(-10, 2.05-i*.1-.1, "%s tuning = %1.2lf (%s)"%(t, selectivity[t][cell][1], stars),
#                       fontsize=5,
#                       color='red')
#    
#    ax = fig.add_axes((1., 0.2, 0.02, 0.6))
#    fig.colorbar(im, cax=ax, orientation='vertical', ticks=(0, 5, 10) )
#    ax.set_xticks((-5, 0, 5))
#    ax.set_yticks((0, 10))
#    ax.text(-0.5, 0.7, 'Event magnitude', rotation=90, ha='center', fontsize=3)
#
#    for s in ax.spines.values():
#        s.set_visible(False)
#        
##     for ax in axs.T:
##         ax[1].set_yticks(())
##         ymin, ymax = ax[1].axis()[-2:]
##         ax[1].fill_between([0, 4], ymax+ymax/20., ymax+ymax/20.+ymax/50., color='0.5', lw=0)
##     axs[1][1].set_xlabel('Time (s)', fontsize=5)
##     axs[0][1].text(0, 0, 'Cell %3d' % cell, fontsize=5, ha='center')
#    
##     axs[0][0].text(-5, -55, 'Air', fontsize=3)
##     axs[0][1].text(-5, -55, 'Flower', fontsize=3)
##     axs[0][2].text(-5, -55, 'Banana', fontsize=3)
#    
##     axs[1][0].set_ylabel('Trial #', fontsize=5)
#    # axs[0][0].set_title('Air', fontsize=5)
#    # axs[0][1].set_title('Banana', fontsize=5)
#    # axs[0][2].set_title('Flower', fontsize=5)
#    
#
#    
#    
##     fig.tight_layout()
#    
#    axs[0][0].set_title('All trials', fontsize=7)
#    axs[0][1].set_title('US delivered', fontsize=7)
#    axs[0][2].set_title('US not delivered', fontsize=7)
#    axs[0][3].set_title('Licks', fontsize=7)
#
#    fig.subplots_adjust(bottom=0.22)
#        
##    fig.savefig('../img/psth_%03d.pdf'%cell)

# PSTH alternate plotting

In [None]:
#for cell in xrange(traces.shape[1]):
## cell += 1
#    t_pre = 6
#    t_post = 6
#
#    fig, axs = pl.subplots(2, 3, figsize=(3, 3), gridspec_kw={'height_ratios':(1, 2)})
#
#    def plot_it(ons, is_it, ax0, ax1):
#
#        psth_ = []
#
#        i = 0
#        for tf in is_it:
#            if not tf:
#                psth_.append(None)
#            else:
#                start = ons[i]
#                nstart = np.min(np.where(time_ax>=(start-t_pre)))
#                nstop = nstart + int((t_pre+t_post)/delta_t)
#                psth_.append(events[:, cell][nstart:nstop])
#                i += 1
#        max_len = np.max([len(p) for p in psth_ if p is not None])
#        psth = np.ones((len(is_it), max_len))*-1
#        for i, p in enumerate(psth_):
#            if p is not None:
#                psth[i] = p
#
#        im = ax1.imshow(psth, cmap=pl.cm.RdBu_r, vmin=-10, vmax=10,
#                   extent=(-t_pre, t_post, 0, len(ons)), aspect='auto', origin='lower')
#        ax0.plot(np.clip(psth, 0, np.inf).sum(0), 'k-', lw=1)
###### or? ax0.plot(time_ax_single, np.mean(psth, 0), lw=1,)
#   #     m = np.mean(psth, 0)
#  #      s = np.std(psth, 0)#/np.sqrt(np.sum(iso)-1)
# #       ax0.fill_between(time_ax_single, m-s, m+s,
##                         lw=0, zorder=0, alpha=0.2)
#        ax0.set_ylim(-10, 300)
#        for s in ax0.spines.values():
#            s.set_visible(False)
#        ax0.set_xticks(())
#        ax0.set_yticks(())
#        pt.nicer_plot(ax1)
#        
#        return im
#
#    im = plot_it(tone_CSm_ons, is_CSmt, axs[0][0], axs[1][0])
#    plot_it(tone_reward_ons, is_rewardt, axs[0][1], axs[1][1])
#    plot_it(rewards, is_rewarded, axs[0][2], axs[1][2])
#    ax = fig.add_axes((1., 0.2, 0.02, 0.4))
#    fig.colorbar(im, cax=ax, orientation='vertical', ticks=(0, 5, 10) )
#    ax.set_xticks(())
#    ax.set_yticks((0, 10))
#    ax.text(-0.5, 0.7, 'Event magnitude', rotation=90, ha='center', fontsize=3)
#
#    for s in ax.spines.values():
#        s.set_visible(False)
#    
#    
#    for i, ax in enumerate(axs.T):
#        ax[1].set_yticks(())
#        ymin, ymax = ax[1].axis()[-2:]
#        ax[1].fill_between([0, 2], ymax+ymax/20., ymax+ymax/20.+ymax/50.,
#                           color='0.5' if i<2 else '1', lw=0)
#    axs[1][1].set_xlabel('Time (s)', fontsize=5)
#    axs[0][1].text(29, 150, 'Cell %3d' % (cell+1), fontsize=5, ha='center')
#    
#    axs[0][0].text(-5, -55, 'CS-', fontsize=3)
#    axs[0][1].text(-5, -55, 'CS+', fontsize=3)
#    axs[0][2].text(-5, -55, 'Reward del.', fontsize=3)
#    
#    axs[1][0].set_ylabel('Trial #', fontsize=5)
#    # axs[0][0].set_title('Air', fontsize=5)
#    # axs[0][1].set_title('Banana', fontsize=5)
#    # axs[0][2].set_title('Flower', fontsize=5)
#    
##     fig.tight_layout()
#    fig.subplots_adjust(bottom=0.22)
#    fig.subplots_adjust(right=0.8)
#    
##    fig.savefig('../img/psth_colored_%03d.pdf'%cell)

------------------

In [None]:
def plot_pie(n_up, n_down, n_total, colors=('r', 'r', 'r'), explode=[0.2, 0.2, 0.0], startangle=0,
             wedgeprops={'lw':0}, labels=['Those cells up', 'Those cells down', 'The rest'],
             percentage=False, title=None, ax=None, title_xy=(3, 0), title_color='k', label_colors=['b', 'r', 'k'],
             **pie_kwargs):
    if ax is None:
        fig, ax = pl.subplots(1, 1)
    if not percentage:
        labels_full = ['%s\nn=%d'%(labels[0], 1.*n_up) if n_up>0 else "",
                       '%s\nn=%d'%(labels[1], 1.*n_down) if n_down>0 else "",
                       '%s\nn=%d'%(labels[2], 1.*(n_total-n_up-n_down) if (n_total-n_up-n_down)>0 else "")]
    else:
        labels_full = ['%s\n%.2lf%%'%(labels[0], 1.*n_up/n_total) if n_up>0 else "",
                       '%s\n%.2lf%%'%(labels[1], 1.*n_down/n_total) if n_down>0 else "",
                       '%s\n%.2lf%%'%(labels[2], 1.*(n_total-n_up-n_down)/n_total) if (n_total-n_up-n_down)>0 else ""]
    patches, texts = ax.pie([n_up, n_down, n_total-(n_up+n_down)], colors=colors, explode=explode, startangle=startangle,
           labels=labels_full, wedgeprops=wedgeprops, textprops={'fontsize':5}, **pie_kwargs);
    [t.set_color(c) for t, c in zip(texts, label_colors)]
    [t.set_fontsize(5) for t in texts]
    
    if title is not None:
        ax.text(title_xy[0], title_xy[1], title, fontsize=5, color=title_color, ha='center')
    
    return ax

fig, ax = pl.subplots(1, 1, figsize=(1, 1))
plot_pie(10, 10, 100, labels=['up', 'dn', 'rest'], ax=ax, title='Sit amet', title_xy=(0, -0.5), title_color='w')

In [None]:
print('percent selective cells, uncorrected for multiple comparisons')
significance = 0.01

for k, v in selectivity.iteritems():
    print k, np.mean([s[1]<significance for s in v[:, 0]])

In [None]:
print('percent selective cells, corrected')
significance = 0.01

to_plot = np.r_[['CSm', 'reward']]

for k, v in selectivity.iteritems():
    print k, np.mean(ut.adjust_pvalues([s[1] for s in v[:, 0]])<significance)

In [None]:
def count_trials_from_key(k):
    if k=='rewarded_tc' or k=='rewarded' or k=='rewarded_us' or k=='reward_onset':
        title = np.sum(is_rewarded)
    if k=='not_rewarded_us' or k=='not_rewarded' or k=='not_rewarded_tc':
        title = np.sum(is_not_rewarded)
    if k== 'reward_tc' or k=='reward_us' or k=='reward':
        title = np.sum(is_rewardt)
    if k== 'corrCSm_us' or k=='corrCSm_tc' or k=='corrCSm':
        title = np.sum(is_corrCSmt)
    if k== 'any_tc' or k=='any_cs':
        title = len(cycles)
    if k== 'errCSm_us' or k=='errCSm_tc' or k=='errCSm':
        title = np.sum(is_errCSmt)
    if k== 'CSm_us' or k=='CSm' or k=='CSm_tc':
        title = np.sum(is_CSmt)
    if k== 'pre_bouts':
        title = len(pre_bouts)
    if k== 'random_rewarded_errCSm_us':
        title = len(is_random_rewarded_errCSmt)
    if k == 'random_rewarded_pre_bouts_us':
        title = len(is_random_rewarded_pre_boutst)

    return title
[(k, count_trials_from_key(k)) for k in np.sort(selectivity.keys())]

In [None]:
colors = {}
colors['CSm'] = ((1, 0, 0), (1, 0, 0), (1, 0, 0))
colors['CSm_us'] = ((1, 0, 0), (1, 0, 0), (1, 0, 0))
colors['CSm_tc'] = ((1, 0, 0), (1, 0, 0), (1, 0, 0))
colors['errCSm'] = ((.5, 0, 0), (.5, 0, 0), (.5, 0, 0))
colors['errCSm_tc'] = ((.5, 0, 0), (.5, 0, 0), (0.5, 0, 0))
colors['errCSm_us'] = ((.5, 0, 0), (.5, 0, 0), (.5, 0, 0))
colors['reward'] = ((0.5, 1, 0.5), (0.5, 1, 0.5), (0.5, 1, 0.5))
colors['reward_us'] = ((0, 1, 0), (0, 1, 0), (0, 1, 0))
colors['rewarded'] = ((0, 0, 0.5), (0, 0, 0.5), (0, 0, 0.5))
colors['rewarded_us'] = ((0, 0, 1), (0, 0, 1), (0, 0, 1))
colors['not_rewarded'] = ((0.5, 1, 1), (0.5, 1, 1), (0.5, 1, 1))
colors['not_rewarded_us'] = ((0.5, 1, 1), (0.5, 1, 1), (0.5, 1, 1))
colors['not_rewarded_tc'] = ((0.5, 1, 1), (0.5, 1, 1), (0.5, 1, 1))
colors['reward_tc'] = ((0.7, 1, 0.7), (0.7, 1, 0.7), (0.7, 1, 0.7))
colors['rewarded_tc'] = ((0, 0, 0.7), (0, 0, 0.7), (0, 0, 0.7))
colors['reward_onset'] = ((0, 0, 0), (0, 0, 0), (0, 0, 0))
#colors['any_tc'] = ((0.5, 1, 1), (0.5, 1, 1), (0.5, 1, 1))
#colors['any_cs'] = ((0.5, 1, 1), (0.5, 1, 1), (0.5, 1, 1))
colors['corrCSm'] = ((0.7, 0, 0), (0.7, 0, 0), (0.7, 0, 0))
colors['corrCSm_tc'] = ((0.7, 0, 0), (0.7, 0, 0), (0.7, 0, 0))
colors['corrCSm_us'] = ((0.7, 0, 0), (0.7, 0, 0), (0.7, 0, 0))
colors['pre_bouts'] = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
colors['random_rewarded_errCSm_us'] = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
colors['random_rewarded_pre_bouts_us'] = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

fig, axs = pl.subplots(len(selectivity)/3+1*(len(selectivity)%3>0), 3,
                       figsize=(4.2, len(selectivity)/3+1*(len(selectivity)%3>0)))

significance = 0.01

# keys = np.sort([k for k in selectivity.keys()])
keys = ['CSm', 'CSm_tc', 'CSm_us',
        'errCSm', 'errCSm_tc', 'errCSm_us',
        'corrCSm', 'corrCSm_tc', 'corrCSm_us',
        'reward', 'reward_tc', 'reward_us',
        'rewarded', 'rewarded_tc', 'rewarded_us',
        'not_rewarded', 'not_rewarded_tc', 'not_rewarded_us',
        'reward_onset', #'any_cs', 'any_tc',
        'pre_bouts', 'random_rewarded_pre_bouts_us',
        'random_rewarded_errCSm_us']
for ax, k in zip(axs.flatten(), keys):
    corrected_ps = ut.adjust_pvalues([s[1] for s in selectivity[k][:, 0]])
    n_up = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]>0))
    n_down = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]<0))
    n_total = events.shape[1]
    plot_pie(n_up, n_down, n_total, labels=['Up', 'Down', 'Untuned'], colors=colors[k],
            title="N=%d"%count_trials_from_key(k), title_xy=(0, -0.5), title_color='w', ax=ax)

    ax.set_title(k, fontsize=7)
fig.tight_layout()

In [None]:
outfile = os.path.join(data_folder, 'selectivity_list_2_4.txt')

with open(outfile, 'w+') as f:
    for k in keys:
        corrected_ps = ut.adjust_pvalues([s[1] for s in selectivity[k][:, 0]])
        n_up = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]>0))
        n_down = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]<0))
        n_total = events.shape[1]
        n_stable = n_total - n_up - n_down
        n_trials = count_trials_from_key(k)
        f.write("%s\n%d\n%d\n%d\n%d\n%d\n" % (k, n_up, n_down, n_stable, n_total,n_trials))
fig.tight_layout()

In [None]:
outfile = os.path.join(data_folder, 'selectivity_list_2_4_w_percent.txt')
significance = 0.01

with open(outfile, 'w+') as f:
    for k in keys:
        corrected_ps = ut.adjust_pvalues([s[1] for s in selectivity[k][:, 0]])
        n_up = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]>0))
        n_down = np.sum((corrected_ps<significance) * (selectivity[k][:, 1]<0))
        n_total = events.shape[1]
        n_stable = n_total - n_up - n_down
        percent_up = (n_up*1.00/n_total*1.00)*100
        percent_down = (n_down*1.00/n_total*1.00)*100
        percent_stable = (n_stable*1.00/n_total*1.00)*100
        n_trials = count_trials_from_key(k)
        f.write("%s\n%d\n%.2f\n%d\n%.2f\n%d\n%.2f\n%d\n%d\n" % (k,n_up,percent_up,n_down,percent_down,n_stable,percent_stable,n_total,n_trials))
#fig.tight_layout()

In [None]:
keys

In [None]:
np.sum(((corrected_ps<significance) * (selectivity['reward'][:, 1]>0))*((corrected_ps<significance) * (selectivity['reward_us'][:, 1]>0)))

In [None]:
tuned_cells = {}
for k, v in selectivity.iteritems():
    corrected_ps = ut.adjust_pvalues([s[1] for s in selectivity[k][:, 0]])
    tuned_cells[k] = [np.where((corrected_ps<0.01)*([s>0 for s in v[:, 1]]))[0],
                      np.where((corrected_ps<0.01)*([s<0 for s in v[:, 1]]))[0],]

In [None]:
events.shape

In [None]:
contours.shape

In [None]:
#this is just a sanity check. There should be no neurons that are members of both conditions
list_of_cells = [tuned_cells['reward'][0], tuned_cells['reward'][1]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)
print('sanity check. There should be no neurons that are members of both conditions')

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned CS+', color='darkorange', fontsize=5)
ax.text(200, 30, 'neg. tuned CS+', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['reward'][0], tuned_cells['CSm'][0]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned CS+', color='darkorange', fontsize=5)
ax.text(200, 30, 'pos. tuned CS-', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['reward'][0], tuned_cells['CSm'][1]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned CS+', color='darkorange', fontsize=5)
ax.text(200, 30, 'neg. tuned CS-', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['rewarded'][0], tuned_cells['rewarded_us'][0]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned CS+ (rewarded)', color='darkorange', fontsize=5)
ax.text(200, 30, 'pos. tuned US+ (rewarded)',color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['rewarded'][1], tuned_cells['rewarded_us'][1]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'neg. tuned CS+ (rewarded)', color='darkorange', fontsize=5)
ax.text(200, 30, 'neg. tuned US+ (rewarded)', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['rewarded_tc'][0], tuned_cells['rewarded_us'][0]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned CS+ trace (rewarded)', color='darkorange', fontsize=5)
ax.text(200, 30, 'pos. tuned US+ (rewarded)', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['rewarded_tc'][1], tuned_cells['rewarded_us'][1]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'neg. tuned CS+ trace (rewarded)', color='darkorange', fontsize=5)
ax.text(200, 30, 'neg. tuned US+ (rewarded)', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['pre_bouts'][0], tuned_cells['reward_onset'][0]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'pos. tuned prebouts ', color='darkorange', fontsize=5)
ax.text(200, 30, 'pos. tuned reward onset', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
list_of_cells = [tuned_cells['pre_bouts'][1], tuned_cells['reward_onset'][1]]
rest_of_cells = np.delete(range(events.shape[1]), np.concatenate(list_of_cells))
list_of_cells.append(rest_of_cells)

fig, ax = pl.subplots(1, 1, figsize=(3, 3))
pt.plot_rois(mean_image, contours,
             list_of_cells,
             colors=['darkorange', 'b', '0.65'],
             ax=ax
            )
ax.set_xticks(())
ax.set_yticks(())
ax.text(200, 15, 'neg. tuned prebouts ', color='darkorange', fontsize=5)
ax.text(200, 30, 'neg. tuned reward onset', color='b', fontsize=5)
ax.text(200, 45, 'both', color='darkorange', fontsize=5, alpha=0.5)
ax.text(200, 45, 'both', color='b', fontsize=5, alpha=0.5)

In [None]:
fig, axs = pl.subplots(len(tuned_cells['reward_onset'][0])/6,7, figsize=(7,len(tuned_cells['reward_onset'][0])/6))

which_cells = tuned_cells['reward_onset'][0]

for cell, ax in zip(which_cells, axs.flatten()):
    plot_heat_map(ax, time_ax, events, cell, cycles[is_rewarded], time_ax_single, vmax=1)
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(-8, 2, cell+1, color='w', fontsize=6)
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('US+ reward onset cells, tuned up')  

In [None]:
fig, axs = pl.subplots(len(tuned_cells['rewarded'][0])/6,7, figsize=(7,len(tuned_cells['rewarded'][0])/6))
which_cells = tuned_cells['rewarded'][0]

for cell, ax in zip(which_cells, axs.flatten()):
    plot_heat_map(ax, time_ax, events, cell, cycles[is_rewarded], time_ax_single, vmax=2)
    for s in ax.spines.itervalues():
        s.set_visible(False)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(-8, 2, cell+1, color='w', fontsize=6)
fig.subplots_adjust(bottom=0.1)
cax = pl.axes([0, 0, 1, 0.01])
im = axs[0][0].images[0]
fig.colorbar(im, cax=cax, orientation='horizontal')
fig.suptitle('CS+ cells (rewarded trials), tuned up')  

# Similarity

In [None]:
def compute_mean_activity_patterns(time_ax, activity, cycles, timeframe):
    start, stop = timeframe
    return np.r_[[np.mean(activity[(time_ax >= (s+start)) * (time_ax < (s+stop))], 0)
                 for s, e in cycles]]

In [None]:
patterns_is_CSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for CS- trials during CS."%((patterns_is_CSmt_CS.sum(1)==0).sum())

patterns_is_CSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for CS- trials during TRACE."%((patterns_is_CSmt_tr.sum(1)==0).sum())

patterns_is_CSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for CS- trials during US."%((patterns_is_CSmt_US.sum(1)==0).sum())

patterns_is_rewardt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for CS+ trials during CS."%((patterns_is_rewardt_CS.sum(1)==0).sum())

patterns_is_rewardt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for CS+ trials during TRACE."%((patterns_is_rewardt_tr.sum(1)==0).sum())

patterns_is_rewardt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for CS+ trials during US."%((patterns_is_rewardt_US.sum(1)==0).sum())

patterns_is_rewarded_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for rewarded trials during CS."%((patterns_is_rewarded_CS.sum(1)==0).sum())

patterns_is_rewarded_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for rewarded trials during TRACE."%((patterns_is_rewarded_tr.sum(1)==0).sum())

patterns_is_rewarded_US = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for rewarded trials during US."%((patterns_is_rewarded_US.sum(1)==0).sum())

patterns_is_not_rewarded_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for unrewarded CS+ trials during CS."%((patterns_is_not_rewarded_CS.sum(1)==0).sum())

patterns_is_not_rewarded_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for unrewarded CS+ trials during TRACE."%((patterns_is_not_rewarded_tr.sum(1)==0).sum())

patterns_is_not_rewarded_US = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for unrewarded CS+ trials during US."%((patterns_is_not_rewarded_US.sum(1)==0).sum())

patterns_is_errCSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for errCS- trials during CS."%((patterns_is_errCSmt_CS.sum(1)==0).sum())

patterns_is_errCSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for errCS- trials during CS."%((patterns_is_errCSmt_tr.sum(1)==0).sum())

patterns_is_errCSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for errCS- trials during CS."%((patterns_is_errCSmt_CS.sum(1)==0).sum())

patterns_is_corrCSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
                                                (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
print "There are %d zero activity patterns for corrCS- trials during CS."%((patterns_is_corrCSmt_CS.sum(1)==0).sum())

patterns_is_corrCSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
                                                (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
print "There are %d zero activity patterns for corrCS- trials during CS."%((patterns_is_corrCSmt_tr.sum(1)==0).sum())

patterns_is_corrCSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
                                                (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
print "There are %d zero activity patterns for corrCS- trials during CS."%((patterns_is_corrCSmt_CS.sum(1)==0).sum())


baseline = -2
patterns_bs = compute_mean_activity_patterns(time_ax, events, cycles,
                                             (CS_START-CYCLE_START+baseline,
                                              CS_START-CYCLE_START+CS_DURATION+baseline))
print "There are %d zero activity patterns for baseline trials during baseline."%((patterns_bs.sum(1)==0).sum())

In [None]:
def compute_similarity_matrix(pattern_ids, all_patterns):
    
    corrmat_distr = {}
    for i, (l, a) in enumerate(zip(pattern_ids, all_patterns)):
        for j, (m, b) in enumerate(zip(pattern_ids, all_patterns)):
            temp = []
            for ii, aa in enumerate(a):
                for jj, bb in enumerate(b):
                    if ii==jj or np.sum(aa)==0 or np.sum(bb)==0: continue
    #                 temp.append(np.dot(aa-aa.mean(), bb-bb.mean())/
    #                             np.sqrt(np.dot(aa-aa.mean(), aa-aa.mean()) * np.dot(bb-bb.mean(), bb-bb.mean())))
                    temp.append(sstats.pearsonr(aa, bb)[0])
            corrmat_distr[(l, m)] = temp
    corrmat = np.zeros((len(pattern_ids), len(pattern_ids)))
    for i, p in enumerate(pattern_ids):
        for j, q in enumerate(pattern_ids):
            corrmat[i][j] = np.mean(corrmat_distr[(p, q)])
    return corrmat_distr, corrmat
            
#pattern_ids = ['ITI', 'CSm_CS', 'rewardt_CS']
#all_patterns = patterns_bs, patterns_is_CSmt_CS, patterns_is_rewardt_CS, patterns_is_not_rewarded_CS
#corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)

In [None]:
def plot_mat(all_patterns, corrmat, vmin=0, vmax=.1, ax=None):
    if ax is None:
        fig, ax = pl.subplots(1, 1, figsize=(2, 2))

    #im = ax.imshow(corrmat, vmin=0., vmax=0.07)
    im = ax.imshow(corrmat, vmin=vmin, vmax=vmax)
    ax.set_yticks(range(len(all_patterns)))
    ax.set_xticks(range(len(all_patterns)))
    ax.set_xticklabels(pattern_ids, rotation=60, fontsize=5)
    ax.set_yticklabels(pattern_ids, fontsize=5)
    ax.set_xlim(-0.5, len(all_patterns)-0.5)
    ax.set_ylim(-0.5, len(all_patterns)-0.5)
#    pl.colorbar(im)
#     ax.set_title('Pattern similarity', fontsize=7)

    return im, ax

In [None]:
fig, axs = pl.subplots(1, 3, figsize=(3, 1), sharex=True, sharey=True)

vmin, vmax = 0, 0.15

pattern_ids = ['ITI', 'CSm_CS', 'rewardt_CS']
all_patterns = patterns_bs, patterns_is_CSmt_CS, patterns_is_rewardt_CS
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[0], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'CSm_tr', 'rewardt_tr']
all_patterns = patterns_bs, patterns_is_CSmt_tr, patterns_is_rewardt_tr
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[1], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'CSm_US', 'rewardt_US']
all_patterns = patterns_bs, patterns_is_CSmt_US, patterns_is_rewardt_US
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
im, ax = plot_mat(all_patterns, corrmat, ax=axs[2], vmin=vmin, vmax=vmax)

pl.colorbar(im, ax=ax)

axs[0].set_yticklabels(['ITI', 'CS-', 'CS+'])
axs[0].set_xticklabels(['ITI', 'CS-', 'CS+'])

axs[0].set_title('During CS', fontsize=7)
axs[1].set_title('During Trace', fontsize=7)
axs[2].set_title('During US', fontsize=7)

In [None]:
fig, axs = pl.subplots(1, 3, figsize=(3, 1), sharex=True, sharey=True)

vmin, vmax = 0, 0.15

pattern_ids = ['ITI', 'err_CSm_CS', 'rewarded_CS']
all_patterns = patterns_bs, patterns_is_errCSmt_CS, patterns_is_rewarded_CS
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[0], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'err_CSm_tr', 'rewarded_tr']
all_patterns = patterns_bs, patterns_is_errCSmt_tr, patterns_is_rewarded_tr
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[1], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'err_CSm_US', 'rewarded_US']
all_patterns = patterns_bs, patterns_is_errCSmt_US, patterns_is_rewarded_US
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
im, ax = plot_mat(all_patterns, corrmat, ax=axs[2], vmin=vmin, vmax=vmax)

pl.colorbar(im, ax=ax)

axs[0].set_yticklabels(['ITI', 'errCS-', 'rewCS+'])
axs[0].set_xticklabels(['ITI', 'errCS-', 'rewCS+'])

axs[0].set_title('During CS', fontsize=7)
axs[1].set_title('During Trace', fontsize=7)
axs[2].set_title('During US', fontsize=7)

In [None]:
fig, axs = pl.subplots(1, 3, figsize=(3, 1), sharex=True, sharey=True)

vmin, vmax = 0, 0.15

pattern_ids = ['ITI', 'corr_CSm_CS', 'rewarded_CS']
all_patterns = patterns_bs, patterns_is_corrCSmt_CS, patterns_is_rewarded_CS
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[0], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'corr_CSm_tr', 'rewarded_tr']
all_patterns = patterns_bs, patterns_is_corrCSmt_tr, patterns_is_rewarded_tr
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
plot_mat(all_patterns, corrmat, ax=axs[1], vmin=vmin, vmax=vmax)

pattern_ids = ['ITI', 'corr_CSm_US', 'rewarded_US']
all_patterns = patterns_bs, patterns_is_corrCSmt_US, patterns_is_rewarded_US
corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)
im, ax = plot_mat(all_patterns, corrmat, ax=axs[2], vmin=vmin, vmax=vmax)

pl.colorbar(im, ax=ax)

axs[0].set_yticklabels(['ITI', 'corrCS-', 'rewCS+'])
axs[0].set_xticklabels(['ITI', 'corrCS-', 'rewCS+'])

axs[0].set_title('During CS', fontsize=7)
axs[1].set_title('During Trace', fontsize=7)
axs[2].set_title('During US', fontsize=7)

In [None]:
#patterns_is_CSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
#patterns_is_CSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
#patterns_is_CSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_CSmt],
#patterns_is_rewardt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
#patterns_is_rewardt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
#patterns_is_rewardt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_rewardt],
#patterns_is_rewarded_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
#patterns_is_rewarded_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
#patterns_is_rewarded_US = compute_mean_activity_patterns(time_ax, events, cycles[is_rewarded],
#patterns_is_not_rewarded_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
#patterns_is_not_rewarded_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
#patterns_is_not_rewarded_US = compute_mean_activity_patterns(time_ax, events, cycles[is_not_rewarded],
#patterns_is_errCSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
#patterns_is_errCSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
#patterns_is_errCSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_errCSmt],
#patterns_is_corrCSmt_CS = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
#patterns_is_corrCSmt_tr = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
#patterns_is_corrCSmt_US = compute_mean_activity_patterns(time_ax, events, cycles[is_corrCSmt],
#patterns_bs = compute_mean_activity_patterns(time_ax, events, cycles,

In [None]:
pattern_ids = ['ITI', 'CSm_CS', 'CSm_tr', 'CSm_US', 'rewardt_CS', 'rewardt_tr', 'rewardt_US',
              'rewarded_CS', 'rewarded_tr', 'rewarded_US',
              'err_reward_CS', 'err_reward_tr', 'err_reward_US',
              'err_CSm_CS', 'err_CSm_tr', 'err_CSm_US',
              'corr_CSm_CS', 'corr_CSm_tr', 'corr_CSm_US',]
# the '\' below is to enable text wrapping
all_patterns = patterns_bs, patterns_is_CSmt_CS, patterns_is_CSmt_tr, patterns_is_CSmt_US,\
patterns_is_rewardt_CS, patterns_is_rewardt_tr, patterns_is_rewardt_US,\
patterns_is_rewarded_CS, patterns_is_rewarded_tr, patterns_is_rewarded_US,\
patterns_is_not_rewarded_CS, patterns_is_not_rewarded_tr, patterns_is_not_rewarded_US,\
patterns_is_errCSmt_CS, patterns_is_errCSmt_tr, patterns_is_errCSmt_US,\
patterns_is_corrCSmt_CS, patterns_is_corrCSmt_tr, patterns_is_corrCSmt_US


corrmat_distr, corrmat = compute_similarity_matrix(pattern_ids, all_patterns)

In [None]:
pairs_to_plot = [('ITI', 'ITI'), ('ITI','rewardt_CS'), ('ITI','rewardt_tr'), ('ITI','rewardt_US'),
                ('rewardt_CS','rewardt_CS'), ('rewardt_CS','CSm_CS'), ('CSm_CS','CSm_CS'),
                ('rewardt_tr','rewardt_tr'), ('rewardt_tr','CSm_tr'), ('CSm_tr','CSm_tr'),
                ('rewardt_US','rewardt_US'), ('rewardt_US','CSm_US'), ('CSm_US','CSm_US'),
                ('rewarded_CS','rewarded_CS'), ('rewarded_tr','rewarded_tr'), ('rewarded_US','rewarded_US'),
                ('rewardt_CS','rewardt_tr'), ('rewardt_CS','rewardt_US'), ('rewardt_tr','rewardt_US'),
                ('rewarded_CS','rewarded_tr'), ('rewarded_CS','rewarded_US'), ('rewarded_tr','rewarded_US'),
                ('rewarded_CS','err_CSm_CS'), ('rewarded_tr','err_CSm_tr'), ('rewarded_US','err_CSm_US'),
                ('corr_CSm_CS','corr_CSm_CS'), ('corr_CSm_tr','corr_CSm_tr'), ('corr_CSm_US','corr_CSm_US'),
                ('err_reward_CS','corr_CSm_CS'), ('err_reward_tr','corr_CSm_tr'), ('err_reward_US','corr_CSm_US'),
                ('err_reward_CS','rewarded_CS'),  ('err_reward_tr','rewarded_tr'),  ('err_reward_US','rewarded_US'),
                ('rewarded_CS','corr_CSm_CS'), ('rewarded_tr','corr_CSm_tr'), ('rewarded_US','corr_CSm_US')]

values = [corrmat_distr[p] for p in pairs_to_plot]

sel_listb = os.path.join(data_folder, 'diggsie_similarity_list.txt')
n=-1
with open(sel_listb, 'w+') as f:
    for v in values:
        n=n+1
        f.write('%s\t%s\t%s\n'% ((pairs_to_plot[n]), str(np.mean(v)), str((np.std(v)/np.sqrt(len(v)-1)))))
        # .txt file reports comparison, mean, SEM

In [None]:
pairs_to_plot = [('ITI', 'ITI'), ('CSm_US', 'rewardt_tr'), ('rewardt_US', 'rewardt_US')]
colors = ['0.7', 'r', 'b']

fig, axs = pl.subplots(1, 2, figsize=(4, 2), gridspec_kw={'width_ratios':[1, 2]})
values = [corrmat_distr[p] for p in pairs_to_plot]
ax = axs[0]
pt.barplot(ax, [0, 1, 2], [np.mean(v) for v in values], [np.std(v)/np.sqrt(len(v)-1) for v in values],
           color=colors)
pt.add_significance(ax, values[0], values[1], 0, 1, 0.25, thresholds=(1e-3, 1e-4, 1e-5))
pt.add_significance(ax, values[1], values[2], 1, 2, 0.27, thresholds=(1e-3, 1e-4, 1e-5))
pt.add_significance(ax, values[0], values[2], 0, 2, 0.29, thresholds=(1e-3, 1e-4, 1e-5))
ax.set_xticks(range(3))
ax.set_xticklabels(["%s\nvs\n%s"%p for p in pairs_to_plot],fontsize=5)
ax.set_ylim(0, .3)
ax.set_ylabel("Pattern similarity")
pt.nicer_plot(ax)

ax = axs[1]
for v, c in zip(values, colors):
    y, x = np.histogram(v, 100)
    ax.plot(np.linspace(0, 1, len(y)), 1.*np.cumsum(y)/np.sum(y), color=c, lw=1)
pt.nicer_plot(ax)
ax.set_xticks((0, 1))
ax.set_xticklabels((0, 'max'))
ax.set_yticks((0, 1))
ax.set_xlabel('Pattern similarity')
ax.set_ylabel('Fraction of trials')

fig.tight_layout()

#import csv
#
#sel_list = os.path.join(data_folder, 'similarity_list.csv')
#with open(sel_list, "w") as output:
#        writer = csv.writer(output, lineterminator='\n')
#        for v in values:
#            writer.writerow([np.mean(v)]) 

In [None]:
pairs_to_plot = [('ITI', 'ITI'), ('CSm_US', 'rewardt_tr'), ('rewardt_US', 'rewardt_US')]
colors = ['r', 'b']

fig, axs = pl.subplots(1, 2, figsize=(4, 2), gridspec_kw={'width_ratios':[1, 2]})

values = [corrmat_distr[p] for p in pairs_to_plot]

ax = axs[0]
m = np.mean(values[0])
pt.barplot(ax, range(2),
           [np.mean((v-m)/m) for v in values[1:]],
           [np.std((v-m)/m)/np.sqrt(len(v)-1) for v in values[1:]],
           color=colors)
pt.add_significance(ax, values[1], values[2], 0, 1, .25)
# pt.add_significance(ax, values[1], values[2], 1, 2, 0.27, thresholds=(1e-3, 1e-4, 1e-5))
# pt.add_significance(ax, values[0], values[2], 0, 2, 0.29, thresholds=(1e-3, 1e-4, 1e-5))
ax.set_xticks(range(len(values)))
ax.set_xticklabels(["%s\nvs\n%s"%p for p in pairs_to_plot[1:]])
# ax.set_ylim(0, .3)
ax.set_ylabel('Ensemble similarity\nw.r.t. to ITI (%)')
pt.nicer_plot(ax)

ax = axs[1]
for v, c in zip(values[1:], colors):
    y, x = np.histogram(v, 100)
    ax.plot(np.linspace(0, 1, len(y)), 1.*np.cumsum(y)/np.sum(y), color=c, lw=1)
pt.nicer_plot(ax)
ax.set_xticks((0, 1))
ax.set_xticklabels((0, 'max'))
ax.set_yticks((0, 1))
ax.set_xlabel('Ensemble similarity w.r.t. to ITI')
ax.set_ylabel('Fraction of trials')

fig.tight_layout()

# Stability

In [None]:
from itertools import product

In [None]:
n_chunks = 3
chunk_length = len(cycles)/3

chunk_ids = np.zeros(len(cycles))
for i in xrange(n_chunks):
    chunk_ids[chunk_length*i:chunk_length*(i+1)] = i


In [None]:
all_patterns_stab = [compute_mean_activity_patterns(time_ax, events, cycles[eval('is_%s'%od) * (chunk_ids==i)],
                                         (CS_START-CYCLE_START, CS_START-CYCLE_START+CS_DURATION))
            for (od, i) in product(['rewarded', 'not_rewarded', 'CSmt'], range(n_chunks))]
baseline = -2
all_patterns_stab.append(compute_mean_activity_patterns(time_ax, events, cycles,
                                               (CS_START-CYCLE_START+baseline,
                                                CS_START-CYCLE_START+CS_DURATION+baseline)))

pattern_ids_stab = [(i, od) for (od, i) in product(['rew', 'not_rew', 'CSmt'], range(n_chunks))]
# all_patterns = patterns_bs, patterns_odor1, patterns_odor3, patterns_odor5
corrmat_distr_stab, corrmat_stab = compute_similarity_matrix(pattern_ids_stab, all_patterns_stab)

In [None]:
fig, ax = pl.subplots(1, 1, figsize=(2, 2))

# im = ax.imshow(corrmat, vmin=0.14, vmax=0.26)
im = ax.imshow(corrmat_stab)
ax.set_yticks(range(len(all_patterns_stab)))
ax.set_xticks(range(len(all_patterns_stab)))
ax.set_xticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab], rotation=75)
ax.set_yticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab])
pl.colorbar(im)
ax.set_title('Stability of CS\npattern similarity', fontsize=7)
ax.vlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
ax.hlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
pt.nicer_plot(ax)
for v in ax.spines.itervalues():
    v.set_visible(False)

In [None]:
all_patterns_stab = [compute_mean_activity_patterns(time_ax, events, cycles[eval('is_%s'%od) * (chunk_ids==i)],
                                         (CS_START-CYCLE_START+CS_DURATION, CS_START-CYCLE_START+CS_DURATION+DELAY))
            for (od, i) in product(['rewarded', 'not_rewarded', 'CSmt'], range(n_chunks))]
baseline = -2
all_patterns_stab.append(compute_mean_activity_patterns(time_ax, events, cycles,
                                               (CS_START-CYCLE_START+baseline,
                                                CS_START-CYCLE_START+CS_DURATION+baseline)))

pattern_ids_stab = [(i, od) for (od, i) in product(['rew', 'not_rew', 'CSmt'], range(n_chunks))]
# all_patterns = patterns_bs, patterns_odor1, patterns_odor3, patterns_odor5
corrmat_distr_stab, corrmat_stab = compute_similarity_matrix(pattern_ids_stab, all_patterns_stab)

In [None]:
fig, ax = pl.subplots(1, 1, figsize=(2, 2))

# im = ax.imshow(corrmat, vmin=0.14, vmax=0.26)
im = ax.imshow(corrmat_stab)
ax.set_yticks(range(len(all_patterns_stab)))
ax.set_xticks(range(len(all_patterns_stab)))
ax.set_xticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab], rotation=75)
ax.set_yticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab])
pl.colorbar(im)
ax.set_title('Stability of trace\npattern similarity', fontsize=7)
ax.vlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
ax.hlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
pt.nicer_plot(ax)
for v in ax.spines.itervalues():
    v.set_visible(False)

In [None]:
all_patterns_stab = [compute_mean_activity_patterns(time_ax, events, cycles[eval('is_%s'%od) * (chunk_ids==i)],
                                         (US_START-CYCLE_START, US_START-CYCLE_START+US_DURATION))
            for (od, i) in product(['rewarded', 'not_rewarded', 'CSmt'], range(n_chunks))]
baseline = -2
all_patterns_stab.append(compute_mean_activity_patterns(time_ax, events, cycles,
                                               (CS_START-CYCLE_START+baseline,
                                                CS_START-CYCLE_START+CS_DURATION+baseline)))

pattern_ids_stab = [(i, od) for (od, i) in product(['rew', 'not_rew', 'CSmt'], range(n_chunks))]
# all_patterns = patterns_bs, patterns_odor1, patterns_odor3, patterns_odor5
corrmat_distr_stab, corrmat_stab = compute_similarity_matrix(pattern_ids_stab, all_patterns_stab)

In [None]:
fig, ax = pl.subplots(1, 1, figsize=(2, 2))

# im = ax.imshow(corrmat, vmin=0.14, vmax=0.26)
im = ax.imshow(corrmat_stab)
ax.set_yticks(range(len(all_patterns_stab)))
ax.set_xticks(range(len(all_patterns_stab)))
ax.set_xticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab], rotation=75)
ax.set_yticklabels(["%s (%d)"%(p[1], p[0]+1) for p in pattern_ids_stab])
pl.colorbar(im)
ax.set_title('Stability of US\npattern similarity', fontsize=7)
ax.vlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
ax.hlines(np.linspace(0, 9, 4)-0.5, 0-0.5, 9-0.5, color='r', lw=.5)
pt.nicer_plot(ax)
for v in ax.spines.itervalues():
    v.set_visible(False)

 # Activity trends

In [None]:
def compute_activity_trend(time_ax, traces, events_of_interest,
                           which_cells=None, tpre=0, tpost=2):
    if which_cells is None:
        which_cells = range(traces.shape[1])
    
    extracted_traces = ut.extract_traces_around_event(time_ax, traces, events_of_interest, tpre, tpost)
    return extracted_traces.mean(1)[:, which_cells].mean(1)

def plot_activity_trend(trend, smooth_win=10, ax=None):
    if ax is None:
        fig, ax = pl.subplots(1, 1, figsize=(3, 2))
    ax.plot(trend, lw=.5, color='0.8')
    
    ax.plot(np.convolve(trend, np.ones(smooth_win).astype(float)/smooth_win, mode='same', ), lw=1, color='C0')

    ymin, ymax = 0, ax.axis()[-1]
    
    
    ax.vlines([smooth_win/2., len(trend)-smooth_win/2.], 0, ymax, color='0.7', linestyles='dotted', lw=1, zorder=0)

    ax.set_xlabel('Trial #')
    ax.set_ylabel('Mean overall activity')
    ax.set_ylim((ymin, ymax))
    ax.set_xlim((-5, len(trend)+len(trend)/10))
    pt.nicer_plot(ax)

    return ax

In [None]:
which_cycles = is_rewarded
events_of_interest = cycles[which_cycles][:, 0] + first_reward_times - CYCLE_START

activity_trend_frt = compute_activity_trend(time_ax, traces, events_of_interest,
                                            which_cells=None,
                                            tpre=0, tpost=2)
ax = plot_activity_trend(activity_trend_frt, smooth_win=10, ax=None)

ax.set_ylabel('Mean activity of all cells in 2s \nperiod after reward delivery onset')


In [None]:
which_cycles = is_rewarded
events_of_interest = cycles[which_cycles][:, 0] + first_reward_times - CYCLE_START

activity_trend_frt = compute_activity_trend(time_ax, traces, events_of_interest,
                                            which_cells=tuned_cells['reward_onset'][0],
                                            tpre=0, tpost=2)
ax = plot_activity_trend(activity_trend_frt, smooth_win=10, ax=None)

ax.set_ylabel('Mean activity of +tuned rew cells in 2s \nperiod after reward delivery onset')


In [None]:
which_cycles = is_rewarded
events_of_interest = cycles[which_cycles][:, 0] + CS_START - CYCLE_START

activity_trend_frt = compute_activity_trend(time_ax, traces, events_of_interest,
                                            which_cells=tuned_cells['reward_onset'][0],
                                            tpre=0, tpost=2)
ax = plot_activity_trend(activity_trend_frt, smooth_win=10, ax=None)

ax.set_ylabel('Mean activity of +tuned rew+ cells \nduring 2s CS period of rewCS+ trials')

In [None]:
which_cycles = is_rewarded
events_of_interest = cycles_starts[which_cycles] + CS_START +CS_DURATION - CYCLE_START

activity_trend_frt = compute_activity_trend(time_ax, traces, events_of_interest,
                                            which_cells=tuned_cells['rewarded_tc'][0],
                                            tpre=0, tpost=2)
ax = plot_activity_trend(activity_trend_frt, smooth_win=10, ax=None)

ax.set_ylabel('Mean activity of +tuned tr cells \nduring 2s trace period of rewCS+ trials')

In [None]:
which_cycles = is_rewarded
events_of_interest = cycles_starts[which_cycles] + CS_START - CYCLE_START

activity_trend_frt = compute_activity_trend(time_ax, traces, events_of_interest,
                                            which_cells=tuned_cells['rewarded'][0],
                                            tpre=0, tpost=2)
ax = plot_activity_trend(activity_trend_frt, smooth_win=10, ax=None)

ax.set_ylabel('Mean activity of +tuned CS cells \nduring 2s CS period of rewCS+ trials')


In [None]:
save_workspace(db)