In [None]:
from pathlib import Path
import numpy as np
import quickspikes as qs
import quantities as pq
import quickspikes.tools as qst
import quickspikes.intracellular as qsi
import matplotlib.pyplot as plt
from core import first_index
from typing import Iterable, Union, Tuple, Optional, Iterator
MOhm = pq.UnitQuantity("megaohm", pq.ohm * 1e6, symbol="MΩ")
pFarad = pq.UnitQuantity("picofarad", pq.farad * 1e-12, symbol="pF")

In [None]:
from neo import AxonIO
import nbank as nb
#cell = "96d3fb36-424b-4e57-b371-15f091587159"
#cell = "f7be8d05-f5c3-4fca-98ff-371a324a99fa"
cell = "4dc9441c-7f8f-4bd9-abd1-e7b42a91756d"
epoch = 1
#cell = "935feb47-e71e-4226-816b-f7f5b6d31325"
#cell = "30dc843d-53ae-49b3-ae65-6b12971e2b46"
cell = "42862924-3056-44fd-973e-95abfd42f81a"
epoch = 4
sweep = 1
path = nb.get(cell, local_only=True)
files = sorted(Path(path).glob("*.abf"))
ifp = AxonIO(files[epoch - 1])
block = ifp.read_block(lazy=True)
segment = block.segments[sweep]

In [None]:
segment.analogsignals[0].sampling_rate

In [None]:
fig, axes = plt.subplots(3)
V = segment.analogsignals[0].load().squeeze().magnitude
I = segment.analogsignals[1].load().squeeze().magnitude
axes[0].plot(V)
axes[1].plot(I)
detector = qsi.SpikeFinder(50, 350, 5000)
res = detector.calculate_threshold(V, thresh_min=-60)
print(res)
if res is not None:
    for time, spike in detector.extract_spikes(V, 10):
        axes[0].plot(time, V[time], "ro")
        axes[2].plot(spike)

In [None]:
protocols = ifp.read_protocol()
Ic = protocols[sweep].analogsignals[1]

In [None]:
step_len, step_start, step_val = qst.runlength_encode(Ic.squeeze().astype("i"))
base_step = first_index(lambda x: x == 0, step_val)
depol_step = first_index(lambda x: x > 0, step_val) or 0
hypol_step1 = first_index(lambda x: x < 0, step_val)
hypol_step2 = hypol_step1 + 1
last_step= len(step_val) - 1
(base_step, depol_step, hypol_step1, hypol_step2, last_step)

In [None]:
start = step_start[4]
end = start + step_len[4]
VV = V[start:end]
plt.plot(VV)

In [None]:
hypol_I = []
hypol_V = []
fig, axes = plt.subplots(2, sharex=True)
for segment in block.segments:
    V = segment.analogsignals[0].load()
    I = segment.analogsignals[1].load()
    start = step_start[3]
    end = start + step_len[3]
    VV = V[start:end].squeeze()
    II = I[start:end].squeeze()
    axes[0].plot(VV)
    axes[1].plot(II)
    hypol_I.append(II)
    hypol_V.append(VV)
    

In [None]:
from chebyfit import fit_exponentials
exponential_decay_thresh = 0.99
hI = np.mean(hypol_I, axis=0)
VV = np.mean(hypol_V, axis=0)
sampling_period = V.sampling_period.rescale("ms").magnitude
t = np.arange(0, 500, sampling_period)
plt.plot(t, VV)
i_min = VV.argmin()
thresh = VV[0] - (VV[0] - VV[i_min]) * exponential_decay_thresh
i_thresh = first_index(lambda x: x < thresh, VV[:i_min])
params, est = fit_exponentials(VV[:i_thresh], 1, deltat=V.sampling_period.rescale("ms"), axis=0)
#params, est = fit_exponentials(VV, 1, deltat=V.sampling_period.rescale("ms"), axis=0)
plt.plot(t[:i_thresh], est)
params[0]["amplitude"]

