In [None]:
from collections.abc import Iterable
import os

from datetime import datetime, timedelta
from dateutil import tz
from hdmf.backends.hdf5.h5_utils import H5DataIO
from hdmf.container import Container
from hdmf.data_utils import DataChunkIterator
import latex
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import numpy as np
import pandas as pd
import pickle
from pynwb import load_namespaces, get_class, register_class, NWBFile, TimeSeries, NWBHDF5IO
from pynwb.file import MultiContainerInterface, NWBContainer, Device, Subject
from pynwb.ophys import ImageSeries, OnePhotonSeries, OpticalChannel, ImageSegmentation, PlaneSegmentation, Fluorescence, DfOverF, CorrectedImageStack, MotionCorrection, RoiResponseSeries, ImagingPlane
from pynwb.core import NWBDataInterface
from pynwb.epoch import TimeIntervals
from pynwb.behavior import SpatialSeries, Position
from pynwb.image import ImageSeries
import pywt
import scipy.io as sio
import scipy
from scipy.stats import multivariate_normal
from scipy.optimize import linear_sum_assignment
import seaborn as sns
import skimage.io as skio
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from tifffile import TiffFile
import tifffile

from atlas import loadmat, NPAtlas, NWBAtlas
from process_file import get_nwb_neurons, get_dataset_neurons, get_dataset_online, combine_datasets, get_pairings, get_color_discrim, get_neur_nums
from stats import get_summary_stats, analyze_pairs, get_accuracy
from visualization import plot_num_heatmap, plot_std_heatmap, plot_summary_stats, plot_color_discrim, plot_accuracies, plot_visualizations_atlas, plot_visualizations_data, plot_atlas2d_super
from utils import covar_to_coord, convert_coordinates, maha_dist, run_linear_assignment

# ndx_mulitchannel_volume is the novel NWB extension for multichannel optophysiology in C. elegans
from ndx_multichannel_volume import CElegansSubject, OpticalChannelReferences, OpticalChannelPlus, ImagingVolume, VolumeSegmentation, MultiChannelVolume, MultiChannelVolumeSeries

import os
import PyQt6.QtCore
os.environ["QT_API"] = "pyqt6"


In [None]:
import os
import PyQt6.QtCore
os.environ["QT_API"] = "pyqt6"


In [None]:
filepath = '/Users/danielysprague/foco_lab/data/NP_nwb/56_YAaDV.nwb'
#filepath = '/Users/danielysprague/foco_lab/data/NWB_foco/2021-12-03-w00-NP1.nwb'

with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
    read_nwbfile = io.read()
    #print(read_nwbfile.processing['ProcessedImage'])
    subject = read_nwbfile.subject #get the metadata about the experiment subject
    growth_stage = subject.growth_stage
    image = read_nwbfile.acquisition['NeuroPALImageRaw'].data[:] #get the neuroPAL image as a np array
    channels = read_nwbfile.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
    im_vol = read_nwbfile.acquisition['NeuroPALImageRaw'].imaging_volume #get the metadata associated with the imaging acquisition
    seg = read_nwbfile.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:] #get the locations of neuron centers
    labels = read_nwbfile.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
    optchans = im_vol.optical_channel_plus[:] #get information about all of the optical channels used in acquisition
    chan_refs = read_nwbfile.processing['NeuroPAL']['OpticalChannelRefs'].channels[:] #get the order of the optical channels in the image
    #calcium_frames = read_nwbfile.acquisition['CalciumImageSeries'].data[0:15, :,:,:] #load the first 15 frames of the calcium images
    #print(read_nwbfile.acquisition['CalciumImageSeries'].dimension[:])
    #fluor = read_nwbfile.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
    #calc_labels = read_nwbfile.processing['CalciumActivity']['NeuronIDs'].labels[:]

    NIR = read_nwbfile.processing['BF_NIR']['BrightFieldNIR'].data[:,:,:]

read_nwbfile




In [None]:
ts = fluor.T

ref_ind = np.argwhere(calc_labels=='AVAR')
ref_trace = np.squeeze(ts[ref_ind,:])

to_plot = ref_trace / np.mean(ref_trace)

print(ref_ind)
print(ref_trace.shape)

plt.plot(ref_trace)
plt.show()

In [None]:
calc_max = np.max(calcium_frames, axis=0)
zmax_calc = np.max(calcium_frames[10,:,:,:], axis=2)
plt.imshow(zmax_calc)
plt.show()
plt.figure()

plt.imshow(NIR[10,:,:])

In [None]:
blobs = pd.DataFrame.from_records(seg, columns = ['X', 'Y', 'Z', 'weight', 'ID'])
blobs = blobs.drop(['weight'], axis=1)
blobs = blobs.replace('nan', np.nan, regex=True) 

print(image.shape)

RGB = image[:,:,:,channels[:-1]]/np.max(image)

print(RGB.shape)

Zmax = np.max(RGB, axis=2)
Ymax = np.max(RGB, axis=1)

plt.figure()

plt.imshow(np.transpose(Zmax, [1,0,2]))
plt.scatter(blobs['x'], blobs['y'], s=5)
plt.xlim((0, Zmax.shape[0]))
plt.ylim((0, Zmax.shape[1]))
plt.gca().set_aspect('equal')

plt.show()

plt.figure()

plt.imshow(np.transpose(Ymax, [1,0,2]))
plt.scatter(blobs['x'], blobs['z'], s=5)
plt.xlim((0, Ymax.shape[0]))
plt.ylim((0, Ymax.shape[1]))
plt.gca().set_aspect('equal')

plt.show()

### Wavelet analysis

%matplotlib qt

#filepath = '/Users/danielysprague/foco_lab/data/NWB_Ray/20230506-15-01-45.nwb'
#filepath = '/Users/danielysprague/foco_lab/data/NWB_foco/2021-12-03-w00-NP1.nwb'
filepath = '/Users/danielysprague/foco_lab/data/Yemini_NWB/20190924_01.nwb'
#filepath = '/Users/danielysprague/foco_lab/data/kimura_full/sub-230928-02_ses-20230928T111400_ophys.nwb'

with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
    read_nwbfile = io.read()
    #print(read_nwbfile.processing['ProcessedImage'])
    subject = read_nwbfile.subject #get the metadata about the experiment subject
    growth_stage = subject.growth_stage
    channels = read_nwbfile.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
    im_vol = read_nwbfile.acquisition['NeuroPALImageRaw'].imaging_volume #get the metadata associated with the imaging acquisition
    seg = read_nwbfile.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:] #get the locations of neuron centers
    labels = read_nwbfile.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
    optchans = im_vol.optical_channel_plus[:] #get information about all of the optical channels used in acquisition
    chan_refs = read_nwbfile.processing['NeuroPAL']['OpticalChannelRefs'].channels[:] #get the order of the optical channels in the image
    rate = read_nwbfile.acquisition['CalciumImageSeries'].rate
    fluor = read_nwbfile.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
    calc_labels = read_nwbfile.processing['CalciumActivity']['NeuronIDs'].labels[:]


traces = np.transpose(fluor)

labels = ["".join(label) for label in labels]

plt.rcParams.update(plt.rcParamsDefault)

def plot_traces(traces, rate, labels, selected):

    plt.rcParams.update({'font.size':20})

    seconds = traces.shape[1]//rate

    fig, axs = plt.subplots(len(selected),1, figsize=(5,6))

    for i, neuron in enumerate(selected):
        index = np.argwhere(np.asarray(labels)==neuron)
        trace = traces[np.squeeze(index),:]

        axs[i].plot(np.linspace(0,seconds,traces.shape[1]), trace)
        axs[i].set_ylabel(r'$\Delta$F/F')
        axs[i].set_xlim(0,seconds)
        axs[i].set_yticks([])
        axs[i].set_title(neuron, loc='left')

        axs[i].spines['right'].set_visible(False)
        axs[i].spines['top'].set_visible(False)

    axs[4].set_xlabel('time (seconds)')

    plt.tight_layout()
    plt.show()

plot_traces(traces, rate, labels, ['AVAR', 'SMDVR', 'AWCR','RID', 'ASHR'])

In [None]:
import scipy

import sklearn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn import svm
import sklearn.metrics as metrics

import tvregdiff

