In [4]:
import os
import glob
import itertools
import json
import importlib
import cv2
import matplotlib as mpl
import scipy.stats as spstats
import dill as pkl
import pandas as pd
import numpy as np
import seaborn as sns
import pylab as pl
import dill as pkl
import statsmodels.api as sm


In [2]:
%matplotlib notebook

In [3]:
import py3utils as p3
import plotting as pplot
import rf_utils as rfutils

In [8]:

#### Plotting params
visual_areas, area_colors = pplot.set_threecolor_palette()
pplot.set_plot_params(labelsize=6, lw_axes=0.25)


# Load source data

In [10]:
sdata, cells0 = p3.get_aggregate_info(visual_areas=visual_areas, return_cells=True)
rf_meta = sdata[sdata.experiment.isin(['rfs', 'rfs10'])].copy()

rf_dkeys = [(va, dk) for (va, dk), g in rf_meta.groupby(['visual_area', 'datakey'])]
CELLS = pd.concat([g for (va, dk), g in cells0.groupby(['visual_area', 'datakey'])\
                                  if (va, dk) in rf_dkeys])
all_cell_counts = CELLS[['visual_area', 'datakey', 'cell']]\
                .drop_duplicates().groupby(['visual_area']).count().reset_index()

print(all_cell_counts.groupby(['visual_area']).sum())

/n/coxfs01/julianarhee/aggregate-visual-areas/dataset_info_assigned.pkl
Segmentation: missing 3 dsets
             datakey  cell
visual_area               
Li              2599  2599
Lm              3491  3491
V1              4007  4007


# Get RF outdirs

In [14]:
aggregate_dir = '/n/coxfs01/julianarhee/aggregate-visual-areas'

response_type='dff'
traceid = 'traces001'

fit_thr=0.5
do_spherical_correction=False

In [125]:
# Get RF fit description (tag for which RF analysis to load)
rf_fit_desc = rfutils.get_fit_desc(response_type=response_type, 
                                do_spherical_correction=do_spherical_correction)
data_id = '%s_%s' % (traceid, rf_fit_desc)
print(data_id)
# Set output dir
dst_dir = os.path.join(aggregate_dir, 'receptive-fields', 
                       '%s__%s' % (traceid, rf_fit_desc))
if not os.path.exists(dst_dir):
    os.makedirs(dst_dir)
print(dst_dir)

traces001_fit-2dgaus_dff_sphr
/n/coxfs01/julianarhee/aggregate-visual-areas/receptive-fields/traces001__fit-2dgaus_dff_sphr


In [124]:
# Output dir for Spherical Correction analysis/examples
sphr_dir = os.path.join(aggregate_dir, 'receptive-fields', 'spherical_correction')
if not os.path.exists(sphr_dir):
    os.makedirs(sphr_dir)
print(sphr_dir)
    

/n/coxfs01/julianarhee/aggregate-visual-areas/receptive-fields/spherical_correction


# Load RF fits

In [16]:
# Load RF fits (R2 fit_thr=0.5)
rfdata = rfutils.aggregate_rfdata(rf_meta, CELLS, fit_desc=rf_fit_desc)
# Get ROI positions
rfdata = p3.add_rf_positions(rfdata)

N dpaths: 65, N unfit: 0
N datasets included: 65, N sessions excluded: 4
Adding RF position info...


In [221]:
cells_w_both = pd.concat([g for (va, dk, rid), g in \
                    rfdata.groupby(['visual_area', 'datakey', 'cell'])\
                   if 'rfs' in g['experiment'].values and \
                      'rfs10' in g['experiment'].values])
cells_w_both[['visual_area', 'datakey', 'cell']].drop_duplicates()\
.groupby(['visual_area']).count()

Unnamed: 0_level_0,datakey,cell
visual_area,Unnamed: 1_level_1,Unnamed: 2_level_1
Li,33,33
Lm,10,10
V1,200,200


In [1071]:
fovs_w_both = pd.concat([g for (va, dk), g in \
                    rfdata.groupby(['visual_area', 'datakey'])\
                   if 'rfs' in g['experiment'].values and \
                      'rfs10' in g['experiment'].values])
fovs_w_both[['visual_area', 'datakey']].drop_duplicates()\
.groupby(['visual_area']).count()

Unnamed: 0_level_0,datakey
visual_area,Unnamed: 1_level_1
Li,3
Lm,1
V1,4


In [1076]:
counts_w_both = fovs_w_both.groupby(['visual_area', 'datakey', 'experiment']).count()['cell'].reset_index()

In [1086]:
drop_repeats=True
counts_all = rfdata.groupby(['visual_area', 'datakey', 'experiment'])\
                .count()['cell'].reset_index()
counts_all = p3.split_datakey(counts_all)
unique_dsets = p3.select_best_fovs(counts_all, criterion='max', colname='cell')
u_dkeys = list([tuple(k) for k in unique_dsets[['visual_area', 'datakey']].values])
all_dkeys = list([tuple(k) for k in \
                  counts_all[['visual_area', 'datakey']].drop_duplicates().values])
final_dkeys = u_dkeys if drop_repeats else all_dkeys
dset_str = 'drop_repeats' if drop_repeats else 'all_dsets'

final_counts = pd.concat([g for (va, dk), g in counts_all.groupby(['visual_area', 'datakey'])\
          if (va, dk) in final_dkeys])
    

[V1] Animalid does not exist: JC078 