In [None]:
pos = (params["amplitude"] > 0) & (params["rate"] > 0)
idx = params["rate"][pos].argmin()
tau = params["rate"][pos][idx] * pq.ms
dV = params["amplitude"][pos][idx] * pq.mV
dI = (hI[0] - hI[-7500:].mean()) * pq.pA
Rm = (dV / dI).rescale(MOhm)
Cm = (tau / Rm).rescale(pFarad)
idx = int(min(25, tau / 2) / sampling_period)
mse = np.sqrt(np.mean((VV[:idx] - est[:idx])**2)) / dV
idx, mse, tau, dV, dI, Rm, Cm

In [None]:
hypol_I = []
hypol_V = []
fig, axes = plt.subplots(2, sharex=True)
for segment in block.segments:
    V = segment.analogsignals[0].load()
    I = segment.analogsignals[1].load()
    start = step_start[5]
    end = start + step_len[5]
    VV = V[start:end].squeeze()
    II = I[start:end].squeeze()
    axes[0].plot(VV)
    axes[1].plot(II)
    hypol_I.append(II)
    hypol_V.append(VV)
    

In [None]:
from chebyfit import fit_exponentials
exponential_decay_thresh = 0.99
hI = np.mean(hypol_I, axis=0)
VV = np.mean(hypol_V, axis=0)
sampling_period = V.sampling_period.rescale("ms").magnitude
t = np.arange(0, 500, sampling_period)
plt.plot(t, VV)
i_max = VV.argmax()
thresh = VV[0] + (VV[i_max] - VV[0]) * exponential_decay_thresh
i_thresh = first_index(lambda x: x > thresh, VV[:i_max])
params, est = fit_exponentials(VV[:i_thresh], 1, deltat=V.sampling_period.rescale("ms"), axis=0)
#params, est = fit_exponentials(VV, 1, deltat=V.sampling_period.rescale("ms"), axis=0)
plt.plot(t[:i_thresh], est)
params

In [None]:
a_spike = np.array([-1290,  -483,  -136,  -148,  -186,   637,   328,    41,    63,
                    42,   377,   872,   639,   -17,   538,   631,   530,   693,
                    743,  3456,  6345,  5868,  4543,  3087,  1691,   830,   241,
                    -350,  -567,  -996,  -877, -1771, -1659, -1968, -2013, -2290,
                    -2143, -1715, -1526, -1108,  -500,   333,    25,  -388,  -368,
                    -435,  -817,  -858,  -793, -1089,   -16,  -430,  -529,  -252,
                    -3,  -786,   -47,  -266,  -963,  -365], dtype=np.int16)
t_peak = a_spike.argmax()
t_trough = a_spike.argmin()
plt.plot(a_spike)
plt.plot(np.gradient(a_spike))

In [None]:
smoothed = qst.fftresample(a_spike.reshape((1, a_spike.size)), a_spike.size * 3)[0]
plt.plot(smoothed)
plt.plot(np.gradient(smoothed))

print(t_peak, t_trough)

