In [1]:
from utilities import *
import matplotlib.pyplot as plt

In [2]:
############### User Configuration ###############
##################################################

datadir = "." # All ZDA files in this directory + subdirectories are loaded
selected_filename = "05_01_05" # focus for subsequent analysis

# Spatial area to investigate
y_range = [45,-2] #[10, 40]
x_range = [0,-2] #[41, 79]

# Temporal area to investigate
time_window = [40,-1]

In [3]:
############## Driver script: begin ##############
##################################################

# Load data
processed = [] # to avoid re-processing later
data_loader = DataLoader()
data_loader.load_all_zda(data_dir=datadir + "/zda_targets")


# Select data of interest        
selected_data = data_loader.select_data_by_keyword(selected_filename)

# Clip to the time range 20 ms onward to get rid of the camera "foot"
selected_data.clip_data(t_range=[40,-1])

raw_data, meta, rli = selected_data.get_data(), selected_data.get_meta(), selected_data.get_rli()
        
        

Number of files loaded: 0


AttributeError: 'NoneType' object has no attribute 'clip_data'

In [None]:
# view frames
fig, axes = plt.subplots(2, 2)
print(raw_data.shape)
axes[1][0].imshow(raw_data[0,:,:,0],cmap='jet')
axes[1][1].imshow(raw_data[0,:,:,-1],cmap='jet')
axes[0][0].imshow(raw_data[1,:,:,0],cmap='jet')
axes[0][1].imshow(raw_data[4,:,:,-1],cmap='jet')
plt.show()

In [None]:
# view a trace
tr = Tracer()
tr.plot_trace(raw_data, 40, 40, meta['interval_between_samples'], trial=0)

In [None]:
# Run this cell at most once per ZDA load
# Need to subtract off the low-frequency voltage drift. First-order correction
tr.correct_background(meta, raw_data)

#full trace
tr.plot_trace(raw_data[:,:,:,:], 
           40, 
           40, 
           meta['interval_between_samples'], 
           trial=0)


In [None]:
# spatial and temporal filtering to handle random noise

sp = SignalProcessor()

filtered_data = sp.filter_temporal(meta, raw_data)
filtered_data = sp.filter_spatial(meta, filtered_data)

tr.plot_trace(filtered_data[:,:,:,:],
           40, 
           40,
           meta['interval_between_samples'], 
           trial=0)


In [None]:
# Examine SNR

# Spatial area to investigate
y_range = [45,-2] #[10, 40]
x_range = [0,-2] #[41, 79]

trial = filtered_data[0, x_range[0]:x_range[1], 
                    y_range[0]:y_range[1], :] 
trials = filtered_data[:, x_range[0]:x_range[1], 
                     y_range[0]:y_range[1], :] 



asnr = AnalyzerSNR(trial)

snr = asnr.get_snr(plot=True)

print("max SNR:", np.max(snr), "min SNR:", np.min(snr))

In [None]:
# impose snr cutoff

asnr.cluster_on_snr(plot=True)
"""
snr_percentile_cutoff = 0.7
k_clusters = 3
snr_cutoff = np.percentile(snr, snr_percentile_cutoff * 100)

mask = (snr >= snr_cutoff).astype(np.float)

# masked image: reasonability check
plt.imshow(snr * mask, cmap='jet', interpolation='nearest')
plt.show()

# 1-D K-means clustering of SNR groups
km = KMeans(n_clusters=k_clusters+1).fit(snr.reshape(-1,1)) # +1 for the masked 0's

clustered = np.array(km.labels_).reshape(snr.shape) + 1
clustered = clustered.astype(np.float)

plt.imshow(clustered * mask, cmap='viridis', interpolation='nearest')
plt.show()
"""

In [None]:
# SNR by cluster
avg_snr_by_cluster = [np.average(snr[np.where(clustered==i)[0]]) for  i in range(1, k_clusters+2)]
print(avg_snr_by_cluster)

cluster_indices_by_snr = np.argsort(np.array(avg_snr_by_cluster)) + 1
highest_snr_cluster = cluster_indices_by_snr[-1]
print(cluster_indices_by_snr)
print("highest_snr_cluster =", highest_snr_cluster)

In [None]:
# Examine some higher-SNR pixels
n_samples = 5
mask = (snr >= snr_cutoff).astype(np.float)