In [1090]:
rf_colors={'rfs': 'm', 'rfs10': 'c'}

In [1097]:
fig, ax = pl.subplots()
sns.stripplot(x='visual_area', y='cell', hue='experiment', data=counts_all,
             ax=ax, dodge=True, order=visual_areas,
             palette=rf_colors)
# sns.stripplot(x='visual_area', y='cell', hue='experiment', data=counts_all,
#              ax=ax, dodge=True, order=visual_areas,
#              palette=rf_colors)
tmpdir = os.path.join(aggregate_dir, 'receptive-fields',
                        'traces001__fit-2dgaus_dff-no-cutoff', 'rfs5_v_rfs10')
pl.savefig(os.path.join(tmpdir, 'counts_by_fov.svg'))

<IPython.core.display.Javascript object>

In [1096]:
rf_fit_desc

'fit-2dgaus_dff_sphr'

## Examples

In [191]:
if not os.path.exists(os.path.join(sphr_dir, 'examples')):
    os.makedirs(os.path.join(sphr_dir, 'examples'))


In [1177]:
# va='V1'
# dk='20190616_JC097_fov1'
# rid=388
# rf_type='rfs'

# rf_type='rfs' if va in ['V1', 'Lm'] else 'rfs10'

# va = 'Li'
# dk = '20190614_JC091_fov1'
# rf_type = 'rfs10'
# rid = 211 #227 #32 #61

va = 'V1'
dk = '20190522_JC084_fov1'
rf_type = 'rfs'
rid = 93 #32 #61

session, animalid, fovn = p3.split_datakey_str(dk)
fov = 'FOV%i_zoom2p0x' % fovn
do_spherical_correction=True
curr_rois = rfdata[(rfdata.visual_area==va) & (rfdata.datakey==dk) 
                 & (rfdata.experiment==rf_type)]['cell'].unique()
print(len(curr_rois))

130


In [1178]:
fit_results, fit_params = rfutils.load_fit_results(animalid, session, fov,
                                experiment=rf_type, traceid=traceid,
                                response_type=response_type,
                                do_spherical_correction=False)

row_vals=fit_params['row_vals']
col_vals = fit_params['col_vals']


In [1179]:
ncols = len(fit_params['col_vals']) #21
nrows = len(fit_params['row_vals']) #11

In [1180]:
# Downsample screen (don't need full resolution)
downsampler=fit_params_sphr['downsample_factor']
resolution_ds = [int(i/downsampler) for i \
                 in fit_params['screen']['resolution'][::-1]]
print("Screen resolution (ds=%.1fx): [%i, %i]" \
      % (downsampler, resolution_ds[0], resolution_ds[1]))
# Linear coordinates
lin_x, lin_y = rfutils.get_lin_coords(resolution=resolution_ds)
# Spherical coordinates (Lin coords already in deg)
# cart_x, cart_y, sphr_x, sphr_y = rfutils.get_spherical_coords(cart_pointsX=lin_x, 
#                                         cart_pointsY=lin_y, cm_to_degrees=False)
# cart_x.shape
cart_x, cart_y, sphr_x, sphr_y = rfutils.get_spherical_coords(cart_pointsX=None, 
                                        cart_pointsY=None, cm_to_degrees=True,
                                        resolution=resolution_ds) 

Screen resolution (ds=3.0x): [360, 640]
(360, 640) (360, 640)


In [1181]:
# rid = 388 #32 #61
# rid = 388
rfmap = fit_results[rid]['data'].copy()
print(rfmap.shape)
# Upsample RF map to screen pixels
rfmap_orig = rfutils.resample_map(rfmap, cart_x, cart_y, 
                          row_vals=row_vals, col_vals=col_vals,
                          resolution=resolution_ds)
print(rfmap_orig.shape)
# Warp resampled RF map
rfmap_sphr = rfutils.warp_spherical(rfmap_orig, sphr_x, sphr_y, 
                        cart_x, cart_y, normalize_range=True, method='linear')
print(rfmap_sphr.shape)

(11, 21)
(360, 640)
(360, 640)


In [1182]:
# Trim
screen_bounds_pix = rfutils.get_screen_lim_pixels(cart_x, cart_y, 
                                    row_vals=row_vals, col_vals=col_vals)

rfmap_trim  = rfutils.trim_resampled_map(rfmap_orig, screen_bounds_pix)
rfmap_trim_sphr  = rfutils.trim_resampled_map(rfmap_sphr, screen_bounds_pix)
#rfmap_corrected[pix_top_edge:pix_bottom_edge, pix_left_edge:pix_right_edge]
print(rfmap_trim.shape, rfmap_trim_sphr.shape)

# Downsample, so we don't have repeated values
rfmap_o = cv2.resize(rfmap_trim, (ncols, nrows))
rfmap_c = cv2.resize(rfmap_trim_sphr, (ncols, nrows))

(293, 561) (293, 561)


In [1183]:
sphr_y.max()

0.6069785045451996

In [1184]:
fig, axn = pl.subplots(2, 3, figsize=(6,3.5))
ax = axn[0,0]
ax.imshow(rfmap_orig, cmap='bone')
ax.set_title('1. Upsampled [%i, %i], ds=%i' \
             % (resolution_ds[1], resolution_ds[0],downsampler))

ax = axn[0,1]
ax.imshow(rfmap_sphr, cmap='bone')
ax.set_title('2. Spherical correction')

