In [None]:
import numpy as np
from smr import File
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML


In [None]:
%matplotlib notebook

### Source .smr full file path:

In [None]:
f_name = '/mnt/papers/Herzfeld_Nat_Neurosci_2018/raw_data/2010_Adapt/Buckley_12deg/B091608/B091608_1218_Adapt.smr'
# f_name = '/mnt/data/kkarbasi/SimpleSpike-Felix/Felix 2006.09.06/Felix 2006.09.06 1313 List.smr'

In [None]:
smr_content = File(f_name)
smr_content.read_channels()

In [None]:
voltage_chan = smr_content.get_channel(0)
HE_chan = smr_content.get_channel(1)
VE_chan = smr_content.get_channel(2)


### Plot timeseries data of each channel


In [None]:
from sorting import sorter

ss = sorter.SpikeDetector(voltage=voltage_chan.data, dt=voltage_chan.dt)
ss.run()

In [None]:
prang = slice(0,1000000)
ss.d_voltage.shape
# plt.figure()
# plt.plot(voltage_chan.data[prang])
plt.figure()
plt.plot(ss.d_voltage[prang])

# plt.figure()
# plt.hist(ss.d_voltage[prang], 1000)


In [None]:
from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(6, covariance_type = 'tied').fit(ss.d_voltage[prang].reshape(-1,1))
cluster_labels = gmm.predict(ss.d_voltage[prang].reshape(-1,1))
cluster_labels = cluster_labels.reshape(ss.d_voltage[prang].shape)

In [None]:
spikes_cluster = np.argmax(gmm.means_)
plt.figure()
plt.plot(ss.d_voltage[prang])
plt.plot(np.squeeze(np.where(cluster_labels == spikes_cluster)), ss.d_voltage[prang][cluster_labels == spikes_cluster], '.r')

In [None]:
from scipy.signal import find_peaks

peaks = find_peaks(ss.d_voltage[prang][cluster_labels == spikes_cluster])
print(ss.d_voltage[prang][cluster_labels == spikes_cluster].shape)
print(peaks[0].shape)
# print(peaks[0])
tmp = ss.d_voltage[prang][cluster_labels == spikes_cluster][peaks[0]]
# print(tmp)
peak_times = np.squeeze(np.where(cluster_labels == spikes_cluster))[peaks[0]]
# tmp[tuple(peaks[0])]

plt.figure()
plt.plot(ss.d_voltage[prang])
plt.plot(peak_times, tmp, '.r')


# plt.figure()
# plt.plot(np.squeeze(np.where(cluster_labels == spikes_cluster)), ss.d_voltage[prang][cluster_labels == spikes_cluster])

In [None]:
# ss.d_voltage[prang][cluster_labels == spikes_cluster]
# tmp = cluster_labels == spikes_cluster
# print(tmp)
# tmp = np.where(tmp == False)
print(gmm.means_)
print('cov')
print(gmm.covariances_)

In [None]:
x = np.arange(np.min(ss.d_voltage[prang]), np.max(ss.d_voltage[prang]), 2e5)
x.shape



In [None]:
from scipy.stats import norm
from kaveh.plots import axvlines
plt.figure()
# uniq = np.unique(ss.d_voltage[prang] , return_counts=True)
x = np.arange(np.min(ss.d_voltage[prang]), np.max(ss.d_voltage[prang]), 2e5)
gauss_mixt = np.array([p * norm.pdf(x, mu, sd) for mu, sd, p in zip(gmm.means_.flatten(), np.sqrt(gmm.covariances_.flatten()), gmm.weights_)])
colors = plt.cm.jet(np.linspace(0,1,len(gauss_mixt)))

# plot histogram overlaid by gmm gaussians
for i, gmixt in enumerate(gauss_mixt):
    plt.plot(x, gmixt, label = 'Gaussian '+str(i), color = colors[i])

plt.hist(ss.d_voltage[prang].reshape(-1,1),bins=256,density=True, color='gray')
axvlines(plt.gca(), gmm.means_)
plt.show()

In [None]:
from sorting import sorter

t = 1
print(round(t / float(ss.dt)))
ss.plot_triggers(t)

In [None]:
from kaveh.toolbox import resample_to_freq

target_freq = 10000

print('resampling channel {}; title: {}; units: {}'.format(voltage_chan.channel_number, voltage_chan.title, voltage_chan.units))
vol_resampled = resample_to_freq(voltage_chan.data, voltage_chan.ideal_rate, target_freq)

print('resampling channel {}; title: {}; units: {}'.format(HE_chan.channel_number, HE_chan.title, HE_chan.units))
HE_resampled = resample_to_freq(HE_chan.data, HE_chan.ideal_rate, target_freq)

print('resampling channel {}; title: {}; units: {}'.format(VE_chan.channel_number, VE_chan.title, VE_chan.units))
VE_resampled = resample_to_freq(VE_chan.data, VE_chan.ideal_rate, target_freq)