def run_pca(traces, labels, reference, rate, n_components):
    deriv_iter = 5
    deriv_alpha = 0.001
    derivs = np.zeros(traces.shape)
    for n in range(traces.shape[0]):
        trace = traces[n,:]
        deriv = tvregdiff.TVRegDiff(trace, deriv_iter, deriv_alpha, cgmaxit=100, diagflag=False, plotflag=False)

        derivs[n,:] = deriv
        
    X = derivs.T
    keep_indices = ~np.isnan(X).any(axis=0)
    X = X[:,keep_indices]
    keep_labels = labels[keep_indices]

    X = scipy.signal.detrend(X, axis=0)

    pca = PCA(n_components = n_components)

    scaler = StandardScaler()

    pipe = Pipeline(steps= [("scaler", scaler), ("pca", pca)])
    pipe.fit(X)

    mat = derivs.T
    mat = mat[:, keep_indices]
    transform = pipe.transform(mat)

    ref_ind = np.argwhere(labels==reference)
    ref_trace = derivs[ref_ind,:]

    weights = pca.components_.T

    x1 = np.squeeze(ref_trace)
    x2 = transform[:,0]

    corr = scipy.signal.correlate(x1,x2)
    lags = scipy.signal.correlation_lags(len(x1), len(x2))

    xvals = lags/rate
    window = 40

    # calculate xcorr between pc1 and reference, flip if xcorr<0
    bounds = [len(lags) // 2 - window, len(lags) // 2 + window]
    to_plot_x = xvals[bounds[0] : bounds[1]]
    to_plot_y = corr[bounds[0] : bounds[1]]
    peak_xcorr_value = to_plot_y[np.argmax(np.abs(to_plot_y))]
    if peak_xcorr_value < 0:
        transform[:, 0] = -1 * transform[:, 0]
        weights[:, 0] = -1 * weights[:, 0]

    if n_components > 1:
        x2 = transform[:,1] #pc2
        corr = scipy.signal.correlate(x1,x2)
        lags = scipy.signal.correlation_lags(len(x1), len(x2))

        xvals = lags / rate

        # calculate xcorr between pc2 and ava
        xvals = lags / rate
        bounds = [len(lags) // 2 - window, len(lags) // 2 + window]
        to_plot_x = xvals[bounds[0] : bounds[1]]
        to_plot_y = corr[bounds[0] : bounds[1]]
        peak_xcorr = to_plot_x[np.argmax(to_plot_y)]

        # if peak xcorr > 0, flip pc2
        if peak_xcorr > -1:
            transform[:, 1] = -1 * transform[:, 1]
            weights[:, 1] = -1 * weights[:, 1]

        x2 = transform[:, 2] #pc3
        corr = scipy.signal.correlate(x1, x2)
        lags = scipy.signal.correlation_lags(len(x1), len(x2))
        xvals = lags / rate
        bounds = [len(lags) // 2 - window, len(lags) // 2 + window]
        to_plot_x = xvals[bounds[0] : bounds[1]]
        to_plot_y = corr[bounds[0] : bounds[1]]
        xcorr_at_0 = to_plot_y[window]
        if xcorr_at_0 < 0:
            transform[:, 2] = -1 * transform[:, 2]
            weights[:, 2] = -1 * weights[:, 2]

    return pca, weights, transform, keep_labels
'''  
filepath = '/Users/danielysprague/foco_lab/data/NWB_Ray/20230506-15-01-45.nwb'
#filepath = '/Users/danielysprague/foco_lab/data/Yemini_nwb/20190924_01.nwb'
#filepath = '/Users/danielysprague/foco_lab/data/NWB_foco/2021-12-03-w00-NP1.nwb'

with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
    read_nwb = io.read()
    identifier = read_nwb.identifier
    seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
    labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
    #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
    channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
    image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
    scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

    rate = read_nwb.acquisition['CalciumImageSeries'].rate
    #fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
    dfof = read_nwb.processing['CalciumActivity']['SignalDFoF']['SignalCalciumImResponseSeries'].data[:]
    calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

    labels = np.asarray(["".join(label) for label in labels])

#dfof = fluor / np.mean(fluor, axis=0)

print(dfof.shape)
print(calc_labels.shape)

pca, transform, labels = run_pca(dfof.T, calc_labels, 'AVAR', rate, 3)
'''

In [None]:
%matplotlib inline

Yemini_good_PCA = []

for file in os.listdir('/Users/danielysprague/foco_lab/data/Yemini_NWB'):
    if not file[-4:] == '.nwb':
        continue

    print(file)
    filepath = '/Users/danielysprague/foco_lab/data/Yemini_NWB/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        identifier = read_nwb.identifier
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        rate = read_nwb.acquisition['CalciumImageSeries'].rate
        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

        labels = np.asarray(["".join(label) for label in labels])

    dfof = fluor[20:,:]/ np.mean(fluor[20:,:], axis=0)

    if not 'AVAL' in calc_labels:
        continue

    pca, weights, transform, keep_labels = run_pca(dfof.T, calc_labels, 'AVAL', rate, 3)

    fig, axs = plt.subplots(3,1, sharex=True)
    axs[0].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,0])
    axs[1].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,1])
    axs[2].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,2])

    plt.show()

    x = input('Keep?')

    if x == 'y':
        Yemini_good_PCA.append(file)

In [None]:
%matplotlib inline

Kimura_good_PCA = []

for file in os.listdir('/Users/danielysprague/foco_lab/data/kimura_full'):
    if not file[-4:] == '.nwb':
        continue

    print(file)
    filepath = '/Users/danielysprague/foco_lab/data/kimura_full/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        identifier = read_nwb.identifier
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        #labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        rate = read_nwb.acquisition['CalciumImageSeries'].rate
        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

        labels = np.asarray(["".join(label) for label in labels])

    dfof = fluor[20:,:]/ np.mean(fluor[20:,:], axis=0)

    if not 'AVAL' in calc_labels:
        continue

    pca, weights, transform, keep_labels = run_pca(dfof.T, calc_labels, 'AVAL', rate, 3)

    fig, axs = plt.subplots(3,1, sharex=True)
    axs[0].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,0])
    axs[1].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,1])
    axs[2].plot(np.linspace(0,transform.shape[0]/rate, transform.shape[0]), transform[:,2])

    plt.show()

    x = input('Keep?')

    if x == 'y':
        Kimura_good_PCA.append(file)

In [None]:
pcs = []
rates = []

df = pd.DataFrame(columns=['Neuron', 'PC', 'Weight', 'File'])

for file in Yemini_good_PCA:
    if not file[-4:] == '.nwb':
        continue

    print(file)
    filepath = '/Users/danielysprague/foco_lab/data/Yemini_nwb/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        identifier = read_nwb.identifier
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        #labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        rate = read_nwb.acquisition['CalciumImageSeries'].rate
        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

        labels = np.asarray(["".join(label) for label in labels])

    dfof = fluor[20:,:]/ np.mean(fluor[20:,:], axis=0)

    if not 'AVAL' in calc_labels:
        continue

    pca, weights, transform, keep_labels = run_pca(dfof.T, calc_labels, 'AVAL', rate, 3)

    #index_AVA = np.argwhere((np.asarray(keep_labels)=='AVAR') | (np.asarray(keep_labels)=='AVAL'))
    #index_SMDV = np.argwhere((np.asarray(keep_labels)=='SMDVR') | (np.asarray(keep_labels)=='SMDVL'))
    #index_AWC = np.argwhere((np.asarray(keep_labels)=='AWCR') | (np.asarray(keep_labels)=='AWCL'))
    #index_ASH = np.argwhere((np.asarray(keep_labels)=='ASHR') | (np.asarray(keep_labels)=='ASHL'))
    #index_RID = np.argwhere(np.asarray(keep_labels)=='RID')

    for i in range(len(keep_labels)):
        label = keep_labels[i]
        if label in ['RID', 'RMEV', 'VB2', 'AWCR', 'AWCL']:
            df.loc[len(df.index)] = [label, 'PC1', weights[i,0], file[:-4]]
            df.loc[len(df.index)] = [label, 'PC2', weights[i,1], file[:-4]]
            df.loc[len(df.index)] = [label, 'PC3', weights[i,2], file[:-4]]
        elif label in ['SMDVR', 'SMDVL', 'AVAR', 'AVAL']:
            df.loc[len(df.index)] = [label[:-1], 'PC1', weights[i,0], file[:-4]]
            df.loc[len(df.index)] = [label[:-1], 'PC2', weights[i,1], file[:-4]]
            df.loc[len(df.index)] = [label[:-1], 'PC3', weights[i,2], file[:-4]]

    pcs.append(transform)
    rates.append(rate)