ax = axn[0,2]
ax.imshow(rfmap_orig, alpha=0.7, cmap='Reds')
ax.imshow(rfmap_sphr, alpha=0.7, cmap='Blues')
ax.set_title('Overlay')

ax = axn[1,0]
ax.imshow(rfmap, cmap='bone')
ax.set_title('Orig rfmap')

ax = axn[1,1]
ax.imshow(rfmap_c, cmap='bone')
ax.set_title('3. Trimmed and downsampled')

ax = axn[1,2]
ax.imshow(rfmap, alpha=0.7, cmap='Reds')
ax.imshow(rfmap_c, alpha=0.7, cmap='Blues')
ax.set_title('Overlay')

for ax in axn.flat:
    ax.set_aspect('equal')
    ax.invert_yaxis()
pl.subplots_adjust(wspace=0.5, hspace=0.5, top=0.8, bottom=0.2)
pplot.label_figure(fig, data_id)
fig.text(0., 0.9, '[%s] rid=%i (%s, %s)' % (rf_type, rid, va, dk))

figname = 'steps_%s_%s_rid%i_%s' % (va, dk, rid, rf_type)
pl.savefig(os.path.join(sphr_dir, 'examples', '%s.svg' % figname))

<IPython.core.display.Javascript object>

## Visualize all RFs for FOV

In [1185]:
fit_params['rfdir']

'/n/coxfs01/2p-data/JC084/20190522/FOV1_zoom2p0x/combined_rfs_static/traces/traces001_b78b04_traces001_5c60f6_traces001_7a5ecd_traces001_c80ac8_traces001_656bda_traces001_f17a03/receptive_fields/fit-2dgaus_dff-no-cutoff'

In [1176]:
# Load original rfmaps
rfmaps_arr = rfutils.load_rfmap_array(fit_params['rfdir'], 
                                      do_spherical_correction=False)
rfmaps_arr.shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [1164]:
ds_factor=3
col_vals = fit_params['col_vals']
row_vals = fit_params['row_vals']
# Downsample screen resolution
resolution_ds = [int(i/ds_factor) for i in \
                 fit_params['screen']['resolution'][::-1]]
# Get linear coordinates in degrees (downsampled)
# lin_x, lin_y = get_lin_coords(resolution=resolution_ds, cm_to_deg=True)
print("Screen res (ds=%ix): [%i, %i]" % (ds_factor, resolution_ds[0], resolution_ds[1]))
# Get Spherical coordinate mapping
cart_x, cart_y, sphr_th, sphr_ph = rfutils.get_spherical_coords(
                                    cart_pointsX=None, cart_pointsY=None,
                                    cm_to_degrees=True, resolution=resolution_ds) # already in deg

args=(cart_x, cart_y, sphr_th, sphr_ph, resolution_ds, row_vals, col_vals,)

Screen res (ds=3x): [360, 640]
(360, 640) (360, 640)


In [1165]:
do_spherical_correction

True

In [1166]:
run = rf_type
print(run)
do_spherical_correction=True
# -----------------------------------------------------------------------------
# rf_param_str = 'fit-2dgaus_%s-no-cutoff' % (response_type) 
run_name = run.split('_')[1] if 'combined' in run else run
rfdir, fit_desc = rfutils.create_rf_dir(animalid, session, fov,
                                'combined_%s_static' % run_name, traceid=traceid,
                                response_type=response_type,
                                do_spherical_correction=do_spherical_correction, fit_thr=fit_thr)
fit_params = rfutils.get_fit_params(animalid, session, fov, run=run_name,
                                traceid=traceid, response_type=response_type,
                                do_spherical_correction=do_spherical_correction)
# Get data source
traceid_dir = rfdir.split('/receptive_fields/')[0]
data_fpath = os.path.join(traceid_dir, 'data_arrays', 'np_subtracted.npz')
data_id = '|'.join([animalid, session, fov, run, traceid, fit_desc])
if not os.path.exists(data_fpath):
    # Realign traces
    print("*****corrected offset unfound, ABORT*****")
    print("%s | %s | %s | %s | %s" % (animalid, session, fov, run, traceid))
    # aggregate_experiment_runs(animalid, session, fov, run_name, traceid=traceid)
    # print("*****corrected offsets!*****")

rfs


In [1167]:
print(fit_params['rfdir'])

/n/coxfs01/2p-data/JC097/20190616/FOV1_zoom2p0x/combined_rfs_static/traces/traces001_e0b0a8_traces001_dd504a_traces001_0e6f9a_traces001_1c5b41_traces001_5483ef_traces001_9f6c47/receptive_fields/fit-2dgaus_dff_sphr


In [1162]:
print(len(curr_rois))
rfmaps_arr = rfutils.sphr_correct_maps(rfmaps_arr[curr_rois], ds_factor=3, 
                                        fit_params=fit_params)
rfutils.save_rfmap_array(rfmaps_sphr, fit_params['rfdir'])
    

102
Screen res (ds=3x): [360, 640]
(360, 640) (360, 640)


KeyboardInterrupt: 

In [1168]:
fit_params['rfdir']

'/n/coxfs01/2p-data/JC097/20190616/FOV1_zoom2p0x/combined_rfs_static/traces/traces001_e0b0a8_traces001_dd504a_traces001_0e6f9a_traces001_1c5b41_traces001_5483ef_traces001_9f6c47/receptive_fields/fit-2dgaus_dff_sphr'