# Select the pixels in the highest SNR cluster, above SNR cutoff
argmaxes =  np.where(clustered * mask == highest_snr_cluster)
max_samples = argmaxes[0].shape[0]
for i in range(0, max_samples, int(max_samples / n_samples) ):

    x_max = argmaxes[0][i]
    y_max = argmaxes[1][i]
    print("Pixel at (" + str(x_max), ",", str(y_max) + ")")
    plot_trace(trials[:,:,:,time_window[0]:time_window[1]], 
               x_max, 
               y_max, 
               meta['interval_between_samples'], 
               trial=0)

    mask[x_max, y_max] *= 5 # highlight

plt.imshow(clustered * mask, cmap='jet', interpolation='nearest')
plt.show()


In [None]:
# Spike sorting on the template-matched images

# We are looking for two patterns of shapes:
#   1) sub-threshold ESPSs 
#         - lower amplitude
#         - wider half-width
#   2) spikes
#         - higher amplitude
#         - narrow half-width

k_2d_clusters = 4 # choose via silhouette coefficient

trial = filtered_data[0, x_range[0]:x_range[1], 
                    y_range[0]:y_range[1], :] 
trials = filtered_data[:, x_range[0]:x_range[1], 
                     y_range[0]:y_range[1], :] 

# Let's do 1-D clustering by max amplitude
#   looking, for now, only in the time frame near stim and highest-SNR cluster

# Select the pixels in the highest SNR cluster, above SNR cutoff
px_selector = np.zeros(clustered.shape)
clusters_selected = cluster_indices_by_snr[-2:]
for c in clusters_selected:
    px_selector += (clustered * mask == c)

argmaxes =  np.where(px_selector > 0)
features = np.zeros((argmaxes[0].shape[0], 3))  # (max amp, half-width, cluster index)
i_filled = 0
for i in range(argmaxes[0].shape[0]):

    x_max = argmaxes[0][i]
    y_max = argmaxes[1][i]

    window = trial[x_max,
                   y_max,
                   time_window[0]:time_window[1]]


    # Calculate width at half-max, assuming min is zero 
    #   (valid assumption due to our fitted lin/exp correction)
    hm = features[i,0] / 2.0
    arg_max = np.argmax(window)

    fwhm = get_half_width(arg_max, window)

    # invalid spike, do not store
    if fwhm is None:
        continue
    
    features[i_filled,0] = np.max(window)
    features[i_filled,1] = fwhm
    features[i_filled,2] = clustered[x_max, y_max]
    i_filled += 1

    print("Pixel at (" + str(x_max), 
          ",", 
          str(y_max) + ")\n\t max amplitude:", 
          features[i,0], 
          "at:",
          arg_max,
          "\n\t half width:", 
          features[i,1])
    #plot_trace(trial[:,:,time_window[0]:time_window[1]],
    #           x_max,
    #           y_max,
    #           meta['interval_between_samples'] )

features = features[:i_filled+1, :]

# 2-D K-means clustering on features (max amp, half width)
label = KMeans(n_clusters=k_2d_clusters).fit_predict(features)

for i in range(k_2d_clusters):
    filtered_label = features[label == i]

    plt.scatter(filtered_label[:,0] , filtered_label[:,1])

plt.title("Cluster on (Peak Amplitude, FWHM)")
plt.ylabel("Spike Width at Half Maximum")
plt.xlabel("Spike Max Amplitude")
plt.show()

# plot by SNR cluster
for c in clusters_selected:
    filtered_label = features[features[:, 2] == c]

    plt.scatter(filtered_label[:,0] , filtered_label[:,1])

plt.title("Plotted by SNR cluster")
plt.ylabel("Spike Width at Half Maximum")
plt.xlabel("Spike Max Amplitude")
plt.show()


In [None]:
# Silhouette analysis

print("Silhouette score:", silhouette_score(features, label))

# Instantiate a scikit-learn K-Means model
model = KMeans(random_state=0)

# Instantiate the KElbowVisualizer with the number of clusters and the metric 
visualizer = KElbowVisualizer(model, k=(2,6), metric='silhouette', timings=False)

# Fit the data and visualize
visualizer.fit(features)    
visualizer.poof()   