for file in Kimura_good_PCA:
    if not file[-4:] == '.nwb':
        continue

    print(file)
    filepath = '/Users/danielysprague/foco_lab/data/kimura_full/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        identifier = read_nwb.identifier
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        #labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        rate = read_nwb.acquisition['CalciumImageSeries'].rate
        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

        labels = np.asarray(["".join(label) for label in labels])

    dfof = fluor[20:,:]/ np.mean(fluor[20:,:], axis=0)

    if not 'AVAL' in calc_labels:
        continue

    pca, weights, transform, keep_labels = run_pca(dfof.T, calc_labels, 'AVAL', rate, 3)

    #index_AVA = np.argwhere((np.asarray(keep_labels)=='AVAR') | (np.asarray(keep_labels)=='AVAL'))
    #index_SMDV = np.argwhere((np.asarray(keep_labels)=='SMDVR') | (np.asarray(keep_labels)=='SMDVL'))
    #index_AWC = np.argwhere((np.asarray(keep_labels)=='AWCR') | (np.asarray(keep_labels)=='AWCL'))
    #index_ASH = np.argwhere((np.asarray(keep_labels)=='ASHR') | (np.asarray(keep_labels)=='ASHL'))
    #index_RID = np.argwhere(np.asarray(keep_labels)=='RID')

    for i in range(len(keep_labels)):
        label = keep_labels[i]
        if label in ['RID', 'RMEV', 'VB2', 'AWCR', 'AWCL']:
            df.loc[len(df.index)] = [label, 'PC1', weights[i,0], file[:-4]]
            df.loc[len(df.index)] = [label, 'PC2', weights[i,1], file[:-4]]
            df.loc[len(df.index)] = [label, 'PC3', weights[i,2], file[:-4]]
        elif label in ['SMDVR', 'SMDVL', 'AVAR', 'AVAL']:
            df.loc[len(df.index)] = [label[:-1], 'PC1', weights[i,0], file[:-4]]
            df.loc[len(df.index)] = [label[:-1], 'PC2', weights[i,1], file[:-4]]
            df.loc[len(df.index)] = [label[:-1], 'PC3', weights[i,2], file[:-4]]

    pcs.append(transform)
    rates.append(rate)

In [None]:
yem_files = [file[:-4] for file in Yemini_good_PCA]
kim_files = [file[:-4] for file in Kimura_good_PCA]

In [None]:
%matplotlib qt
plt.figure()
sns.boxplot(data=df[df['File'].isin(kim_files)], x='Neuron', y='Weight', hue='PC', order= ['RID', 'AVA', 'SMDV', 'RMEV', 'VB2', 'AWCR', 'AWCL'])
plt.show()

In [None]:
pcs = []
rates = []

df = pd.DataFrame(columns=['Neuron', 'PC', 'Weight', 'File'])

for folder in ['Yemini_NWB', 'Kimura_full']:
    for file in os.listdir('/Users/danielysprague/foco_lab/data/'+folder):
        if not file[-4:] == '.nwb':
            continue

        print(file)
        filepath = '/Users/danielysprague/foco_lab/data/' + folder+ '/' + file

        with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
            read_nwb = io.read()
            identifier = read_nwb.identifier
            seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
            #labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
            #labels_index = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels_index'][:]
            channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
            image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
            scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

            rate = read_nwb.acquisition['CalciumImageSeries'].rate
            fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
            calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

            labels = np.asarray(["".join(label) for label in labels])

        dfof = (fluor[20:,:] -np.mean(fluor[20:,:]))/ np.mean(fluor[20:,:], axis=0)

        if not 'AVAL' in calc_labels:
            continue

        pca, weights, transform, keep_labels = run_pca(dfof.T, calc_labels, 'AVAL', rate, 3)

        #index_AVA = np.argwhere((np.asarray(keep_labels)=='AVAR') | (np.asarray(keep_labels)=='AVAL'))
        #index_SMDV = np.argwhere((np.asarray(keep_labels)=='SMDVR') | (np.asarray(keep_labels)=='SMDVL'))
        #index_AWC = np.argwhere((np.asarray(keep_labels)=='AWCR') | (np.asarray(keep_labels)=='AWCL'))
        #index_ASH = np.argwhere((np.asarray(keep_labels)=='ASHR') | (np.asarray(keep_labels)=='ASHL'))
        #index_RID = np.argwhere(np.asarray(keep_labels)=='RID')

        for i in range(len(keep_labels)):
            label = keep_labels[i]
            if label in ['RID', 'RMEV', 'VB2', 'AWCR', 'AWCL']:
                df.loc[len(df.index)] = [label, 'PC1', weights[i,0], file[:-4]]
                df.loc[len(df.index)] = [label, 'PC2', weights[i,1], file[:-4]]
                df.loc[len(df.index)] = [label, 'PC3', weights[i,2], file[:-4]]
            elif label in ['SMDVR', 'SMDVL', 'AVAR', 'AVAL']:
                df.loc[len(df.index)] = [label[:-1], 'PC1', weights[i,0], file[:-4]]
                df.loc[len(df.index)] = [label[:-1], 'PC2', weights[i,1], file[:-4]]
                df.loc[len(df.index)] = [label[:-1], 'PC3', weights[i,2], file[:-4]]

        pcs.append(transform)
        rates.append(rate)



In [None]:
plt.figure()
sns.boxplot(data=df, x='Neuron', y='Weight', hue='PC', order= ['RID', 'AVA', 'SMDV', 'RMEV', 'VB2', 'AWCR', 'AWCL'])
plt.show()

In [None]:
fig, axs  = plt.subplots(3,5)
for i in range(5):
    axs[0,i].plot(np.linspace(0,pcs[i].shape[0]/rates[i], pcs[i].shape[0]), pcs[i][:,0])
    axs[1,i].plot(np.linspace(0,pcs[i].shape[0]/rates[i], pcs[i].shape[0]), pcs[i][:,1])
    axs[2,i].plot(np.linspace(0,pcs[i].shape[0]/rates[i], pcs[i].shape[0]), pcs[i][:,2])
    axs[0,0].set_ylabel('PC1')
    axs[1,0].set_ylabel('PC2')
    axs[2,0].set_ylabel('PC3')
plt.show()

In [None]:
fig, axs  = plt.subplots(3,5)
for i in range(5):
    axs[0,i].plot(np.linspace(0,pcs[i+21].shape[0]/rates[i+21], pcs[i+21].shape[0]), pcs[i+21][:,0])
    axs[1,i].plot(np.linspace(0,pcs[i+21].shape[0]/rates[i+21], pcs[i+21].shape[0]), pcs[i+21][:,1])
    axs[2,i].plot(np.linspace(0,pcs[i+21].shape[0]/rates[i+21], pcs[i+21].shape[0]), pcs[i+21][:,2])
plt.show()

In [None]:
%matplotlib qt

traces = np.transpose(fluor)

labels = ["".join(label) for label in labels]

def plot_traces(traces, rate, labels, selected):

    seconds = traces.shape[1]//rate

    fig, axs = plt.subplots(len(selected),1, figsize=(5,6))

    for i, neuron in enumerate(selected):
        index = np.argwhere(np.asarray(labels)==neuron)
        trace = traces[np.squeeze(index),:]

        axs[i].plot(np.linspace(0,seconds,traces.shape[1]), trace)
        axs[i].set_xlabel('seconds')
        axs[i].set_ylabel('DFoF')
        axs[i].set_xlim(0,seconds)
        axs[i].set_title(neuron)

    plt.tight_layout()
    plt.show()

plot_traces(traces, rate, labels, ['AVAR', 'SMDVR', 'AWCR','RID', 'ASHR'])

In [None]:
def plot_wavelet_decomp(signal, waveletname, level):
    fig, ax = plt.subplots(figsize=(6,1))
    ax.set_title("Original Signal: ")
    ax.plot(signal)
    plt.show()
        
    data = signal

    detail_coefs = []
    
    fig, axarr = plt.subplots(nrows=5, ncols=2, figsize=(6,6))
    for ii in range(level):
        (data, coeff_d) = pywt.dwt(data, waveletname)
        detail_coefs.append(coeff_d)
        axarr[ii, 0].plot(data, 'r')
        axarr[ii, 1].plot(coeff_d, 'g')
        axarr[ii, 0].set_ylabel("Level {}".format(ii + 1), fontsize=14, rotation=90)
        axarr[ii, 0].set_yticklabels([])
        if ii == 0:
            axarr[ii, 0].set_title("Approximation coefficients", fontsize=14)
            axarr[ii, 1].set_title("Detail coefficients", fontsize=14)
        axarr[ii, 1].set_yticklabels([])
    plt.tight_layout()
    plt.show()

    return detail_coefs