In [1170]:
fit_results_sphr, fit_params_sphr = rfutils.fit_rfs(rfmaps_arr, response_type=response_type,
                            do_spherical_correction=do_spherical_correction,
                            fit_params=fit_params, data_identifier=data_id)

@@@ doing rf fits @@@


AttributeError: 'NoneType' object has no attribute 'columns'

In [932]:
print(data_id)

JC091|20190614|FOV1_zoom2p0x|rfs10|traces001|fit-2dgaus_dff_sphr


In [930]:
# fitdf_pos = rfutils.rfits_to_df(fit_results, scale_sigma=False, convert_coords=False,
#                 fit_params=fit_params, spherical=fit_params['do_spherical_correction'],
#                 row_vals=fit_params['row_vals'], col_vals=fit_params['col_vals'])


In [726]:
# #
# fig = rfutils.plot_top_rfs(fit_results, fit_params)
# pplot.label_figure(fig, data_id)
# figname = 'top%i_fit_thr_%.2f_%s_ellipse_sc_py3' % (len(fit_roi_list), fit_thr, fit_desc)
# pl.savefig(os.path.join(fit_params['rfdir'], '%s.png' % figname))
# print(figname)


In [913]:
# # Get data source
# rfdir=fit_params['rfdir']
# print(rfdir)
# traceid_dir = rfdir.split('/receptive_fields/')[0]
# soma_fpath = os.path.join(traceid_dir, 'data_arrays', 'np_subtracted.npz')

# trace_type='corrected'
# traces, labels, sdf, run_info = p3.load_dataset(soma_fpath, 
#                                     trace_type=trace_type,create_new=False)

/n/coxfs01/2p-data/JC091/20190614/FOV1_zoom2p0x/combined_rfs10_static/traces/traces001_601a1c_traces001_e2ec1e_traces001_473e42_traces001_72c645_traces001_2a113f/receptive_fields/fit-2dgaus_dff-no-cutoff


In [914]:
# Z-score or dff the traces:
# trials_by_cond = rfutils.get_trials_by_cond(labels)
# zscored_traces, zscores = p3.process_traces(traces, labels,
#                                         response_type=fit_params['response_type'],
#                                         nframes_post_onset=fit_params['nframes_post_onset'])
# # -------------------------------------------------------
# trials_by_cond = rfutils.get_trials_by_cond(labels)
# nx = len(fit_params['col_vals'])
# ny = len(fit_params['row_vals'])
# print("Error loading array, extracting now")
# print("...getting avg by cond")
# avg_resp_by_cond0 = rfutils.group_trial_values_by_cond(zscores, trials_by_cond, nx=nx, ny=ny,
#                                             do_spherical_correction=do_spherical_correction)
# if do_spherical_correction:
#     print("...doin spherical warps")
#     if n_processes>1:
#         avg_resp_by_cond0 = sphr_correct_maps_mp(avg_resp_by_cond0, fit_params,
#                                                     n_processes=n_processes, test_subset=test_subset)
#     else:
#         avg_resp_by_cond0 = sphr_correct_maps(avg_resp_by_cond0, fit_params,
#                                                     multiproc=False)
# print("...saved array")
# rfutils.save_rfmap_array(avg_resp_by_cond0, fit_params['rfdir'])

--- processed traces: dff
Error loading array, extracting now
...getting avg by cond
...saved array


In [1186]:
print(animalid, session, fov, rf_type)
fit_results, fit_params = rfutils.load_fit_results(animalid, session, fov,
                                     experiment=rf_type, traceid=traceid,
                                     response_type=response_type,
                                     do_spherical_correction=False)


JC084 20190522 FOV1_zoom2p0x rfs


In [1187]:

fit_results_sphr, fit_params_sphr = rfutils.load_fit_results(animalid, session, fov,
                                     experiment=rf_type, traceid=traceid,
                                     response_type=response_type,
                                     do_spherical_correction=True)

In [1189]:
fitdf_orig = rfutils.rfits_to_df(fit_results, fit_params=fit_params,
                               row_vals=row_vals, col_vals=col_vals, spherical=False)
fitdf_sphr = rfutils.rfits_to_df(fit_results_sphr, fit_params=fit_params_sphr,
                               row_vals=row_vals, col_vals=col_vals, spherical=True)
pass_rois_orig = fitdf_orig[fitdf_orig['r2']>0.5].dropna().index.tolist()
pass_rois_sphr = fitdf_sphr[fitdf_sphr['r2']>0.5].dropna().index.tolist()

print("N fit (orig): %i" % len(pass_rois_orig))
print("N fit (warp): %i" % len(pass_rois_sphr))

pass_rois_both = np.intersect1d(pass_rois_orig, pass_rois_sphr)
len(pass_rois_both)

Screen res (ds=3x): [360, 640]
N fit (orig): 134
N fit (warp): 97


96

In [1190]:

# fig = rfutils.plot_top_rfs(fitdf_sphr, fit_params_sphr, fit_roi_list=roi_list)
# pplot.label_figure(fig, data_id)
# figname = 'top%i_fit_thr_%.2f_%s_ellipse__p3' % (len(fit_roi_list), fit_thr, fit_desc)
# pl.savefig(os.path.join(fit_params_sphr['rfdir'], '%s.png' % figname))
# print(figname)


In [1191]:
#rid = 211 #388 #211
rid in pass_rois_both


False

In [1192]:
screen = p3.get_screen_dims()

In [1145]:
sdf = p3.get_stimuli(dk, rf_type)