In [None]:
def rle(arr: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """ run length encoding. Partial credit to R rle function.
    Multi datatype arrays catered for including non Numpy
    returns: tuple (runlengths, startpositions, values)

    from stackoverflow 1066758

    """
    n = len(arr)
    if n == 0:
        return (None, None, None)
    else:
        y = np.array(arr[1:] != arr[:-1])     # pairwise unequal (string safe)
        i = np.append(np.where(y), n - 1)     # must include last element posi
        z = np.diff(np.append(-1, i))         # run lengths
        p = np.cumsum(np.append(0, z))[:-1]   # positions
        return(z, p, arr[i])


def find_onset(spk, dV_thresh=10.0, n_baseline=100, min_rise=20):
    """Returns the onset offset of a spike, relative to the peak time.

    *spk* must be a 1D vector containing all the points up to the peak of the
    spike. At least the first 100 samples should be before the onset of the peak
    in order to collect statistics on the gradient.

    """
    # There are still some issues here for fast spikers, which may have ISIs
    # shorter than the pre-peak window. The only real solution right now is to
    # adjust the window accordingly.
    dV = np.gradient(spk)
    mdV = dV[:n_baseline].mean()
    sdV = dV[:n_baseline].std()
    thresh = mdV + dV_thresh * sdV
    n, p, v = rle(dV > thresh)
    pp = p[(n >= min_rise) & v]
    ind = pp[-1]                # will raise IndexError if no hits
    if ind < n_baseline:
        raise IndexError("onset occurred in baseline - need more data")
    return spk.size - ind


def find_trough(spk, min_rise=5):
    """Find the local minimum after a spike.

    *spk* must be a 1D vector containing the points after the peak of the spike.
     This function finds when the derivative becomes greater or equal to
     zero for at least min_rise samples.

    """
    dV = np.gradient(spk)
    n, p, v = rle(dV >= 0)
    pp = p[(n >= min_rise) & v]
    return pp[0]                # will raise IndexError if no hits


In [None]:
n_baseline = 100
dV_thresh = 10.0
peak_time = 200
spks = qs.peaks(b_recording.astype('d'), [7635, 8412, 33778], peak_time, 100)
for spk in spks:
    plt.plot(spk)
    spk_on = find_onset(spk[:peak_time], n_baseline=50)
    print(peak_time-spk_on)
    plt.plot(peak_time-spk_on, spk[peak_time-spk_on], 'ro')

In [None]:
#np.save("../intra_spike_narrow.npy", c_recording)
#print(c_times)

In [None]:
spikes = qs.peaks(c_recording.astype("d"), c_times, 200, 700)
det = qs.detector(-20, 100)
for spike in spikes:
    t = det.send(spike)
    print(t)
    plt.plot(spike)

In [None]:
n_baseline = 100
dV_thresh = 10.0
peak_time = 200
spks = qs.peaks(c_recording.astype('d'), c_times[:2], peak_time, 100)
for spk in spks:
    plt.plot(spk)
    spk_on = find_onset(spk[:peak_time], n_baseline=n_baseline, min_rise=13)
    print(peak_time-spk_on)
    plt.plot(peak_time-spk_on, spk[peak_time-spk_on], 'ro')

In [None]:
dV = np.gradient(spk)
mdV = dV[:n_baseline].mean()
sdV = dV[:n_baseline].std()
thresh = mdV + dV_thresh * sdV
plt.plot(dV)
print(thresh)

In [None]:
rle(dV > thresh)

In [None]:
from collections import namedtuple
help(namedtuple)

In [None]:
def first_index(fn, seq):
    """ Returns the index of the first value in seq where fn(x) is True"""
    return next((i for (i, x) in enumerate(seq) if fn(x)), None)


In [None]:
b_recording = np.load("../intra_spike.npy")
b_times = [7635, 8412, 9363, 10424, 11447, 12661, 13887, 15079, 16373,
           17753, 19168, 20682, 22357, 23979, 25574, 27209, 28989,
           30508, 32088, 33778]
spks = qs.peaks(b_recording.astype('d'), b_times, 200, 400)
plt.plot(spks.T);

In [None]:
times, aligned = qs.realign_spikes(b_times, spks, 3, 4)
plt.plot(aligned.T);

In [None]:
fig, axes = plt.subplots(2)
axes[0].plot(b_recording)
from quickspikes.intracellular import SpikeFinder
detector = qsi.SpikeFinder(50, 350, 5000)
res = detector.calculate_threshold(b_recording)
print(res)
for time, spike in detector.extract_spikes(b_recording, 10):
    axes[1].plot(spike)
    axes[0].plot(time, b_recording[time], "ro")
axes[0].set_xlim(7000,9000)
#times = [time for time, spike in detector.extract_spikes(b_recording, 10)]