coefs_db5 = plot_wavelet_decomp(trace, 'db5', 5)
coefs_sym5 = plot_wavelet_decomp(trace, 'sym5', 5)

In [None]:
db5_power = []
sym5_power = []

for i in range(len(coefs_db5)):
    db5_power.append(np.sqrt(np.mean(coefs_db5[i]**2)))
    sym5_power.append(np.sqrt(np.mean(coefs_sym5[i]**2)))

plt.plot(2**np.linspace(0,4,5),db5_power)
plt.plot(2**np.linspace(0,4,5),sym5_power)
plt.xlabel('Scale of wavelet (s)')
plt.ylabel('RMS power of detail coefficients')

plt.show()



In [None]:
filepath = '/Users/danielysprague/foco_lab/data/kimura_full/sub-230928-02_ses-20230928T111400_ophys.nwb'

with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
    read_nwb = io.read()
    seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
    labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
    channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
    image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
    scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

    fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
    calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]

print(labels)
print(calc_labels)

read_nwb


In [None]:
labels = sio.loadmat('/Users/danielysprague/foco_lab/data/Yemini_21/OH16230/Heads/20190924_01/')['gclabels']
gclabels = np.asarray([label.replace(" ","") for label in labels])
print(gclabels.shape)
coefs_sym5 = plot_wavelet_decomp(fluor[:,0], 'sym5', 5)

In [None]:
%matplotlib qt

AVA_traces = []
SMDV_traces = []
AWC_traces = []
ASH_traces = []
RID_traces = []

for file in os.listdir('/Users/danielysprague/foco_lab/data/kimura_full'):
    if not file[-4:] == '.nwb':
        continue

    filepath = '/Users/danielysprague/foco_lab/data/kimura_full/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]
        rate = read_nwb.acquisition['CalciumImageSeries'].rate

    index_AVA = np.argwhere((np.asarray(calc_labels)=='AVAR') | (np.asarray(calc_labels)=='AVAL'))
    index_SMDV = np.argwhere((np.asarray(calc_labels)=='SMDVR') | (np.asarray(calc_labels)=='SMDVL'))
    index_AWC = np.argwhere((np.asarray(calc_labels)=='AWCR') | (np.asarray(calc_labels)=='AWCL'))
    index_ASH = np.argwhere((np.asarray(calc_labels)=='ASHR') | (np.asarray(calc_labels)=='ASHL'))
    index_RID = np.argwhere(np.asarray(calc_labels)=='RID')
    AVA_traces = AVA_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_AVA if index is not None]
    SMDV_traces = SMDV_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_SMDV if index is not None]
    AWC_traces = AWC_traces + [np.squeeze(fluor[:, index]/np.mean(fluor[:,index])) for index in index_AWC if index is not None]
    ASH_traces = ASH_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_ASH if index is not None]
    RID_traces = RID_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_RID if index is not None]

#for i, trace in enumerate(AVA_traces):
#    plt.plot(np.linspace(0,985,1645), trace[:1645], alpha=0.5)
#plt.show()

#for i, trace in enumerate(AVA_traces):
#    plt.plot(np.fft.fft(trace[:1645], axis=0), alpha=0.5)

In [None]:
%matplotlib qt

AVA_traces = []
SMDV_traces = []
AWC_traces = []
ASH_traces = []
RID_traces = []

for file in os.listdir('/Users/danielysprague/foco_lab/data/Yemini_NWB'):
    if not file[-4:] == '.nwb':
        continue

    filepath = '/Users/danielysprague/foco_lab/data/Yemini_NWB/' + file

    with NWBHDF5IO(filepath, mode='r', load_namespaces=True) as io:
        read_nwb = io.read()
        seg = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons'].voxel_mask[:]
        labels = read_nwb.processing['NeuroPAL']['NeuroPALSegmentation']['NeuroPALNeurons']['ID_labels'][:]
        channels = read_nwb.acquisition['NeuroPALImageRaw'].RGBW_channels[:] #get which channels of the image correspond to which RGBW pseudocolors
        image = read_nwb.acquisition['NeuroPALImageRaw'].data[:]
        scale = read_nwb.imaging_planes['NeuroPALImVol'].grid_spacing[:] #get which channels of the image correspond to which RGBW pseudocolors

        fluor = read_nwb.processing['CalciumActivity']['SignalRawFluor']['SignalCalciumImResponseSeries'].data[:]
        calc_labels = read_nwb.processing['CalciumActivity']['NeuronIDs'].labels[:]
        rate = read_nwb.acquisition['CalciumImageSeries'].rate

    index_AVA = np.argwhere((np.asarray(calc_labels)=='AVAR') | (np.asarray(calc_labels)=='AVAL'))
    index_SMDV = np.argwhere((np.asarray(calc_labels)=='SMDVR') | (np.asarray(calc_labels)=='SMDVL'))
    index_AWC = np.argwhere((np.asarray(calc_labels)=='AWCR') | (np.asarray(calc_labels)=='AWCL'))
    index_ASH = np.argwhere((np.asarray(calc_labels)=='ASHR') | (np.asarray(calc_labels)=='ASHL'))
    index_RID = np.argwhere(np.asarray(calc_labels)=='RID')
    AVA_traces = AVA_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_AVA if index is not None]
    SMDV_traces = SMDV_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_SMDV if index is not None]
    AWC_traces = AWC_traces + [np.squeeze(fluor[:, index]/np.mean(fluor[:,index])) for index in index_AWC if index is not None]
    ASH_traces = ASH_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_ASH if index is not None]
    RID_traces = RID_traces + [np.squeeze(fluor[:,index]/np.mean(fluor[:,index])) for index in index_RID if index is not None]

#for i, trace in enumerate(AVA_traces):
#    plt.plot(np.linspace(0,985,1645), trace[:1645], alpha=0.5)
#plt.show()

#for i, trace in enumerate(AVA_traces):
#    plt.plot(np.fft.fft(trace[:1645], axis=0), alpha=0.5)
#plt.show()

In [None]:
def plot_wavelet_decomp(signal, waveletname, level):
    fig, ax = plt.subplots(figsize=(6,1))
    ax.set_title("Original Signal: ")
    ax.plot(signal)
    plt.show()
        
    data = signal

    detail_coefs = []
    
    fig, axarr = plt.subplots(nrows=level, ncols=2, figsize=(6,6))
    for ii in range(level):
        (data, coeff_d) = pywt.dwt(data, waveletname)
        detail_coefs.append(coeff_d)
        axarr[ii, 0].plot(data, 'r')
        axarr[ii, 1].plot(coeff_d, 'g')
        axarr[ii, 0].set_ylabel("Level {}".format(ii + 1), fontsize=14, rotation=90)
        axarr[ii, 0].set_yticklabels([])
        if ii == 0:
            axarr[ii, 0].set_title("Approximation coefficients", fontsize=14)
            axarr[ii, 1].set_title("Detail coefficients", fontsize=14)
        axarr[ii, 1].set_yticklabels([])
    plt.tight_layout()
    plt.show()

    return detail_coefs

In [None]:
def get_wavelet_decomp(signal, waveletname, level):

    detail_coefs = []
    approx_coefs = []
    detail_power = []
    
    data=signal

    for ii in range(level):
        (data, coeffs_d) = pywt.dwt(data, waveletname)
        approx_coefs.append(data)
        detail_coefs.append(coeffs_d)
        detail_power.append(np.sqrt(np.mean(coeffs_d**2)))

    return approx_coefs, detail_coefs, detail_power


In [None]:
plot_wavelet_decomp(fluor[:,5], 'sym5', 5)
plot_wavelet_decomp(AVA_traces[10], 'haar', 10)

In [None]:
%matplotlib inline
wavelet = pywt.Wavelet('haar')
[dec_lo, dec_hi, rec_lo, rec_hi] = wavelet.filter_bank
print(wavelet.dec_len)

print(dec_lo)
print(dec_hi)

plt.plot(dec_lo)
plt.plot(dec_hi)
plt.legend(labels=['decomp low', 'decomp high'])
plt.show()

In [None]:
%matplotlib qt

waveletname = 'haar'
wavelet = pywt.Wavelet(waveletname)
wavelen = wavelet.dec_len
level = 7

rate= rate
length = 936

neurons = ['AVA', 'SMDV', 'AWC', 'ASH', 'RID']