In [1146]:
c1 = '#f15a29' #'k'
c2 =  '#662d91' #'w'
fig, ax = pl.subplots( figsize=(10, 5.7))
fig.patch.set_visible(False) #(False) #('off')

example_roi_list=[rid]
other_rois = [r for r in pass_rois_both if r not in example_roi_list]

ax = rfutils.plot_rfs_to_screen_pretty(fitdf_orig, sdf, screen, 
                               sigma_scale=1, #fit_params['sigma_scale'],
                               fit_roi_list=pass_rois_both, ax=ax, 
                               roi_colors=[c1]*len(other_rois), ellipse_lw=0.2)

other_rois = [r for r in pass_rois_both if r not in example_roi_list]
ax = rfutils.plot_rfs_to_screen_pretty(fitdf_sphr, sdf, screen, 
                               sigma_scale=1, #fit_params['sigma_scale'],
                               fit_roi_list=pass_rois_both, ax=ax, 
                               roi_colors=[c2]*len(other_rois), ellipse_lw=0.2)

ax = rfutils.plot_rfs_to_screen_pretty(fitdf_orig, sdf, screen, 
                               sigma_scale=1, #fit_params['sigma_scale'],
                               fit_roi_list=example_roi_list, ax=ax, 
                               roi_colors=[c1]*len(example_roi_list), ellipse_lw=2)
ax = rfutils.plot_rfs_to_screen_pretty(fitdf_sphr, sdf, screen, 
                               sigma_scale=1, #fit_params['sigma_scale'],
                               fit_roi_list=example_roi_list, ax=ax, 
                               roi_colors=[c2]*len(example_roi_list), ellipse_lw=2)

ax.patch.set_color([0.7]*3)
ax.patch.set_alpha(1)
ax.set_aspect('equal')

custom_lines = [mpl.lines.Line2D([0], [0], color=c1, lw=4),
                mpl.lines.Line2D([0], [0], color=c2, lw=4)]
ax.legend(custom_lines, ['Original', 'Spherical-Correction'],
          bbox_to_anchor=(1,1), loc='lower right')


pl.subplots_adjust(left=0.1, right=0.95, bottom=0.2, top=0.8)
pplot.label_figure(fig, data_id)
figname = '%s__screen_rfs__%s_%s_rid%i__p3' % (rf_type, va, dk, rid)
pl.savefig(os.path.join(sphr_dir, 'examples', '%s.svg' % (figname)), 
           bboxx_inches='tight')

<IPython.core.display.Javascript object>



In [1040]:
va, dk

('Li', '20190614_JC091_fov1')

In [1136]:
sort_by_fit = fitdf_orig.loc[pass_rois_both]\
                .sort_values(by='r2', ascending=False).index.tolist()
plot_rois = sort_by_fit[0:30]
print([i for i in plot_rois if i not in fitdf_orig.index.tolist()])
print([i for i in plot_rois if i not in fitdf_sphr.index.tolist()])


[]
[]


In [1142]:
plot_rois = sort_by_fit[0:30]
fig, axn = pl.subplots(6, 5, figsize=(8,8))
fig.patch.set_visible(False) #(False) #('off')

for ax, rid in zip(axn.flat, plot_rois):
    ax = rfutils.plot_rfs_to_screen_pretty(fitdf_orig, sdf, screen, 
                                   sigma_scale=1, #fit_params['sigma_scale'],
                                   fit_roi_list=[rid], ax=ax, 
                                   roi_colors=[c1],
                                   ellipse_lw=1)
    ax = rfutils.plot_rfs_to_screen_pretty(fitdf_sphr, sdf, screen, 
                                   sigma_scale=1, #fit_params['sigma_scale'],
                                   fit_roi_list=[rid], ax=ax, 
                                   roi_colors=[c2],
                                    ellipse_lw=1)
    ax.set_title(rid, loc='left', fontsize=4)
for ax in axn.flat:
    ax.patch.set_color([0.7]*3)
    ax.patch.set_alpha(1)
    ax.set_aspect('equal')

custom_lines = [mpl.lines.Line2D([0], [0], color=c1, lw=4),
                mpl.lines.Line2D([0], [0], color=c2, lw=4)]
axn.flat[-1].legend(custom_lines, ['Original', 'Spherical-Correction'],
          bbox_to_anchor=(1,0), loc='upper right')

fig.text(0.05, 0.9, '%s, %s (top %i cells). R=orig, B=corr.' % (va, dk, len(plot_rois)))
    
pl.subplots_adjust(left=0.05, right=0.97, bottom=0.2, top=0.8)
pplot.label_figure(fig, data_id)
figname = '%s__top%i_compare_%s_%s__p3' % (rf_type,len(plot_rois), va,dk)
pl.savefig(os.path.join(sphr_dir, 'examples', '%s.svg'%figname))
print(figname)

<IPython.core.display.Javascript object>

rfs10__top30_compare_Li_20190614_JC091_fov1__p3


In [703]:
cells_w_both['datakey'].unique()

array(['20190602_JC091_fov1', '20191017_JC113_fov1',
       '20190618_JC097_fov1', '20190613_JC097_fov1',
       '20190616_JC097_fov1', '20190622_JC085_fov1',
       '20191006_JC110_fov1'], dtype=object)

In [704]:
def checkerboard(shape):
    return np.indices(shape).sum(axis=0) % 2
import copy