### Channel information:

In [None]:
chans = smr_content.channels

for chan in chans:
    if hasattr(chan, 'comment'): print("Channel number {}: {}".format(chan.channel_number , chan.comment))
    if hasattr(chan, 'kind'): print('kind: {}'.format(chan.kind))
    if hasattr(chan, 'units') and hasattr(chan, 'title'): print("Units: {} ({})".format(chan.title , chan.units))
    if hasattr(chan, 'dt'): print('dt: {}'.format(chan.dt))
    print('--------------------')
    

In [None]:
raster_chan = smr_content.get_channel(29)

print(raster_chan.data.shape)

print(raster_chan.data)

In [None]:
print(HE_resampled.shape)
print(VE_resampled.shape)
print(vol_resampled.shape)

plot_range = slice(70000, 80000)
fig, axes = plt.subplots(3,1)
fig.set_size_inches(30,10)

axes[0].plot(vol_resampled[plot_range])
axes[0].set_title('Voltage')

axes[1].plot(HE_resampled[plot_range])
axes[1].set_title('Horizontal Eye')

axes[2].plot(VE_resampled[plot_range])
axes[2].set_title('Vertical Eye')

fig.suptitle('Sampling rate = {} Hz'.format(target_freq))

plt.savefig("./test.jpg")

### Select channel for animated visualization:

In [None]:
# Channel number:
chan_number = 3

# Plot frame length (x axis length, in unit time):
frame_l = 120

# Jump interval between consecutive frames
frame_interval = 1

# Animation stop time (in channel time units). It will animate from time 0 to this number
maxn = 360
[4000:8000]

In [None]:
chan = smr_content.get_channel(chan_number)
chan_data = chan.data
t = np.arange(0,chan_data.shape[0])

x = chan_data[0 : maxn]
t = np.arange(0,x.shape[0])

ymax = np.max(x)
ymin = np.min(x)

counter = [0]

fig = plt.figure(figsize=(20 , 3))

def animate(n):
    plt.cla()
    line = plt.plot(t[ counter[-1] :  counter[-1] + frame_l], x[counter[-1] :  counter[-1] + frame_l], color='g')
    plt.title(chan.title)
    plt.ylabel(chan.units)
    plt.xlabel("t ({}s)".format(chan.dt))
    counter.append(n*frame_interval)
        
    
    plt.ylim(ymax, ymin)
    return line

anim = animation.FuncAnimation(fig, animate, frames=(maxn - frame_l)/frame_interval, interval=25)

HTML(anim.to_jshtml())

In [None]:
time_range = np.arange(5000, 55000, 100)

HT_chan = smr_content.get_channel(1)
VT_chan = smr_content.get_channel(2)

horizontal_pos = HT_chan.data
vertical_pos = VT_chan.data

x = horizontal_pos
y = vertical_pos

ymax = np.max(y) + 1
ymin = np.min(y) - 1

xmax = np.max(x) + 1
xmin = np.min(x) - 1


counter = [0]

fig = plt.figure(figsize=(6 , 6))

def animate(n):['data']
    plt.cla()
#     line = plt.plot(t[ counter[-1] :  counter[-1] + frame_l], x[counter[-1] :  counter[-1] + frame_l], color='g')
    line = plt.plot(x[n], y[n], color='r', marker='*')
    plt.title("{}, Frame number: {}".format(chan.title, n))
    plt.ylabel("Vertical Position ({})".format(VT_chan.units))
    plt.xlabel("Horizontal Position ({})".format(HT_chan.units))

    plt.ylim(ymax, ymin)
    plt.xlim(xmax, xmin)
    return line

anim = animation.FuncAnimation(fig, animate, frames=time_range, interval=2)

HTML(anim.to_jshtml())


In [None]:
print(np.argmax(x>0))
print(np.argmax(y>0))
print(x.shape)

In [None]:
plt.figure(figsize=(15,2))
plt.plot(x[14000:24000])

plt.plot(y[14000:24000])

In [None]:
a = np.arange(0,5)
print(a)
print(a[::-1])

In [None]:
np.min(x)

In [None]:
np.size(y)

In [None]:
plt.plot(chan.blocks)

In [None]:
neuron_0 = np.load('/mnt/papers/Herzfeld_Nat_Neurosci_2018/neurophys_python2/cs_durations/neuron_0_times.npy')

In [None]:
plot_range = slice(5, 20)

plt.figure(figsize=(15,1))
plt.plot(neuron_0[plot_range, 1], [1]*(plot_range.stop - plot_range.start), '.')
plt.plot(neuron_0[plot_range, 0], [1.03]*(plot_range.stop - plot_range.start), '.')
plt.ylim((0.5, 1.5))

In [None]:
plt.plot(neuron_0[:, 1] - neuron_0[:, 0])