all_traces = [AVA_traces, SMDV_traces, AWC_traces, ASH_traces, RID_traces]

df = pd.DataFrame(columns=['Neuron_name', 'DecompLevel', 'Power'])

for i, neuron in enumerate(neurons):
    for j in range(len(all_traces[i])):
        approx, detail, power = get_wavelet_decomp(np.squeeze(all_traces[i][j][50:]), waveletname, level)
        df = pd.concat([df, pd.DataFrame([{'neuron_name': neuron, 'DecompScale': str((wavelen/rate)* 2**k)[:4], 'Power':power[k]} for k in range(len(power))])])

    '''
    AVA_approx, AVA_detail, AVA_power = get_wavelet_decomp(AVA_traces[i], waveletname, level)
    SMDV_approx, SMDV_detail, SMDV_power = get_wavelet_decomp(SMDV_traces[i], waveletname, level)
    AWC_approx, AWC_detail, AWC_power = get_wavelet_decomp(AWC_traces[i], waveletname, level)
    ASH_approx, ASH_detail, ASH_power = get_wavelet_decomp(ASH_traces[i], waveletname, level)
    RID_approx, RID_detail, RID_power = get_wavelet_decomp(RID_traces[i], waveletname, level)
    
    df = pd.concat([df, pd.DataFrame([{'neuron_name':'AVA', 'DecompLevel':2**j, 'Power':AVA_power[j]} for j in range(len(AVA_power))])])
    df = pd.concat([df, pd.DataFrame([{'neuron_name':'SMDV', 'DecompLevel':2**j, 'Power':SMDV_power[j]} for j in range(len(SMDV_power))])])
    df = pd.concat([df, pd.DataFrame([{'neuron_name':'AWC', 'DecompLevel':2**j, 'Power':AWC_power[j]} for j in range(len(AWC_power))])])
    df = pd.concat([df, pd.DataFrame([{'neuron_name':'ASH', 'DecompLevel':2**j, 'Power':ASH_power[j]} for j in range(len(ASH_power))])])
    df = pd.concat([df, pd.DataFrame([{'neuron_name':'RID', 'DecompLevel':2**j, 'Power':RID_power[j]} for j in range(len(RID_power))])])
    '''

fig, axs = plt.subplots(5,2)

for i, neuron in enumerate(neurons):
    for j in range(len(all_traces[i])):
        axs[i,0].plot(np.linspace(50,length/rate,length-50), all_traces[i][j][50:length], alpha=0.5)
    axs[i,0].set_xlabel('Time in seconds')
    axs[i,0].set_ylabel('df/f')
    axs[i,0].title.set_text(neuron)

    sns.boxplot(ax=axs[i,1], data= df[df['neuron_name']==neuron], x='DecompScale', y='Power')
    axs[i,1].set_xlabel('Wavelet scale in seconds')
    axs[i,1].set_ylabel('Wavelet power')
    axs[i,1].title.set_text(neuron)

plt.show()

In [None]:
%matplotlib qt
def plot_accuracies(datasets, accs_NP, accs_full, labels):

    df_dataset = pd.DataFrame(columns=['Atlas', 'Dataset', 'Accuracy'])
    df_ranks = pd.DataFrame(columns= ['Atlas', 'Rank', 'Accuracy'])

    for i, dataset in enumerate(datasets):
        for key in dataset.keys():
            acc_NP = accs_NP.loc[accs_NP['Filename']==key]
            acc_full = accs_full.loc[accs_full['Filename']==key]

            df_dataset.loc[len(df_dataset.index)] = ['NP', labels[i], acc_NP.iloc[0]['Percent_top1']]
            df_dataset.loc[len(df_dataset.index)] = ['Consolidated', labels[i], acc_full.iloc[0]['Percent_top1']]

            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top', acc_NP.iloc[0]['Percent_top1']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top', acc_full.iloc[0]['Percent_top1']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top2', acc_NP.iloc[0]['Percent_top2']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top2', acc_full.iloc[0]['Percent_top2']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top3', acc_NP.iloc[0]['Percent_top3']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top3', acc_full.iloc[0]['Percent_top3']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top4', acc_NP.iloc[0]['Percent_top4']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top4', acc_full.iloc[0]['Percent_top4']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top5', acc_NP.iloc[0]['Percent_top5']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top5', acc_full.iloc[0]['Percent_top5']]

    fig, axs = plt.subplots(1,2)

    sns.set(style='white', font_scale=1.5)

    sns.violinplot(ax= axs[0], data = df_dataset, x = 'Dataset', y='Accuracy', hue='Atlas', gap=0.5, palette=['purple','pink'], orient='v', split=True, cut=0, inner='quart', density_norm='width') 
    
    for i, dataset in enumerate(labels):
        NP_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='NP')]['Accuracy']
        consol_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='Consolidated')]['Accuracy']

        for a, b in zip(NP_vals, consol_vals):
            axs[0].plot([i-0.1,i+0.1], [a,b], color='black', linewidth=0.5)

    sns.violinplot(ax= axs[1], data = df_ranks, x = 'Rank', y='Accuracy', hue='Atlas', gap=0.5, palette=['purple', 'pink'], orient='v', split=True, cut=0, inner='quart', density_norm='width') 

    ranks = ['top', 'top2', 'top3', 'top4', 'top5']

    for j, rank in enumerate(ranks):
        NP_vals_rank = df_ranks[(df_ranks['Rank']==rank)&(df_ranks['Atlas']=='NP')]['Accuracy']
        consol_vals_rank = df_ranks[(df_ranks['Rank']==rank)&(df_ranks['Atlas']=='Consolidated')]['Accuracy']

        for a, b in zip(NP_vals_rank, consol_vals_rank):
            axs[1].plot([j-0.1,j+0.1], [a,b], color='black', linewidth=0.5)

    axs[0].set_ylim((0,1))
    axs[1].set_ylim((0,1))
    
    axs[0].set_title('Accuracy by dataset')
    axs[1].set_title('Cumulative accuracy of top n assignments')

    plt.show()
    
#plot_accuracies([chaud_dataset, Yem_dataset, old_FOCO_dataset, FOCO_dataset, kimura_dataset, flavell_dataset], accs_NP, accs_full, ['1', '2', '3', '4', '5', '6'])