In [833]:
def plot_corrected_gridlines(rfmap_o, rfmap_w, 
                     cart_x_ds, cart_y_ds, sphr_x_ds, sphr_y_ds, label=True,
                     col_vals=None, row_vals=None, interval=10, inline=1,
                     lw=0.5, line_color='m', cmap='bone', fontsize=12):
    mapcorX, mapcorY = np.meshgrid(range(len(col_vals)), range(len(row_vals)))
    fig, axn = pl.subplots(1,2, figsize=(8, 4))
    fig.suptitle('Remap monitor', fontsize=14, fontweight='bold')
    
    ax = axn[0]
    ax.set_title('Linear Map (deg)')
    currfig = ax.imshow(cart_x_ds, alpha=0.0)
    levels1 = range(int(np.floor(np.nanmin(cart_x_ds) / interval) * interval),
                    int((np.ceil(np.nanmax(cart_x_ds) / interval) + 1) * interval), interval)
    im1 = ax.contour(mapcorX, mapcorY, cart_x_ds, levels1, colors=line_color, linewidth=lw)
    if label:
        ax.clabel(im1, fontsize=fontsize, inline=inline, fmt='%2.1f')
    currfig = ax.imshow(cart_y_ds, alpha=0.0)
    levels2 = range(int(np.floor( np.nanmin(cart_y_ds) / interval) * interval),
                    int((np.ceil( np.nanmax(cart_y_ds) / interval) + 1) * interval), interval)
    im2 = ax.contour(mapcorX, mapcorY, cart_y_ds, levels2, colors=line_color, linewidth=lw)
    if label:
        ax.clabel(im2, fontsize=fontsize, inline=inline, fmt='%2.1f')
    ax.imshow(rfmap_o, cmap=cmap)
    
    ax=axn[1]
    ax.set_title('Spherical Map (deg)')
    deg_x_ds = np.rad2deg(sphr_x_ds)
    currfig = ax.imshow(deg_x_ds, alpha=0.)
    levels3 = range(int(np.floor( np.nanmin(deg_x_ds) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_x_ds) / interval) + 1) * interval), interval)
    im3 = ax.contour(mapcorX, mapcorY, deg_x_ds, levels3, colors=line_color, linewidth=lw)
    if label:
        ax.clabel(im3, fontsize=fontsize, inline=inline, fmt='%2.1f')

    deg_y_ds = np.rad2deg(sphr_y_ds)
    currfig = ax.imshow(deg_y_ds, alpha=0.0)
    levels4 = range(int(np.floor( np.nanmin(deg_y_ds) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_y_ds) / interval) + 1) * interval), interval)
    im4 = ax.contour(mapcorX, mapcorY, deg_y_ds, levels4, 
                     colors=line_color, linewidth=lw)
    if label:
        ax.clabel(im4, fontsize=fontsize, inline=inline, fmt='%2.1f')
    ax.imshow(rfmap_w, cmap=cmap)

    pl.subplots_adjust(wspace=0.5, hspace=0.5)

    return fig

In [834]:
sphr_x1.max()

0.8899829362724926

In [707]:
checkers = checkerboard((nrows, ncols))
checkmap = copy.copy(checkers)
screen_bounds_pix = rfutils.get_screen_lim_pixels(cart_x, cart_y,
                                        row_vals=row_vals, col_vals=col_vals)
(pix_bottom_edge, pix_left_edge, pix_top_edge, pix_right_edge) = screen_bounds_pix

# Upsample RF map to screen pixels
check_orig = rfutils.resample_map(checkmap, cart_x, cart_y, 
                          row_vals=row_vals, col_vals=col_vals,
                          resolution=resolution_ds)
# Warp resampled RF map
check_sphr = rfutils.warp_spherical(check_orig, 
                            sphr_x, sphr_y, cart_x, cart_y, 
                            normalize_range=True, method='linear')
# Trim
check_trim  = rfutils.trim_resampled_map(check_orig, screen_bounds_pix)
check_trim_sphr  = rfutils.trim_resampled_map(check_sphr, screen_bounds_pix)
# Downsample, so we don't have repeated values
check_o = cv2.resize(check_trim, (ncols, nrows))
check_c = cv2.resize(check_trim_sphr, (ncols, nrows))
# Get downsampled coordinates
# Trim and downsample coordinate space to match corrected map
cart_x_ds  = cv2.resize(cart_x[pix_top_edge:pix_bottom_edge, pix_left_edge:pix_right_edge], (ncols,nrows))
cart_y_ds  = cv2.resize(cart_y[pix_top_edge:pix_bottom_edge, pix_left_edge:pix_right_edge], (ncols, nrows))

sphr_x_ds  = cv2.resize(sphr_x[pix_top_edge:pix_bottom_edge, pix_left_edge:pix_right_edge], (ncols,nrows))
sphr_y_ds  = cv2.resize(sphr_y[pix_top_edge:pix_bottom_edge, pix_left_edge:pix_right_edge], (ncols, nrows))


In [835]:
fig = plot_corrected_gridlines(check_o, check_c, 
                    cart_x_ds, cart_y_ds, sphr_x_ds, sphr_y_ds, 
                    col_vals=col_vals, row_vals=row_vals, interval=10,
                              fontsize=12, inline=False, label=False)
# fig.text(0.05, 0.8, 'lin_x, lin_y to spherical coords')
# figname = 'gridlines__lin2sphr_%s' % rf_type

# fig.text(0.05, 0.8, 'cart_x, cart_y, full RES')
# figname = 'gridlines__cart2sphrFULL'