In [None]:
import scipy
def gen_plots_acc(datasets, labels,accs_NP_unmatch, accs_NP, accs_full_unmatch, accs_full):

    plt.rcParams.update({'font.size': 22})

    df_dataset = pd.DataFrame(columns=['Atlas', 'Dataset', 'Accuracy'])
    df_ranks = pd.DataFrame(columns= ['Atlas', 'Rank', 'Accuracy'])

    for i, dataset in enumerate(datasets):
        for key in dataset.keys():
            acc_NP_unmatch = accs_NP_unmatch.loc[accs_NP_unmatch['Filename']==key]
            acc_NP = accs_NP.loc[accs_NP['Filename']==key]
            acc_full_unmatch = accs_full_unmatch.loc[accs_full_unmatch['Filename']==key]
            acc_full = accs_full.loc[accs_full['Filename']==key]

            df_dataset.loc[len(df_dataset.index)] = ['Base', labels[i], acc_NP_unmatch.iloc[0]['Percent_top1']]
            df_dataset.loc[len(df_dataset.index)] = ['Matched', labels[i], acc_NP.iloc[0]['Percent_top1']]
            df_dataset.loc[len(df_dataset.index)] = ['Full', labels[i], acc_full_unmatch.iloc[0]['Percent_top1']]
            df_dataset.loc[len(df_dataset.index)] = ['Full matched', labels[i], acc_full.iloc[0]['Percent_top1']]

            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top', acc_NP_unmatch.iloc[0]['Percent_top1']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top', acc_full.iloc[0]['Percent_top1']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top2', acc_NP_unmatch.iloc[0]['Percent_top2']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top2', acc_full.iloc[0]['Percent_top2']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top3', acc_NP_unmatch.iloc[0]['Percent_top3']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top3', acc_full.iloc[0]['Percent_top3']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top4', acc_NP_unmatch.iloc[0]['Percent_top4']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top4', acc_full.iloc[0]['Percent_top4']]
            df_ranks.loc[len(df_ranks.index)] = ['NP', 'top5', acc_NP_unmatch.iloc[0]['Percent_top5']]
            df_ranks.loc[len(df_ranks.index)] = ['Consolidated', 'top5', acc_full.iloc[0]['Percent_top5']]

    palette = sns.color_palette('colorblind')
    color1 = palette[3]
    color2 = palette[2]
    color3 = palette[0]
    color4 = palette[8]
    color5 = palette[4]
    color6 = palette[6]

    fig, axs = plt.subplots(3,2)

    sns.set(style='white', font_scale=1.5)

    #sns.violinplot(ax=axs[0][0], data = df_dataset[df_dataset['Atlas']=='NP_unmatch'], x='Dataset', y='Accuracy', palette=['purple'], inner='quart', density_norm='width')

    sns.violinplot(ax = axs[0][0], data=df_dataset, x='Atlas', y='Accuracy', hue='Atlas', palette=[color1, color2, color3, color4], cut=0, inner='quart', density_norm='width')

    sns.violinplot(ax= axs[0][1], data = df_ranks, x = 'Rank', y='Accuracy', hue='Atlas', gap=0.5, palette=[color1, color4], orient='v', split=True, cut=0, inner='quart', density_norm='width') 

    ranks = ['top', 'top2', 'top3', 'top4', 'top5']

    for j, rank in enumerate(ranks):
        NP_vals_rank = df_ranks[(df_ranks['Rank']==rank)&(df_ranks['Atlas']=='NP')]['Accuracy']
        consol_vals_rank = df_ranks[(df_ranks['Rank']==rank)&(df_ranks['Atlas']=='Consolidated')]['Accuracy']

        for a, b in zip(NP_vals_rank, consol_vals_rank):
            axs[0][1].plot([j-0.1,j+0.1], [a,b], color='black', linewidth=0.5)

    sns.violinplot(ax= axs[1][0], data = df_dataset[(df_dataset['Atlas']=='Base')|(df_dataset['Atlas']=='Matched')], x = 'Dataset', y='Accuracy', hue='Atlas', gap=0.5, palette=[color1, color2], orient='v', split=True, cut=0, inner='quart', density_norm='width') 
    sns.violinplot(ax= axs[1][1], data = df_dataset[(df_dataset['Atlas']=='Base')|(df_dataset['Atlas']=='Full')], x = 'Dataset', y='Accuracy', hue='Atlas', gap=0.5,  palette=[color1, color3], orient='v', split=True, cut=0, inner='quart', density_norm='width') 
    sns.violinplot(ax= axs[2][0], data = df_dataset[(df_dataset['Atlas']=='Full')|(df_dataset['Atlas']=='Full matched')], x = 'Dataset', y='Accuracy', hue='Atlas', gap=0.5, palette=[color3, color4], orient='v', split=True, cut=0, inner='quart', density_norm='width') 
    sns.violinplot(ax= axs[2][1], data = df_dataset[(df_dataset['Atlas']=='Base')|(df_dataset['Atlas']=='Full matched')], x = 'Dataset', y='Accuracy', hue='Atlas', gap=0.5, palette=[color1, color4], orient='v', split=True, cut=0, inner='quart', density_norm='width') 
    
    for i, dataset in enumerate(labels):
        NP_unmatch_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='Base')]['Accuracy']
        NP_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='Matched')]['Accuracy']
        consol_unmatch_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='Full')]['Accuracy']
        consol_vals = df_dataset[(df_dataset['Dataset']==dataset)&(df_dataset['Atlas']=='Full matched')]['Accuracy']

        for a, b, c, d in zip(NP_unmatch_vals, NP_vals, consol_unmatch_vals, consol_vals):
            axs[1][0].plot([i-0.1,i+0.1], [a,b], color='black', linewidth=0.5)
            axs[1][1].plot([i-0.1,i+0.1], [a,c], color='black', linewidth=0.5)
            axs[2][0].plot([i-0.1,i+0.1], [c,d], color='black', linewidth=0.5)
            axs[2][1].plot([i-0.1,i+0.1], [a,d], color='black', linewidth=0.5)

    axs[0][0].set_ylim((0,1))
    axs[0][0].set(xlabel=None)
    #axs[0][1].legend([],[], frameon=False)
    axs[0][1].set_ylim((0,1))
    axs[0][1].set_ylabel(None)
    axs[0][1].set(xlabel=None)
    axs[0][1].legend([],[], frameon=False)
    axs[1][0].set_ylim((0,1))
    axs[1][0].legend([],[], frameon=False)
    axs[1][1].set_ylim((0,1))
    axs[1][1].legend([],[], frameon=False)
    axs[1][1].set_ylabel(None)
    axs[2][0].set_ylim((0,1))
    axs[2][0].legend([],[], frameon=False)
    axs[2][0].set_xlabel(None)
    axs[2][0].set_xticklabels(['Original', 'Color matched', 'Consolidated', 'Consolidated & color matched'])
    axs[2][1].set_ylim((0,1))
    #axs[2][1].legend([],[], frameon=False)
    axs[2][1].set_ylabel(None)
    axs[2][1].set_xlabel(None)

    base_accs = np.asarray(df_dataset[df_dataset['Atlas']=='Base']['Accuracy'])
    match_accs = np.asarray(df_dataset[df_dataset['Atlas']=='Matched']['Accuracy'])
    consol_accs = np.asarray(df_dataset[df_dataset['Atlas']=='Full']['Accuracy'])
    consol_match_accs = np.asarray(df_dataset[df_dataset['Atlas']=='Full matched']['Accuracy'])

    base_match = scipy.stats.ttest_rel(base_accs, match_accs)
    base_full = scipy.stats.ttest_rel(base_accs, consol_accs)
    full_fullmatch = scipy.stats.ttest_rel(consol_accs, consol_match_accs)
    base_fullmatch = scipy.stats.ttest_rel(base_accs, consol_match_accs)

    print('t-value: ' +str(base_match.statistic)+' pvalue: '+str(base_match.pvalue))
    print('t-value: ' +str(base_full.statistic)+' pvalue: '+str(base_full.pvalue))
    print('t-value: ' +str(full_fullmatch.statistic)+' pvalue: '+str(full_fullmatch.pvalue))
    print('t-value: ' +str(base_fullmatch.statistic)+' pvalue: '+str(base_fullmatch.pvalue))

    plt.show()

gen_plots_acc([chaud_dataset, Yem_dataset, old_FOCO_dataset, FOCO_dataset, kimura_dataset, flavell_dataset], ['1', '2', '3', '4', '5', '6'], accs_NP_unmatch, accs_NP, accs_full_unmatch, accs_full)

In [None]:
%matplotlib inline
import scipy.stats as stats

def plot_accuracies_atlas_compare(datasets, accs_NP, accs_full):

    df = pd.DataFrame(columns=['NeuroPAL','Consolidated', 'dataset'])

    for i, dataset in enumerate(datasets):
        for key in dataset.keys():
            acc_NP = accs_NP[key]
            acc_full = accs_full[key]

            df.loc[len(df.index)] = [acc_NP, acc_full, key]
    
    df_long = pd.melt(df, id_vars='dataset', value_vars=['NeuroPAL', 'Consolidated'], var_name='Accuracy', value_name='Value')

    plt.figure(figsize=(8, 6))
    sns.boxplot(x='Accuracy', y='Value', data=df_long, color='skyblue', width=0.4)
    sns.boxplot(x='Accuracy', y='Value', data=df_long, color='lightcoral', width=0.4)
    sns.scatterplot(x='Accuracy', y='Value', data=df_long, color='skyblue')
    sns.scatterplot(x='Accuracy', y='Value', data=df_long, color='lightcoral')

    t, prob = stats.ttest_rel(np.asarray(df['NeuroPAL']), np.asarray(df['Consolidated']))

    plt.ylabel('Assignment accuracy')
    plt.xlabel('Atlas used')
    plt.ylim((0,1))

    print(t)
    print(prob)
    # Add lines connecting data points from the same dataset
    #for i in range(len(df)):
    #    plt.plot([i, i], [df['acc_NP'][i], df['acc_full'][i]], color='gray', linestyle='-', linewidth=1, alpha=0.7)

    plt.show()

plot_accuracies_atlas_compare([NP_dataset, chaud_dataset, Yem_dataset, old_FOCO_dataset, FOCO_dataset], accs_NP, accs_full)

In [None]:
raw_file = '/Users/danielysprague/foco_lab/data/NP_Ray/20221215-20-02-49/full_comp.tif'
data = skio.imread(raw_file)
data = np.transpose(data)

print(data.shape)

## Figure 4

In [None]:
%matplotlib qt

def plot_num_neur_heatmap(atlas, pairs, total_dataset):
    neurons, num = get_neur_nums(total_dataset, atlas)

    fig, axs = plt.subplots(1,3)
    sns.set(style='white')

    neur_df = atlas.df

    neur_df = atlas.df[['ID', 'ganglion']]

    dict_df = pd.DataFrame(list(neurons.items()), columns = ['ID', 'num'])
    dict_df['frac'] = dict_df['num']/num
 
    merged = pd.merge(neur_df, dict_df, on='ID') #this will preserve the order of neurons from neur_df which is sorted by ganglion and then distance along x axis

    sns.barplot(ax=axs[0], y='ID', x='frac', hue='ganglion', data=merged, orient='h')
    sns.despine(ax=axs[0], bottom=True, top=True,left=True, right=True)
    axs[0].set_ylabel('Neuron IDs')
    axs[0].set_xlabel('Fraction of datasets with ground truth labeled neuron')
    axs[0].tick_params(labelleft=False, left=False)
    bar_width = 0.8

    for patch in axs[0].patches:
        current_width = patch.get_height()
        diff = current_width - bar_width
        # Change the bar width
        patch.set_height(bar_width)

        # Recenter the bar
        patch.set_y(patch.get_y() + diff * .5)

    axs[0].invert_xaxis()
    # Show the plot
    axs[0].get_legend().remove()
    #axs[0].legend(title='ganglion', loc='lower left')
    
    num_pair, num_heatmap, total_std_heatmap = analyze_pairs(pairs, neur_df, num)
    ganglia_indices = {}

    for ganglion in neur_df['ganglion'].unique():
        # Find the indices where the category starts and ends
        start_index = neur_df.index[neur_df['ganglion'] == ganglion][0]
        end_index = neur_df.index[neur_df['ganglion'] == ganglion][-1]
        
        # Store the start and end indices in the dictionary
        ganglia_indices[ganglion] = (start_index, end_index)

    axs[1].set_facecolor('black')

    mask = np.where(num_heatmap < 0.4, True, False)

    sns.heatmap(data=total_std_heatmap, ax=axs[1], cmap='Reds', cbar = False)
    highlight_boxes = [((ganglia_indices[gang][0], ganglia_indices[gang][0]), (ganglia_indices[gang][1], ganglia_indices[gang][1]), gang) for gang in neur_df['ganglion'].unique()]

    for (x1, y1), (x2, y2), label in highlight_boxes:
        axs[1].add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='black', lw=3))
        if x1<len(neur_df)/2:
            axs[1].text(x2+2, (y1 + y2) / 2, label, color='black', ha='left', va='center')
        else:
            axs[1].text(x1-2, (y1 + y2) / 2, label, color='black', ha='right', va='center')

    axs[1].set_title('Standard deviation of pairwise distances')
    axs[1].tick_params(which='both', bottom=False, left=False,labelbottom=False, labelleft=False)  # Hide tick labels

    avg_std = np.sum(total_std_heatmap, axis=0)/total_std_heatmap.shape[0]

    new_df = neur_df.copy()

    new_df['std'] = avg_std

    sns.barplot(ax=axs[2], data=new_df, y='ID', x='std', hue='ganglion',orient='h')
    sns.despine(ax=axs[2], bottom=True, top=True,left=True, right=True)
    axs[2].set_ylabel('')
    axs[2].set_xlabel('Cumulative std in um')
    axs[2].tick_params(labelleft=False, left=False)
    axs[2].get_legend().remove()
    #axs[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

    for patch in axs[2].patches:
        current_width = patch.get_height()
        diff = current_width - bar_width
        # Change the bar width
        patch.set_height(bar_width)

        # Recenter the bar
        patch.set_y(patch.get_y() + diff * .5)

    plt.show()

plot_num_neur_heatmap(atlas, pair_tot, tot_dataset)

In [None]:
synap_df = pd.read_csv('/Users/danielysprague/foco_lab/data/synaptic_connecs.csv')
synap_df.head()

df = pd.DataFrame(columns= ['Synaptic weight', 'LR pair', 'mean_dist', 'std_dist', 'Synapse type'])

for i, row in synap_df.iterrows():
    source = row['Source']
    target = row['Target']
    weight = row['Weight']
    type = row['Type']

    if source == target:
        continue

    if source[:-1] == target[:-1]:
        LR = True
    else:
        LR = False

    if source < target:
        pair = source + '-' + target
    else: 
        pair = target + '-' + source

    if pair in pair_tot.keys():
        dists = pair_tot[pair]
        if len(dists) < 5:
            continue
        else:

            mean = np.mean(dists)
            std = np.std(dists)

            df.loc[len(df)] = [weight, LR, mean, std, type]
fig, axs = plt.subplots(1,2)

sns.scatterplot(ax=axs[0], data=df[df['LR pair']==False], x='Synaptic weight', y='mean_dist', hue='Synapse type')
axs[0].set_xlabel('synaptic weight between pair of neurons')
axs[0].set_ylabel('mean distance in um')
sns.scatterplot(ax=axs[1], data=df[df['LR pair']==False], x='Synaptic weight', y='std_dist', hue='Synapse type')
axs[1].set_xlabel('synaptic weight between pair of neurons')
axs[1].set_ylabel('standard deviation of distance in um')
plt.show()

    


## Lineage plots

lin_dist_plot = lin_dist_df.loc[lin_dist_df['neuron1']!=lin_dist_df['neuron2']]
lin_dist_plot = lin_dist_plot.loc[lin_dist_df['birth_dists']<800]

fig = plt.figure()

fontsize=16

tree_dists = np.asarray(lin_dist_plot['tree_dists'])
std = np.asarray(lin_dist_plot['Std_dist'])

#pearson_coef

sns.regplot(data=lin_dist_plot, x='tree_dists',y='Std_dist')
#sns.scatterplot(ax=axs[1], data=lin_dist_plot, x='lin_dist', y='Norm_std')
plt.gca().invert_xaxis()
plt.title('Standard deviation of pairwise physical distance and pairwise distance in cell lineage', fontsize=fontsize)
plt.ylabel('Standard deviation in um', fontsize=fontsize)
plt.xlabel('Lineal tree distance', fontsize=fontsize)

plt.show()

In [None]:
from sklearn.linear_model import LinearRegression

model = LinearRegression()

x = tree_dists.reshape((-1,1))
y = std
model = model.fit(x, y)

r_sq = model.score(x,y)
print('R2 value = '+ str(r_sq))
print('y ='+str(model.coef_[0])+'*x + '+str(model.intercept_))


In [None]:
lin_dist_plot = lin_dist_df.loc[lin_dist_df['neuron1']!=lin_dist_df['neuron2']]
lin_dist_plot = lin_dist_plot.loc[lin_dist_df['birth_dists']<800]

fig = plt.figure()

fontsize=16

tree_dists = np.asarray(lin_dist_plot['tree_dists'])
mean = np.asarray(lin_dist_plot['Mean_dist'])

#pearson_coef

sns.regplot(data=lin_dist_plot, x='tree_dists',y='Mean_dist')
#sns.scatterplot(ax=axs[1], data=lin_dist_plot, x='lin_dist', y='Norm_std')
plt.gca().invert_xaxis()
plt.title('Mean of pairwise physical distance and pairwise distance in cell lineage', fontsize=fontsize)
plt.ylabel('Mean distance in um', fontsize=fontsize)
plt.xlabel('Lineal tree distance', fontsize=fontsize)

plt.show()

In [None]:
model = LinearRegression()

x = tree_dists.reshape((-1,1))
y = mean
model = model.fit(x, y)

r_sq = model.score(x,y)
print('R2 value = '+ str(r_sq))
print('y ='+str(model.coef_[0])+'*x + '+str(model.intercept_))


In [None]:
lin_dist_plot = lin_dist_df.loc[lin_dist_df['neuron1']!=lin_dist_df['neuron2']]
lin_dist_plot = lin_dist_plot.loc[lin_dist_df['birth_dists']<800]

fig = plt.figure()

fontsize=16

birth_dists = np.asarray(lin_dist_plot['birth_dists'])
std = np.asarray(lin_dist_plot['Std_dist'])

#pearson_coef

sns.regplot(data=lin_dist_plot, x='birth_dists',y='Std_dist')
#sns.scatterplot(ax=axs[1], data=lin_dist_plot, x='lin_dist', y='Norm_std')
plt.title('Standard deviation of pairwise physical distance and pairwise distance in cell lineage', fontsize=fontsize)
plt.ylabel('Standard deviation in um', fontsize=fontsize)
plt.xlabel('Time of birth of last shared parent cell between pair in min', fontsize=fontsize)