fig.text(0.05, 0.8, 'cart_x, cart_y to spherical coords (w DS=3)')
figname = 'gridlines__cart2sphr_%s' % rf_type

pl.savefig(os.path.join(sphr_dir, '%s.svg' % figname))

<IPython.core.display.Javascript object>

  


In [712]:
warp_x = rfutils.warp_spherical(cart_x, cart_x, cart_y, 
                                 sphr_x, sphr_y, normalize_range=True)
warp_y = rfutils.warp_spherical(cart_y, cart_x, cart_y, 
                                     sphr_x, sphr_y, normalize_range=True)


In [757]:
def plot_remap(lin_coord_x, lin_coord_y, deg_coord_x, deg_coord_y, cmap='viridis',
               interval=10,
               resolution=[1080, 1920], mapcorX=None, mapcorY=None):
    
    if mapcorX is None or mapcorY is None:
        mapcorX, mapcorY = np.meshgrid(range(resolution[1]), range(resolution[0]))

    f1, axn = pl.subplots(2,2,figsize=(8, 4))
    f1.suptitle('Remap monitor', fontsize=14, fontweight='bold')
    
    ax = axn[0, 0] #pl.subplot(221)
    ax.set_title('Linear Map X (cm)')
    currfig = ax.imshow(lin_coord_x, cmap=cmap)
    levels1 = range(int(np.floor(np.nanmin(lin_coord_x) / interval) * interval),
                    int((np.ceil(np.nanmax(lin_coord_x) / interval) + 1) * interval), interval)
    im1 = ax.contour(mapcorX, mapcorY, lin_coord_x, levels1, colors='k', linewidth=2)
    pl.colorbar(currfig, ax=ax, ticks=levels1)

    ax = axn[1,0] #pl.subplot(223)
    ax.set_title('Linear Map Y (cm)')
    currfig = ax.imshow(lin_coord_y, cmap=cmap)
    levels2 = range(int(np.floor( np.nanmin(lin_coord_y) / interval) * interval),
                    int((np.ceil( np.nanmax(lin_coord_y) / interval) + 1) * interval), interval)
    im2 = ax.contour(mapcorX, mapcorY, lin_coord_y, levels2, colors='k', linewidth=2)
    f1.colorbar(currfig, ax=ax, ticks=levels2)

    ax = axn[0, 1] #pl.subplot(222)
    ax.set_title('Spherical Map X (deg)')
    currfig = ax.imshow(deg_coord_x, cmap=cmap)
    levels3 = range(int(np.floor( np.nanmin(deg_coord_x) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_coord_x) / interval) + 1) * interval), interval)
    im3 = ax.contour(mapcorX, mapcorY, deg_coord_x, levels3, colors='k', linewidth=2)
    #        plt.clabel(im3, levels3, fontsize = 10, inline = 1, fmt='%2.1f')
    f1.colorbar(currfig, ax=ax, ticks=levels3)

    ax = axn[1, 1] #pl.subplot(224)
    ax.set_title('Spherical Map Y (deg)')
    currfig = ax.imshow(deg_coord_y, cmap=cmap)
    levels4 = range(int(np.floor( np.nanmin(deg_coord_y) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_coord_y) / interval) + 1) * interval), interval)
    im4 = ax.contour(mapcorX, mapcorY, deg_coord_y, levels4, colors='k', linewidth=2)
    #        plt.clabel(im4, levels4, fontsize = 10, inline = 1, fmt='%2.1f')
    f1.colorbar(currfig, ax=ax, ticks=levels4)

    for ax in axn.flat:
        ax.tick_params(which='both', axis='both', size=0)
        ax.set_xticks([])
        ax.set_yticks([])
        
    return fig

In [842]:
fig = plot_remap(cart_x, cart_y, warp_x, warp_y, cmap='viridis',
          resolution=resolution_ds) 
#, np.rad2deg(sphr_pointsTh), np.rad2deg(sphr_pointsPh))

pl.savefig(os.path.join(sphr_dir, 'stimulus_v_perceived.svg'))

<IPython.core.display.Javascript object>

  app.launch_new_instance()