plt.show()

In [None]:
model = LinearRegression()

x = birth_dists.reshape((-1,1))
y = std
model = model.fit(x, y)

r_sq = model.score(x,y)
print('R2 value = '+ str(r_sq))
print('y ='+str(model.coef_[0])+'*x + '+str(model.intercept_))


In [None]:
lin_dist_plot = lin_dist_df.loc[lin_dist_df['neuron1']!=lin_dist_df['neuron2']]
lin_dist_plot = lin_dist_plot.loc[lin_dist_df['birth_dists']<800]

fig = plt.figure()

fontsize=16

birth_dists = np.asarray(lin_dist_plot['birth_dists'])
mean = np.asarray(lin_dist_plot['Mean_dist'])

#pearson_coef

sns.regplot(data=lin_dist_plot, x='birth_dists',y='Mean_dist')
#sns.scatterplot(ax=axs[1], data=lin_dist_plot, x='lin_dist', y='Norm_std')
plt.title('Mean of pairwise physical distance and pairwise distance in cell lineage', fontsize=fontsize)
plt.ylabel('Mean distance in um', fontsize=fontsize)
plt.xlabel('Time of birth of last shared parent cell between pair in min', fontsize=fontsize)

plt.show()

In [None]:
model = LinearRegression()

x = birth_dists.reshape((-1,1))
y = mean
model = model.fit(x, y)

r_sq = model.score(x,y)
print('R2 value = '+ str(r_sq))
print('y ='+str(model.coef_[0])+'*x + '+str(model.intercept_))


In [None]:
lin_dist_plot = lin_dist_df.loc[lin_dist_df['neuron1']!=lin_dist_df['neuron2']]
lin_dist_plot = lin_dist_df.loc[lin_dist_df['birth_dists']<800]

fig = plt.figure()

fontsize=16

birth_dists = np.asarray(lin_dist_plot['birth_dists'])
norm_std = np.asarray(lin_dist_plot['Norm_std'])

#pearson_coef

sns.scatterplot(data=lin_dist_plot, x='birth_dists',y='Mean_dist')
#sns.scatterplot(ax=axs[1], data=lin_dist_plot, x='lin_dist', y='Norm_std')
#plt.title('Mean of pairwise physical distance and pairwise distance in cell lineage', fontsize=fontsize)
plt.ylabel('Mean distance in um', fontsize=fontsize)
plt.xlabel('Time of birth of last shared parent cell between pair in min', fontsize=fontsize)
plt.show()

In [None]:
order = ['AntBulb-AntBulb', 'Ant-Ant', 'Dors-Dors', 'Lat-Lat', 'Vent-Vent', 'RVG-RVG', 'PostBulb-PostBulb', 'Ant-AntBulb', 'Ant-Dors', 'Ant-Lat', 'Dors-Lat', 'Lat-Vent','Lat-RVG','Lat-PostBulb', 'RVG-Vent', 'PostBulb-Vent','Postbulb-RVG']

## OLD Color statistics analysis

In [None]:
def get_color_stats(folder):

    rgbhist = np.zeros((32,3))

    neur_colors = {}

    for file in os.listdir(folder):
        if not file[-4:] == '.nwb':
            continue

        print(file)

        blobs, rgb_data = get_nwb_neurons(folder+'/'+file)

        color_norm = (rgb_data - np.min(rgb_data, axis=(0,1,2))) / (np.max(rgb_data, axis=(0,1,2))- np.min(rgb_data, axis=(0,1,2)))

        blobs[['Rnorm', 'Gnorm','Bnorm']] = np.nan

        for i, row in blobs.iterrows():
            colors = color_norm[max(row['x']-2,0):min(row['x']+2,rgb_data.shape[0]-1),max(row['y']-2,0):min(row['y']+2,rgb_data.shape[1]-1),max(row['z']-1,0):min(row['z']+1,rgb_data.shape[2]-1),:]

            flat_colors = colors.reshape(-1, colors.shape[-1])
            
            Rnorm = np.median(flat_colors[0])
            Gnorm = np.median(flat_colors[1])
            Bnorm = np.median(flat_colors[2])

            blobs.loc[i, 'Rnorm'] = Rnorm
            blobs.loc[i, 'Gnorm'] = Gnorm
            blobs.loc[i, 'Bnorm'] = Bnorm

        IDd = blobs[blobs['ID']!='']

        for i, row in IDd.iterrows():
            ID = row['ID']
            colors = np.asarray(row[['Rnorm', 'Gnorm', 'Bnorm']])
            if not ID in neur_colors:
                neur_colors[ID] = [colors]
            else:
                neur_colors[ID].append(colors)
            
        image = np.asarray(color_norm)
        im_flat = image.reshape(-1, image.shape[-1])

        rhist, bins = np.histogram(im_flat[:,0], bins=32, range=(0,1))
        ghist, bins = np.histogram(im_flat[:,1], bins=32, range=(0,1))
        bhist, bins = np.histogram(im_flat[:,2], bins=32, range=(0,1))

        rgbhist[:,0] += rhist
        rgbhist[:,1] += ghist
        rgbhist[:,2] += bhist
    
    rgbhist[:,0] = rgbhist[:,0]/np.sum(rgbhist[:,0])
    rgbhist[:,1] = rgbhist[:,1]/np.sum(rgbhist[:,1])
    rgbhist[:,2] = rgbhist[:,2]/np.sum(rgbhist[:,2])

    return rgbhist, neur_colors

In [None]:
foco_rgb, foco_colors = get_color_stats('/Users/danielysprague/foco_lab/data/NWB_Ray')
yem_rgb, yem_colors = get_color_stats('/Users/danielysprague/foco_lab/data/Yemini_NWB')
foco_og_rgb, foco_og_colors = get_color_stats('/Users/danielysprague/foco_lab/data/')


In [None]:
def plot_histograms(dataset_rgbs, labels):

    sns.set(style="white")
    fig, axs = plt.subplots(1,3, sharey=True)

    # Plot the bars using Matplotlib

    bin_edges = np.arange(32)/32

    for i, ax in enumerate(axs):
        #ax.hist(im_flat[:,i], bins= np.arange(32)/32, color= 'red')
        for j, dataset in enumerate(dataset_rgbs):
            ax.bar(bin_edges[:], dataset[:,i], alpha=0.5, align='center', width=1/32, color=sns.color_palette("pastel")[j], log=True, label=labels[j])
            #ax.bar(bin_edges[:], yem_rgb[:,0], alpha=0.5, align='center', width=1/32, color=sns.color_palette("pastel")[1], log=True, label='Yemini')
        ax.legend()
        #sns.histplot(ax=ax, data=np.ones(32),weights=np.transpose(foco_rgb[:,i]), bins= np.arange(32)/32, alpha=0.3, stat = 'probability', log_scale=(False,True), label='FOCO')
        #sns.histplot(ax=ax, data=np.ones(32),weights=np.transpose(yem_rgb[:,i]), bins=np.arange(32)/32, alpha=0.3, stat = 'probability',log_scale=(False,True), label='Yemini')

    axs[0].set_title('Red histogram')
    axs[0].set_xlabel('Normalized color')
    axs[0].set_ylabel('log probability')
    axs[1].set_title('Green histogram')
    axs[1].set_xlabel('Normalized color')
    axs[2].set_title('Blue histogram')
    axs[2].set_xlabel('Normalized color')

    plt.tight_layout()
    plt.show()

plot_histograms([foco_rgb, yem_rgb], ['FOCO', 'Yemini'])

In [None]:
color_discrim_FOCO = get_color_discrim('/Users/danielysprague/foco_lab/data/NWB_Ray', 6)
color_discrim_Yemini = get_color_discrim('/Users/danielysprague/foco_lab/data/Yemini_NWB', 6)


In [None]:
color_df = pd.DataFrame(columns=['avg_col_discrim', 'identifier'])

datasets= [color_discrim_FOCO, color_discrim_Yemini, color_discrim_FOCO_og, color_discrim_NP]

for dataset in datasets:
    for key, value in dataset.items():
        color_df = pd.concat([color_df, pd.DataFrame({'avg_col_discrim': np.mean(value), 'identifier':key[:-4]}, index=[0])], ignore_index=True)

print(color_df)

In [None]:
plot_color_discrim([color_discrim_FOCO, color_discrim_Yemini, color_discrim_FOCO_og, color_discrim_NP], ['FOCO', 'Yemini', 'FOCO_og', 'NP'])