In [839]:
def plot_isolines(lin_coord_x, lin_coord_y, deg_coord_x, deg_coord_y, cmap='viridis',
               interval=10, fontsize=12, inline=1, c1='magenta', c2='green', lw=1,
                manual_x=None, manual_y=None,rightside_up=True,
               resolution=[1080, 1920], mapcorX=None, mapcorY=None):
    
    if mapcorX is None or mapcorY is None:
        mapcorX, mapcorY = np.meshgrid(range(resolution[1]), range(resolution[0]))

    f1, axn = pl.subplots(1,2,figsize=(8, 4))
    f1.suptitle('Remap monitor', fontsize=14, fontweight='bold')
    
    ax = axn[0] #pl.subplot(221)
    ax.set_title('Linear Map (deg.)')
    currfig = ax.imshow(lin_coord_x, cmap=cmap, alpha=0)
    levels1 = range(int(np.floor(np.nanmin(lin_coord_x) / interval) * interval),
                    int((np.ceil(np.nanmax(lin_coord_x) / interval) + 1) * interval), interval)
    im1 = ax.contour(mapcorX, mapcorY, lin_coord_x, levels1, colors=c1, linewidth=lw)
    if manual_x is not None:
        ax.clabel(im1, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 manual=manual_y, rightside_up=rightside_up)
    else:
        ax.clabel(im1, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 rightside_up=rightside_up)
        
    currfig = ax.imshow(lin_coord_y, cmap=cmap, alpha=0)
    levels2 = range(int(np.floor( np.nanmin(lin_coord_y) / interval) * interval),
                    int((np.ceil( np.nanmax(lin_coord_y) / interval) + 1) * interval), interval)
    im2 = ax.contour(mapcorX, mapcorY, lin_coord_y, levels2, colors=c2, linewidth=lw)
    if manual_y is not None:
        ax.clabel(im2, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 manual=manual_y, rightside_up=rightside_up)
    else:
        ax.clabel(im2, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 rightside_up=rightside_up)
        
    ax = axn[1] #pl.subplot(222)
    ax.set_title('Spherical Map (deg)')
    currfig = ax.imshow(deg_coord_x, cmap=cmap, alpha=0)
    levels3 = range(int(np.floor( np.nanmin(deg_coord_x) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_coord_x) / interval) + 1) * interval), interval)
    im3 = ax.contour(mapcorX, mapcorY, deg_coord_x, levels3, colors=c1, linewidth=lw)
    if manual_x is not None:
        ax.clabel(im3, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 manual=manual_x, rightside_up=rightside_up)
    else:
        ax.clabel(im3, fontsize=fontsize, inline=inline, fmt='%2.1f', 
                  rightside_up=rightside_up)
    
    currfig = ax.imshow(deg_coord_y, cmap=cmap,alpha=0)
    levels4 = range(int(np.floor( np.nanmin(deg_coord_y) / interval) * interval),
                    int((np.ceil( np.nanmax(deg_coord_y) / interval) + 1) * interval), interval)
    im4 = ax.contour(mapcorX, mapcorY, deg_coord_y, levels4, colors=c2, linewidth=lw)
    if manual_y is not None:
        ax.clabel(im4, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 manual=manual_y, rightside_up=rightside_up)
    else:
        ax.clabel(im4, fontsize=fontsize, inline=inline, fmt='%2.1f',
                 rightside_up=rightside_up)

    for ax in axn.flat:
        ax.tick_params(which='both', axis='both', size=0)
        ax.set_xticks([])
        ax.set_yticks([])
        
    return fig

In [840]:
manual_x = [(i, 0) for i in np.arange(-20, 30, 10)]
manual_y = [( 0, i) for i in np.arange(-20, 30, 10)]

# levels_x = range(int(np.floor(np.nanmin(warp_x) / interval) * interval),
#             int((np.ceil(np.nanmax(warp_x) / interval) + 1) * interval), interval)
# manual_x = [(i, 10) for i in list(levels_x)[0:3]]

fig = plot_isolines(cart_x, cart_y, warp_x, warp_y, cmap='Spectral',
          resolution=resolution_ds, manual_x=None, manual_y=None,
                   rightside_up=True, interval=10, fontsize=6 )
#, np.rad2deg(sphr_pointsTh), np.rad2deg(sphr_pointsPh))

pl.savefig(os.path.join(sphr_dir, 'isolines_perceived.svg'))

<IPython.core.display.Javascript object>



In [841]:
corrected_x, corrected_y = np.rad2deg(sphr_x), np.rad2deg(sphr_y)

# levels_x = range(int(np.floor(np.nanmin(warp_x) / interval) * interval),
#             int((np.ceil(np.nanmax(warp_x) / interval) + 1) * interval), interval)
# manual_x = [(i, 10) for i in list(levels_x)[0:3]]

fig = plot_isolines(cart_x, cart_y, corrected_x, corrected_y, cmap='Spectral',
          resolution=resolution_ds, manual_x=None, manual_y=None,
          rightside_up=True, interval=10, fontsize=6 )
#, np.rad2deg(sphr_pointsTh), np.rad2deg(sphr_pointsPh))

pl.savefig(os.path.join(sphr_dir, 'isolines_corrected.svg'))

<IPython.core.display.Javascript object>



In [816]:
deg_y_ds = np.rad2deg(sphr_y_ds)

In [710]:
deg_y_ds.max()

27.346801004097372

In [711]:
cart_y_ds.min(), cart_y_ds.max()

(-24.836262344897445, 25.02344897442391)

In [None]:
aggregate_dir = '/n/coxfs01/julianarhee/aggregate-visual-areas'

response_type='dff'
responsive_test='ROC'
overlap_thr=None
trial_epoch = 'plushalf'
traceid = 'traces001'

responsive_thr = 0.05 if responsive_test=='ROC' else 10.0
overlap_str = 'noRF' if overlap_thr in [None, 'None'] else 'overlap%.2f' % overlap_thr

fname = 'neuraldata_%s_%s_%s-thr-%.2f_%s_%s' \
        % (traceid, response_type, responsive_test, responsive_thr, trial_epoch, overlap_str)

datafile = os.path.join(aggregate_dir, 'data-stats', 'tmp_data', '%s.pkl' % fname)
with open(datafile, 'rb') as f:
    D = pkl.load(f, encoding='latin1')
D.keys()

In [None]:
def get_dsets_with_max_rfs(rf_dsets, assigned_cells):

    # Load all RF data
    from pipeline.python.classifications import rf_utils as rfutils
    all_rfdfs = rfutils.load_aggregate_rfs(rf_dsets)
    all_rfs = get_rfdata(assigned_cells, all_rfdfs, verbose=False, average_repeats=True)

    best_dfs = get_dsets_with_most_cells(all_rfs) #, assigned_cells)

    return best_dfs

