In [1]:
%matplotlib
import gzip
import pickle
from datetime import datetime
from glob import glob

import warnings

from video_writer import VideoWriter

warnings.filterwarnings("error")

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as linalg
from matplotlib import cm
from scipy.interpolate import interp1d
from scipy.ndimage import correlate1d
from scipy.stats import t as t_dist
from scipy.stats import multivariate_normal as mvnd
from scipy.signal import (
    correlate,
    cspline1d,
    find_peaks,
    cwt,
    ricker,
    medfilt,
    morlet2,
    hilbert,
    filtfilt,
    butter,
    sepfir2d,
)
import ruptures as rpt
from tqdm import trange, tqdm

# from dtw import dtw



Using matplotlib backend: Qt5Agg


In [12]:

def smooth_and_norm_real(vals, winsize=95, smoothing=0.2, zcutoff=3.0):
    win = np.ones((winsize,))
    win_count = correlate(np.ones_like(vals), win)
    signal = np.copy(vals)
    init_rms = np.sqrt(correlate(np.square(signal), win) / win_count)
    init_cond = np.where(
        init_rms[: 1 - winsize] < init_rms[winsize - 1 :],
        init_rms[: 1 - winsize],
        init_rms[winsize - 1 :],
    )
    for _ in range(8):
        moving_rms = np.sqrt(correlate(np.square(signal), win).clip(0.0) / win_count)
        min_rms = np.minimum(moving_rms[: 1 - winsize], moving_rms[winsize - 1 :]).clip(
            1e-5
        )
        magclip = (
            zcutoff * min_rms / np.maximum(min_rms * zcutoff, np.abs(signal)).clip(1e-5)
        )
        signal *= magclip
    moving_avg = correlate(signal / min_rms, win) / win_count
    min_avg = np.where(
        init_cond,
        moving_avg[: 1 - winsize],
        moving_avg[winsize - 1 :],
    )
    min_avg = moving_avg[winsize // 2 : -(winsize // 2)]

    smoothed = cspline1d(signal / min_rms - min_avg, lamb=smoothing)

    return smoothed, vals / min_rms - min_avg


def get_rms(signal, winsize=95):
    win = np.ones((winsize,))
    win_count = correlate(np.ones_like(signal), win, mode="same")
    rms = np.sqrt(correlate(np.square(signal), win, mode="same") / win_count)
    return rms

def norm_complex(real, imag, winsize=95, smoothing=0.2, zcutoff=3.0):
    win = np.ones((winsize,))
    win_count = correlate(np.ones_like(real), win)
    signal = real + 1j * imag
    for _ in range(8):
        moving_rms = np.sqrt(correlate(np.square(np.abs(signal)), win) / win_count)
        min_rms = np.minimum(moving_rms[: 1 - winsize], moving_rms[winsize - 1 :])
        magclip = (
            zcutoff * min_rms / np.maximum(min_rms * zcutoff, np.abs(signal)).clip(1e-5)
        )
        signal *= magclip
    return (real + 1.0j * imag) / min_rms


def smooth_and_norm_complex_stitch(
    real_lead, imag_lead, real_lag, imag_lag, winsize=95, smoothing=0.2, zcutoff=3.0
):
    orig_ld_re = cspline1d(real_lead, lamb=smoothing)
    orig_ld_im = cspline1d(imag_lead, lamb=smoothing)
    orig_lg_re = cspline1d(real_lag, lamb=smoothing)
    orig_lg_im = cspline1d(imag_lag, lamb=smoothing)
    win = np.ones((winsize,))
    win_count = correlate(np.ones_like(real_lead), win)
    lead_signal = orig_ld_re + 1.0j * orig_ld_im
    lag_signal = orig_lg_re + 1.0j * orig_lg_im
    # TODO: pick a better way of stitching together, eg which is waviest?
    # TODO: also should do a moving average of each individual component and subtract it
    # currently there are jumps when it switches from one to the other. These suck.
    # TODO: if a component is NOT very wavy, should shrink it by, eg, 1/3
    # currently some components are very noisy in some places and this causes problems.
    lead_rms = np.sqrt(correlate(np.square(np.abs(lead_signal)), win) / win_count)
    lag_rms = np.sqrt(correlate(np.square(np.abs(lag_signal)), win) / win_count)
    signal = np.where(
        lead_rms[: 1 - winsize] < lag_rms[winsize - 1 :], lead_signal, lag_signal
    )
    orig_sig = np.copy(signal)
    for _ in range(8):
        moving_rms = np.sqrt(correlate(np.square(np.abs(signal)), win) / win_count)
        min_rms = np.minimum(moving_rms[: 1 - winsize], moving_rms[winsize - 1 :])
        magclip = (
            zcutoff * min_rms / np.maximum(min_rms * zcutoff, np.abs(signal)).clip(1e-5)
        )
        signal *= magclip
    return orig_sig / min_rms


def fit_linear_lead_and_lag(yi, winsize=155):
    k = winsize - 1
    sd = k * (k + 1) / 2
    sdsq = k * (k + 1) * (2 * k + 1) / 6
    win = np.ones(winsize)
    invdenom = 1.0 / (sdsq * winsize - sd ** 2)
    slp_kern = np.arange(winsize)
    avg_kern = np.full(winsize, 1.0 / winsize)
    fslp_lag = correlate(yi, slp_kern, mode="full")[winsize - 1:]
    fslp_lead = correlate(yi, -slp_kern[::-1], mode="full")[:1 - winsize]
    fsum = correlate(yi, win, mode="full")
    fsum_lead = fsum[:1 - winsize]
    fsum_lag = fsum[winsize - 1:]
    yi2 = np.square(yi)
    sumsq = correlate(yi2, win, mode="full")
    sumsq_lead = sumsq[:1 - winsize]
    sumsq_lag = sumsq[winsize - 1:]
    m_lag = (fslp_lag * winsize - sd * fsum_lag) * invdenom
    b_lag = (fsum_lag * sdsq - sd * fslp_lag) * invdenom
    m_lead = (fslp_lead * winsize + sd * fsum_lead) * invdenom
    b_lead = (fsum_lead * sdsq + sd * fslp_lead) * invdenom
    serr_lag = (
        np.square(b_lag) * winsize
        - 2 * b_lag * fsum_lag
        + 2 * b_lag * m_lag * sd
        + sumsq_lag
        - 2 * m_lag * fslp_lag
        + np.square(m_lag) * sdsq
    )
    serr_lead = (
        np.square(b_lead) * winsize
        - 2 * b_lead * fsum_lead
        - 2 * b_lead * m_lead * sd
        + sumsq_lead
        - 2 * m_lead * fslp_lead
        + np.square(m_lead) * sdsq
    )
    
    sst_lead = sumsq_lead - np.square(fsum_lead) / winsize
    sst_lag = sumsq_lag - np.square(fsum_lag) / winsize
    
    r2_lead = 1.0 - serr_lead / sst_lead.clip(1e-6)
    r2_lag = 1.0 - serr_lag / sst_lag.clip(1e-6)

    return m_lead * winsize, m_lag * winsize, b_lead, b_lag, np.sqrt(serr_lead.clip(min=0.0) / winsize), np.sqrt(serr_lag.clip(min=0.0)  / winsize), r2_lead, r2_lag


def compute_stats(xi, yi, winsize=255, mode="nearest"):
    hw = winsize // 2
    hwmo = (winsize - 1) // 2
    win = np.ones((winsize,))
    xiyi = xi * yi
    xi2 = np.square(xi)
    yi2 = np.square(yi)
    win_xiyi = correlate1d(xiyi, win, mode=mode)
    win_xi2 = correlate1d(xi2, win, mode=mode)
    win_yi2 = correlate1d(yi2, win, mode=mode)
    m = win_xiyi / win_xi2.clip(1e-8)
    b = correlate1d(yi, win, mode=mode)
    idx = np.arange(xi.shape[0])[:, None] + np.arange(winsize) - hwmo
    ywin = yi[idx.clip(min=0, max=xi.size - 1)]
    yhat = idx * m[:, None] + b[:, None]
    sse = np.sum(np.square(ywin - yhat), axis=1)
    r2 = 1 - sse / win_yi2.clip(min=1e-9)

    # angles = -np.arctan(m)
    # angles = np.concatenate(
    #     (np.full((hwmo,), angles[0]), angles, np.full((hw,), angles[-1]))
    # )
    # R = np.array([[np.cos(angles), -np.sin(angles)], [np.sin(angles), np.cos(angles)]])
    # input = np.stack((xi, yi), axis=0)[:, None, :]
    # res = np.sum(R * input, axis=1)[0]

    # res3 = res ** 3
    # win_res3 = correlate(res3, win, mode=mode)
    # skew = np.concatenate(
    #     (np.full((hw,), win_res3[0]), win_res3, np.full((hw,), win_res3[-1]))
    # )

    return b, m, r2, sse / winsize


def correlate_peaks(sig, start, length, min_shift, chunk):
    seg = sig[start : start + chunk]
    comp = sig[start + min_shift : start + length + min_shift]
    return correlate(comp, seg, mode="valid"), seg, comp


def get_turning_points(sig):
    deltas = sig[1:] - sig[:-1]
    turn_pts = deltas[1:] * deltas[:-1] <= 0.0
    curvature = correlate(
        sig, np.array([5.0, 0.0, -3.0, -4.0, -3.0, 0.0, 5.0]), mode="same"
    )

    peaks = np.argwhere(np.logical_and(deltas[:-1] > 0.0, turn_pts))
    troughs = np.argwhere(np.logical_and(deltas[:-1] < 0.0, turn_pts))

    # unlikely that a true breath would be faster than 3 seconds, ie 15 frames
    dist = 15
    peak_dists = peaks[1:] - peaks[:-1]
    trough_dists = troughs[1:] - troughs[:-1]

    # find the highest peak, eliminate any peaks too closeby, which means also
    # eliminating some troughs... but which ones... fuck, I have to think about this.
    while False:
        pass


def compute_r2(data, fit, winsize=255):
    hw = winsize // 2
    win = np.ones((winsize,))
    err = data - fit
    avgs = correlate(data, win, mode="same") / winsize
    sum_err2 = correlate(np.square(err), win, mode="same")
    idx = np.arange(data.shape[0])[:, None] + np.arange(winsize)
    data_ext = np.concatenate((np.zeros(hw), data, np.zeros(hw)), axis=0)
    sum_var = np.sum(np.square(data_ext[idx] - avgs[:, None]), axis=1)
    sum_var = np.var(data_ext[idx] - avgs[:, None], axis=1) * winsize
    r2 = 1 - sum_err2 / sum_var.clip(1e-5)

    return r2, err, sum_err2, sum_var


def get_findiff_curvature(data):
    curvature = np.zeros_like(data)
    curvature[1:-1] = data[:-2] + data[2:] - 2.0 * data[1:-1]
    curvature[0] = curvature[1]
    curvature[-1] = curvature[-2]
    return curvature


def morlet_real(*args, **kwargs):
    return np.real(morlet2(*args, **kwargs))


def get_slopes(data):
    dx = data[:, 1:] - data[:, :-1]  # x right
    dy = data[1:] - data[:-1]  # y down
    de = data[1:, 1:] - data[:-1, :-1]  # y down, x right
    do = data[1:, :-1] - data[:-1, 1:]  # y down, x left

    tl = de[:-1, :-1]
    tc = dy[:-1, 1:-1]
    tr = do[:-1, 1:]
    rc = -dx[1:-1, 1:]
    br = -de[1:, 1:]
    bc = -dy[1:, 1:-1]
    bl = -do[1:, :-1]
    lc = dx[1:-1, :-1]

    res = np.zeros((8, data.shape[0], data.shape[1]))
    res[:, 1:-1, 1:-1] = np.stack((tl, tc, tr, rc, br, bc, bl, lc), axis=0)
    return res


def wavefinding_cwt(signal, widths, omega=5):
    output = np.empty((len(widths), len(signal)), dtype=np.complex128)
    for ind, width in enumerate(widths):
        # go for an odd window length about 8x the length of the width
        N = round(4 * width - 0.5) * 2 + 1
        N = np.min([N, len(signal)])
        wavelet_data = morlet2(N, width, omega)
        # using correlate instead of convolve
        output[ind] = correlate(
            signal.astype(np.complex128), wavelet_data, mode="same"
        ) * np.exp(-1.0j * omega * np.arange(len(signal)) / width)
    return output


def suppress_noise(real, imag):
    """Normalizes the traces, then compares them and weights the wavier one higher."""
    real_norm, _ = smooth_and_norm_real(real)
    imag_norm, _ = smooth_and_norm_real(imag)
    omega = 20.0
    fs = 5.0
    freqs = np.logspace(0.1, -1.4, 150)  # ~50-85 are breathing frequencies
    widths_morlet = omega * fs / (freqs[50:85] * 2 * np.pi)
    real_wave = wavefinding_cwt(real_norm, widths_morlet, omega)
    mags_real = np.sum(np.square(np.abs(real_wave)), axis=0)
    imag_wave = wavefinding_cwt(imag_norm, widths_morlet, omega)
    mags_imag = np.sum(np.square(np.abs(imag_wave)), axis=0)
    # clip at 4 - don't want to be making up bs waves
    ratio = np.square(mags_real / mags_imag).clip(0.25, 4.0)
    # normalize again after the ratio thing
    normed = norm_complex(real_norm * ratio, imag_norm / ratio)
    # these work pretty well I've found
    widths_peak = np.arange(4, 80) * 0.2
    cplx_peaks = cwt(normed, ricker, widths_peak, dtype=np.complex128)
    cplx_slopes = get_slopes(np.abs(cplx_peaks))
    peak_cond = np.sum(np.sign(cplx_slopes), axis=0) == 8.0
    peak_vals = cplx_peaks[peak_cond]
    arg = np.mod(np.angle(np.square(cplx_peaks)) * 0.5 / np.pi, 1.0)
    cols = cm.hsv(arg)[..., :3]
    mags = np.abs(cplx_peaks)
    mags *= 999.0 / np.max(mags)
    mags += 1.0
    cols *= np.log10(mags[..., None]) / 3.0
    cols[peak_cond, :] = 0.0

    amax = np.argmax(np.abs(cplx_peaks[:26]), axis=0).astype(float)
    amax_sm = np.around(cspline1d(medfilt(amax, 1023), 5000000)).astype(int)
    asym_trace = np.mean(np.abs(cplx_peaks[5:25]), axis=0)
    max_trace = np.abs(cplx_peaks[amax_sm, np.arange(amax.size)])

    # find the minima of both traces
    asym_min = np.logical_and(
        asym_trace[:-2] > asym_trace[1:-1], asym_trace[1:-1] < asym_trace[2:]
    )
    max_min = np.logical_and(
        max_trace[:-2] > max_trace[1:-1], max_trace[1:-1] < max_trace[2:]
    )
    asym_mean = np.mean(cplx_peaks[6:35], axis=0)  # empirical values
    asym_rot = np.angle(asym_mean[1:] / asym_mean[:-1])
    rot_conv = correlate(asym_rot, np.ones(3), mode="same")
    asym_pk_arg, *_ = find_peaks(
        np.abs(rot_conv), height=0.9, distance=4, prominence=0.2
    )
    asym_min_arg = np.argwhere(asym_min).squeeze()
    max_min_arg = np.argwhere(max_min).squeeze()
    # min_max_min = max_min_arg - 2
    # max_max_min = max_min_arg + 2
    # min_cond = np.any(
    #     np.logical_and(
    #         asym_min_arg[:, None] >= min_max_min[None, :],
    #         asym_min_arg[:, None] <= max_max_min[None, :],
    #     ),
    #     axis=1,
    # )
    # rem_args = asym_min_arg[min_cond] + 1
    rem_args = asym_pk_arg
    arg_diffs = rem_args[1:] - rem_args[:-1]
    mid_heights = max_trace[np.around(0.5 * (rem_args[1:] + rem_args[:-1])).astype(int)]

    # used for calculations inside
    nfold = np.sqrt(np.square(normed))
    nsq = np.abs(normed) * normed
    rotated = np.zeros_like(normed, dtype=np.complex128)

    # now we need to segment where the signal goes outside of the range (-3.0, 3.0)
    past_3 = np.argwhere(np.abs(normed) > 3.0).squeeze()
    num_outliers = len(past_3)
    outliers = np.zeros(num_outliers + 2, dtype=int)
    outliers[1:-1] = past_3
    outliers[-1] = len(normed) - 1
    for idx in range(num_outliers + 1):
        start = outliers[idx]
        stop = outliers[idx + 1]
        if start < 109805 < stop:
            print("stop here")
        inside_cond = np.logical_and(rem_args > start, rem_args < stop)
        inside_args = np.argwhere(inside_cond).squeeze()
        num_mins = np.count_nonzero(inside_cond)
        if num_mins >= 3:
            # this is where we can start to feel pretty confident.
            # But we'll still rotate each peak based on the average within the peak.
            # outside the peak, I guess just rotate based on the closest peak?
            first_min = np.min(inside_args)
            last_min = np.max(inside_args - (1 if num_mins % 2 == 0 else 0))
            diffs = arg_diffs[first_min:last_min]
            hghts = mid_heights[first_min:last_min]
            if num_mins > 4:
                diff_scores = (
                    (np.mean(diffs) - diffs) / np.std(diffs).clip(1e-5)
                ).clip(-3.0, 3.0)
                height_scores = (
                    (hghts - np.mean(hghts)) / np.std(hghts).clip(1e-5)
                ).clip(-3.0, 3.0)
                parity = (
                    np.mean(diff_scores[::2] + height_scores[::2])
                    - np.mean(diff_scores[1::2] + height_scores[1::2])
                ) < 0  # zero if even are the peaks, else odd are
            else:
                seg = normed[rem_args[first_min] : rem_args[last_min]]
                frac_re = np.sum(np.square(np.real(seg) / np.abs(seg)))
                if frac_re > 0.5:
                    nseg = np.real(seg) - np.mean(np.real(seg))
                    # frac_under = np.count_nonzero(
                    #     (np.real(seg) < nseg
                    # )) / len(seg)
                else:
                    nseg = np.imag(seg) - np.mean(np.imag(seg))
                    # frac_under = np.count_nonzero(
                    #     (np.imag(seg) < np.mean(np.imag(seg)))
                    # ) / len(seg)
                parity = (
                    np.mean(np.abs(nseg[: diffs[0]]))
                    - np.mean(np.abs(nseg[diffs[0] :]))
                    > 0
                )
            # next, average across each peak.
            rot_angles = np.zeros((num_mins - 1) // 2, dtype=np.complex128)
            for jdx in range((num_mins - 1) // 2):
                first = rem_args[inside_args[jdx * 2]]
                mid = rem_args[inside_args[jdx * 2 + 1]]
                last = rem_args[inside_args[jdx * 2 + 2]]
                parity_av = np.mean(normed[first:mid]) - np.mean(normed[mid:last])
                parity_av /= np.abs(parity_av)
                if parity:
                    parity_av *= -1.0
                # first_pk = rem_args[inside_args[jdx * 2 + parity]]
                # last_pk = rem_args[inside_args[jdx * 2 + 1 + parity]]
                # wav_av = np.mean(nfold[first:last])
                # peak_av = np.mean(normed[first_pk:last_pk])
                rot_angles[jdx] = np.conj(parity_av)
                # rot_angles[jdx] = np.conj(
                #     wav_av * np.sign(np.real(wav_av / peak_av)) / np.abs(wav_av)
                # )
                rotated[first:last] = normed[first:last] * rot_angles[jdx]
                if first < 109805 < last:
                    print(first, last, rot_angles[jdx], wav_av, parity_av)
            # finally, do the parts on either side of the peaks...
            rotated[start : rem_args[first_min]] = (
                normed[start : rem_args[first_min]] * rot_angles[0]
            )
            rotated[rem_args[last_min] : stop] = (
                normed[rem_args[last_min] : stop] * rot_angles[-1]
            )

        else:
            # hard to tell what's going on really. We have one peak or less.
            # get the distance-weighted average, and rotate by that.
            wav_av = np.mean(nfold[start:stop])
            nsq_av = np.mean(nsq[start:stop])
            rot_angle = np.conj(
                wav_av * np.sign(np.real(wav_av / nsq_av)) / np.abs(wav_av)
            )
            diff = rotated[start - 1] - normed[start] * rot_angle
            ndiff = rotated[start - 1] + normed[start] * rot_angle
            result = np.sign(np.abs(ndiff) - np.abs(diff))
            rotated[start:stop] = normed[start:stop] * rot_angle * result

    _, (ax1, ax2) = plt.subplots(2, 1, sharex="col")
    ax1.imshow(cols, aspect="auto")
    ax2.plot(np.real(asym_mean))
    ax2.plot(np.abs(rot_conv))
    ax2.plot(asym_trace)
    ax2.plot(np.real(normed))
    ax2.plot(np.imag(normed))
    ax2.plot(real_norm, linestyle=":")
    ax2.plot(imag_norm, linestyle=":")
    ax2.plot(ratio, linestyle="--")
    ax2.scatter(asym_pk_arg, -4.5 * np.ones_like(asym_pk_arg))
    ax2.plot(np.real(rotated))
    # ax2.legend(
    #     [
    #         "asymmean",
    #         "rotconv",
    #         "asymtrace",
    #         "re(normed)",
    #         "im(normed)",
    #         "realnorm",
    #         "imnorm",
    #         "ratio",
    #         "rotated",
    #         "crossing",
    #     ]
    # )
    plt.show()

    # _, ax = plt.subplots(1, 1)
    # ax.plot(real_norm)
    # ax.plot(imag_norm)
    # ax.plot(np.log2(ratio))
    # plt.show()
    return (real_norm * ratio, imag_norm / ratio, ratio, cwt_real, cwt_imag)


def flip_signal(signal):
    smoothed, normed = smooth_and_norm_real(signal, smoothing=20.0)

    r2, *_ = compute_r2(normed.clip(-3.0, 3.0), smoothed, winsize=255)

    curv = get_findiff_curvature(smoothed)

    bmp = 21 / 60  # max breaths per second
    fs = 5  # sample rate
    max_curv_exag = (2 * np.pi * bmp / fs) ** 4

    curv_exag = (curv * np.abs(curv)) / max_curv_exag

    curv_exag_sm = correlate(curv_exag, np.ones(1001) / 10.0, mode="same")

    past_3 = np.argwhere(np.abs(normed) > 5.0).squeeze()
    num_outliers = len(past_3)
    outliers = np.zeros(num_outliers + 2, dtype=int)
    outliers[1:-1] = past_3
    outliers[-1] = len(signal) - 1

    righted = np.zeros_like(signal)
    stdev = 20.0
    windows = np.zeros_like(signal)

    for idx in range(num_outliers + 1):
        start = outliers[idx]
        stop = outliers[idx + 1]

        win_range = np.minimum(np.arange(stop - start), np.arange(stop - start)[::-1])
        window = 1.0 - np.exp(-0.5 * np.square(win_range.clip(max=3 * stdev) / stdev))
        windows[start:stop] = window

        right = np.sum(curv_exag[start:stop] * window) <= 0.0

        if right:
            righted[start:stop] = normed[start:stop]
        else:
            righted[start:stop] = -normed[start:stop]

    # _, ax = plt.subplots(1, 1)
    # ax.plot(normed)
    # ax.plot(smoothed)
    # ax.plot(windows)
    # ax.plot(righted, linestyle="--")
    # ax.plot(curv_exag)
    # ax.plot(curv_exag_sm)
    # ax.plot(r2)
    # ax.legend(["norm", "sm", "win", "right", "curv", "curvsm", "r2"])
    # plt.show()

    return smoothed, normed, r2, righted, curv_exag


def flip_components_indiv_and_combine(sig_x, sig_y, segments):
    """Normalizes the traces, then compares them and weights the wavier one higher."""
    sig_x_copy = np.copy(sig_x)
    sig_y_copy = np.copy(sig_y)

    for sdx, seg in enumerate(segments[:-1]):
        scale_x = flip_signal(sig_x[seg:segments[sdx + 1]])
        scale_y = flip_signal(sig_y[seg:segments[sdx + 1]])
        sig_x_copy[seg:segments[sdx + 1]] *= scale_x
        sig_y_copy[seg:segments[sdx + 1]] *= scale_y

    omega = 20.0
    fs = 5.0
    freqs = np.logspace(0.1, -1.4, 150)  # ~50-85 are breathing frequencies
    widths_morlet = omega * fs / (freqs[55:80] * 2 * np.pi)
    real_wave = wavefinding_cwt(real_righted.clip(-3.0, 3.0), widths_morlet, omega)
    mags_real = np.sum(np.square(np.abs(real_wave)), axis=0)
    imag_wave = wavefinding_cwt(imag_righted.clip(-3.0, 3.0), widths_morlet, omega)
    mags_imag = np.sum(np.square(np.abs(imag_wave)), axis=0)

    rr_align = np.zeros_like(real_righted)
    ir_align = np.zeros_like(imag_righted)

    past_3 = np.argwhere(
        np.logical_or(np.abs(real_norm) > 5.0, np.abs(imag_norm) > 5.0)
    ).squeeze()
    num_outliers = len(past_3)
    outliers = np.zeros(num_outliers + 2, dtype=int)
    outliers[1:-1] = past_3
    outliers[-1] = len(real_righted) - 1

    stdev = 20.0

    for idx in range(num_outliers + 1):
        start = outliers[idx]
        stop = outliers[idx + 1]

        if start < 97250 < stop:
            print("shit")

        win_range = np.minimum(np.arange(stop - start), np.arange(stop - start)[::-1])
        window = 1.0 - np.exp(-0.5 * np.square(win_range.clip(max=3 * stdev) / stdev))

        agree = (
            np.sum(real_righted[start:stop] * imag_righted[start:stop] * window) >= 0.0
        )

        if agree:
            rr_align[start:stop] = real_righted[start:stop]
            ir_align[start:stop] = imag_righted[start:stop]
        else:
            # who has the higher sum of curvatures on their side?
            real_flip = np.sign(
                np.abs(np.sum(real_curv[start:stop]))
                - np.abs(np.sum(imag_curv[start:stop]))
            )
            rr_align[start:stop] = real_righted[start:stop] * real_flip
            ir_align[start:stop] = imag_righted[start:stop] * -real_flip

    real_frac = mags_real / (mags_real + mags_imag)
    imag_frac = 1.0 - real_frac

    result = real_frac * real_righted + imag_frac * imag_righted

    # _, ax = plt.subplots(1, 1)
    # ax.plot(real_righted)
    # # ax.plot(real_smooth)
    # ax.plot(imag_righted)
    # ax.plot(result)
    # # ax.plot(imag_smooth)
    # # ax.plot(result)
    # # ax.plot(real_nm_curv_conv + 3.0, linestyle="--")
    # # ax.plot(imag_nm_curv_conv + 3.0, linestyle="--")
    # # ax.plot(real_r2 - 3.0)
    # # ax.plot(imag_r2 - 3.0, linestyle="--")
    # # ax.plot(mags_real * 0.005)
    # # ax.plot(mags_imag * 0.005)
    # ax.plot(real_frac)
    # ax.plot(imag_frac)
    # ax.legend(["rr", "ir", "res", "rf", "if"])
    # plt.show()

    return result

def adjust_for_zero_trend(signal, lam=25.0):
    new_sig = np.copy(signal)
    sm = cspline1d(signal, lam)
    segs = np.argwhere(np.logical_and(sm[1:] * sm[:-1] < 0.0, sm[1:] < sm[:-1])).squeeze()
    segs[0] = 0
    for idx in range(len(segs) - 1):
        start = segs[idx]
        stop = segs[idx + 1]
        seg_offset = np.mean(signal[start:stop])
        new_sig[start:stop] -= seg_offset
        bnd_offset = np.mean(np.cumsum(new_sig[start:stop]))
        new_sig[start] -= 0.5 * bnd_offset
        new_sig[stop] += 0.5 * bnd_offset
    
    return new_sig, sm

def estimate_oob(signal_x, signal_y, r2_x, r2_y, r2_med_x, r2_med_y, ab_x, ab_y, pos_x, pos_y):
    
    sm = 81
    sm_pos_x = medfilt(pos_x, sm)
    sm_pos_y = medfilt(pos_y, sm)

    dx = np.abs(pos_x - sm_pos_x)
    dy = np.abs(pos_y - sm_pos_y)

    med_dx = np.median(dx)
    med_dy = np.median(dy)
    x_cond = dx > 12.0 * med_dx
    y_cond = dy > 12.0 * med_dy
    dx_extr = np.where(x_cond, 10.0, 0.0)
    dy_extr = np.where(y_cond, 8.0, 0.0)

    exp_k = 4
    either = np.logical_or(x_cond, y_cond).astype(float)
    expand = np.concatenate((np.full(exp_k, either[0]), either, np.full(exp_k, either[-1])))
    exp_cs = np.cumsum(expand)
    mov = exp_cs[2 * exp_k:] - exp_cs[:-2 * exp_k] > 0.5
    mov_mult = 1.0 + 9.0 * medfilt(mov.astype(float), 2 * exp_k + 1)
    
    best_r2 = np.maximum(r2_x, r2_y).clip(0.05, 0.99)
    best_med = np.maximum(r2_med_x, r2_med_y).clip(0.05, 0.99)
    sqmag = np.sum(np.square(ab_x), axis=1) + np.sum(np.square(ab_y), axis=1).clip(1e-12)
    log_mag = np.log(sqmag)
    avg_log_mag = np.mean(log_mag[best_r2 >= 0.8])
    std_log_mag = np.std(log_mag[best_r2 >= 0.8])
    
    oob_buf = 101
    mov_buf = 7
    
    # rate non-periodic by magnitude, and by med_r^2
    # if it's low r^2 and low magnitude, it's probably out of bed
    # odds that it's nonperiodic, where 0.8 = 1:1 odds
    # z-score for magnitudes, and odds that the signal is small
    log_mag_z = (avg_log_mag - log_mag) / std_log_mag
    med_zed = medfilt(log_mag_z, 501).clip(-1.6, 1.6)
    diff_z = med_zed - log_mag_z
    std_diff = np.std(diff_z[best_r2 >= 0.8])
    odds_spike = (0.667 * diff_z / std_diff - 1.0).clip(0.0) * 4.0 + np.where(np.logical_or(diff_z > 0.0, mov), 0.1, 0.0)
    # odds_spike = (0.22 * diff_z / std_diff).clip(0.333) * 3.0 
    odds_small = (log_mag_z - 1.0).clip(0.0)
    
    odds_nonper_inst = 16.0 * np.square(1.0 - best_r2)
    odds_large = 1.0 #np.square((-1.0 - log_mag_z).clip(0.01))
    odds_mov = odds_nonper_inst * odds_large * mov_mult * odds_spike
    p_mov = odds_mov / (odds_mov + 1.0)
    is_mov = medfilt((p_mov > 0.5).astype(float), mov_buf)
    
    # odds_non_per = 25.0 * np.square(1.0 - best_med)
    # inst_nonper_odds = 25.0 * np.square(np.square(1.0 - best_r2.clip(min=0.8)))
    # odds_non_per *= 625.0 * np.square(np.square(1.0 - best_r2.clip(min=0.8)))
    odds_nonper_med = 25.0 * np.square(1.0 - best_med)
    odds_nonper = odds_nonper_inst  # np.where(low_overflow > 0.0, odds_nonper_inst, odds_nonper_med)
    odds_ratio = odds_nonper * odds_small
    
    p_oob = medfilt(odds_ratio / (odds_ratio + 1.0), oob_buf)
    # let's plot that shit to see where we're at
    _, ax = plt.subplots(1, 1)
    ax.plot(odds_nonper.clip(0.0, 20.0), label="nonper odds", alpha=0.3, linestyle=":")
    # ax.plot(odds_small, label="small odds", alpha=0.3, linestyle=":")
    ax.plot(best_r2, label="inst", alpha=0.3, linestyle=":")
    ax.plot(best_med, label="med", alpha=0.3, linestyle=":")
    ax.plot(is_mov, label="med", alpha=0.3, linestyle=":")
    ax.plot(log_mag_z, label="zscore", alpha=0.6, linestyle=(0, (3, 1, 1, 1)))
    # ax.plot(odds_spike, label="spike", alpha=0.6, linestyle=(0, (2, 1, 1, 1)))
    # ax.plot(r2_x, label="r2x")
    # ax.plot(r2_y, label="r2y")
    ax.plot(1e3 * signal_x, alpha=0.2, linestyle=":")
    ax.plot(1e3 * signal_y, alpha=0.2, linestyle=":")
    # ax.plot(odds_ratio, alpha=0.5, linestyle=":", label="OR")
    ax.plot(p_oob, label="poob")
    ax.plot(p_mov, label="pmov", linestyle=(0, (3,1,1,1,1,1)))
    # ax.plot(mov)
    ax.legend()
    plt.show()
    
    mov = p_mov > 0.8
    meb_mov = p_mov > 0.55
    mov[0] = True
    mov[-1] = True
    starts = np.argwhere(np.logical_and(np.logical_not(mov[:-1]), mov[1:])).squeeze()
    stops = np.argwhere(np.logical_and(np.logical_not(mov[1:]), mov[:-1])).squeeze()
    n_segs = starts.size
    for idx in range(n_segs - 1):
        seg = np.arange(stops[idx], starts[idx])
        mov[seg] = np.all(meb_mov[seg])
    return mov
    

def find_discont(signal_x, signal_y, r2_x, r2_y, pos_x, pos_y, ab_x, ab_y, min_compare=37, max_compare=250):
    best_r2 = np.maximum(r2_x, r2_y)
    sqmag = np.sum(np.square(ab_x), axis=1) + np.sum(np.square(ab_y), axis=1).clip(1e-12)
    mag = np.sqrt(sqmag)

    # for each non periodic segment, check the probability that the position did not change
    # guarantees: per_begin[0] < per_end[0] and per_end[-1] > per_begin[-1] and lengths are the same
    thresh = 0.75
    per_begin = np.argwhere(np.logical_and(best_r2[:-1] < thresh, best_r2[1:] >= thresh)).squeeze()
    per_end = np.argwhere(np.logical_and(best_r2[:-1] >= thresh, best_r2[1:] < thresh)).squeeze()
    n_per = per_begin.size

    keep_begin = np.ones(n_per, dtype=bool)
    keep_end = np.ones(n_per, dtype=bool)
    mov_odds = np.ones(n_per + 1)
    rms_odds = np.ones(n_per + 1)
    shr_odds = np.ones(n_per + 1)
    npr_odds = np.ones(n_per + 1)
    
    # first trim the nonper that are too short (<=1.2s), and likely not movement or sighing
    # larger than the surrounding signal? Probably want to test it.
    # r^2 value still pretty high throughout? likely want to merge it
    for idx in range(n_per - 1):
        # get the lengths of the segments
        lg_dur = per_begin[idx + 1] - per_end[idx]
        short_or = np.log2(lg_dur / 8)
        # more sensitive to lower r^2 values
        rms_r2 = 1.0 - np.sqrt(np.mean(np.square(1.0 - best_r2[per_end[idx]:per_begin[idx + 1]])))
        nonper_or = np.exp(-8 * (rms_r2 - 0.65))
        rms_bef = np.sqrt(np.mean(sqmag[per_begin[idx]:per_end[idx]]))
        rms_dur = np.sqrt(np.mean(sqmag[per_end[idx]:per_begin[idx + 1]]))
        rms_aft = np.sqrt(np.mean(sqmag[per_begin[idx + 1]:per_end[idx + 1]]))
        rms_ratio = rms_dur / max(rms_bef, rms_aft) - 1.0
        # if significantly larger than the other bits of the signal, quite likely that it should be kept
        rms_or = 0.275 / np.square(max(-1.0, min(0.5, rms_ratio - 0.7)))
        keep_odds = short_or * nonper_or * rms_or
        keep_begin[idx + 1] = keep_odds > 1
        keep_end[idx] = keep_odds > 1
        mov_odds[idx + 1] = keep_odds
        rms_odds[idx + 1] = rms_or
        shr_odds[idx + 1] = short_or
        npr_odds[idx + 1] = nonper_or
        

    merged_begin = per_begin[keep_begin]
    merged_end = per_end[keep_end]
    merged_mov = mov_odds[1:][keep_end]
    merged_rms = rms_odds[1:][keep_end]
    merged_shr = shr_odds[1:][keep_end]
    merged_npr = npr_odds[1:][keep_end]
    n_merged = merged_begin.size
    
    print(n_per, n_merged)
    print(merged_end - merged_begin)
    
    # next up, pruning the periodic parts that are too short.
    keep_merged = merged_end - merged_begin >= min_compare
    final_begin = merged_begin[keep_merged]
    final_end = merged_end[keep_merged]
    n_final = final_begin.size
    p_mov = np.ones(n_final - 1)
    
    # use ks_test to find the likely discontinuities
    for idx in range(n_final - 1):
        bef_lg = final_end[idx] - final_begin[idx]
        aft_lg = final_end[idx + 1] - final_begin[idx + 1]
        bef_lg = min(max_compare, 2 * aft_lg, bef_lg)
        aft_lg = min(max_compare, 2 * bef_lg, aft_lg)
        bef_coords = np.arange(final_end[idx] - bef_lg, final_end[idx])
        aft_coords = np.arange(final_begin[idx + 1], final_begin[idx + 1] + aft_lg)
        mn = round((np.sum(bef_coords) + np.sum(aft_coords)) / (bef_lg + aft_lg))
        print(idx, final_end[idx])
        
        T_x, p_x = split_test(pos_x[bef_coords], pos_x[aft_coords], bef_coords, aft_coords, idx==85)
        T_y, p_y = split_test(pos_y[bef_coords], pos_y[aft_coords], bef_coords, aft_coords, idx==85)
        p_mov[idx] = max(T_x, T_y)
        
        
    # page 31 of this pdf is very important and helpful:
    # https://scholar.princeton.edu/sites/default/files/bstewart/files/lecture7_handout_2018.pdf
    # the sigma_squared_estimator * inv(XT X) gives the covariance matrix
    # T statistic to test if hypothesis that discontinuity doesn't exist (split param is 0)
    # is given by split/SE(split), where SE(split) = sqrt(Variance(split)) = sqrt(Cov(split,split))
    # page 32 and 24 then show how to use the t distribution. When estimating k+1 parameters, use
    # the t distribution t_(n - (k + 1)) what is the sigma_squared_estimator?
    # (uT u) / (n - (k + 1)) where u are the errors i.e. the sample average sum of square errors,
    # ie population mean square error
    
    _, (ax1,ax2) = plt.subplots(2, 1, sharex=True)
    ax1.plot(signal_x * 1e3, alpha=0.3, linestyle=":")
    ax1.plot(signal_y * 1e3, alpha=0.3, linestyle=":")
    ax1.plot(best_r2, alpha=0.3, linestyle=":")
    ax1.plot(sqmag * 1e6, alpha=0.3, linestyle=(0, (1, 3)))
    
    ax2.plot(pos_x)#medfilt(pos_x, 101))
    ax2.plot(pos_y)#medfilt(pos_y, 101))
    
    for idx in range(n_final):
        ax1.axvspan(final_begin[idx], final_end[idx], color="g", alpha=0.2)
        if idx == 0:
            continue
        ax1.plot(np.array([final_end[idx - 1], final_begin[idx]]), np.ones(2) * p_mov[idx - 1], linestyle=(0, ()))
    #     ax.plot(np.array([merged_end[idx - 1], merged_begin[idx]]), np.ones(2) * min(10.0, merged_rms[idx - 1]), linestyle=(0, (1, 1)))
    #     ax.plot(np.array([merged_end[idx - 1], merged_begin[idx]]), np.ones(2) * min(10.0, merged_shr[idx - 1]), linestyle=(0, (3, 1)))
    #     ax.plot(np.array([merged_end[idx - 1], merged_begin[idx]]), np.ones(2) * min(10.0, merged_npr[idx - 1]), linestyle=(0, (5, 1, 1, 1)))
    # ax.legend(["sigx", "sigy", "r2", "mag", "mov", "rms", "shr", "npr"])
    plt.show()
    return


    keep_start[1:] = mov_or_sigh > 0.5
    keep_stop[:-1] = mov_or_sigh > 0.5
    contig_starts = starts[keep_start]
    contig_stops = stops[keep_stop]
    discont = np.zeros(contig_starts.size - 1)

    for idx in range(contig_starts.size - 1):
        # get the lengths of the segments
        lg_bef = contig_stops[idx] - contig_starts[idx]
        lg_dur = contig_starts[idx + 1] - contig_stops[idx]
        lg_aft = contig_stops[idx + 1] - contig_starts[idx + 1]
        # segments should not differ in length by TOO much, or else the KS test will fail
        use_lg = 2 * min(250, lg_bef, lg_aft)
        bef_b = contig_stops[idx] - min(use_lg, lg_bef)
        aft_e = contig_starts[idx + 1] + min(use_lg, lg_aft)
        bef_t = np.arange(bef_b, contig_stops[idx], dtype=float)
        aft_t = np.arange(contig_starts[idx + 1], aft_e, dtype=float)
        p_x = ks_test(n[bef_b:contig_stops[idx], 2], n[contig_starts[idx + 1]:aft_e, 2], bef_t, aft_t)
        p_y = ks_test(n[bef_b:contig_stops[idx], 3], n[contig_starts[idx + 1]:aft_e, 3], bef_t, aft_t)
        discont[idx] = 1.0 - p_x * p_y


def fit_piecewise_contin(y_bef, t_bef, y_aft, t_aft, sig, prt=False):
    ctr = 0.5 * (t_bef[-1] + t_aft[0])
    rad = t_aft[0] - ctr
    xs = np.concatenate((t_bef, t_aft)) - ctr
    ys = np.concatenate((y_bef, y_aft))
    ys -= np.mean(ys)
    lg = xs.size
    szb = y_bef.size
    sza = y_aft.size
    X = np.stack(
        (
            np.minimum(xs + rad, 0.0),
            np.maximum(xs - rad, 0.0),
            np.minimum(rad, np.maximum(xs, -rad)),
        ), axis=-1)
    X -= np.mean(X, axis=0)
    cov = linalg.inv(np.dot(X.T, X)) # covariance of X, sorta
    pinv = linalg.pinv(X)
    coeffs = np.dot(pinv, ys)
    disc_var = sig * np.sqrt(cov[2, 2])
    s3 = np.sqrt(1/3)
    # tf = np.array([[s3, s3, s3],[-1, 0, 1],[0, -1, 1]])
    tf = np.array([[1.0, 0, -1], [1, 1, 0], [0, -1, 1]]) # middle one is arbitrary... suspiciously like the other one xD
    Xtf = np.dot(X, np.linalg.inv(tf))
    covtf = np.linalg.inv(np.dot(Xtf.T, Xtf))
    pinvtf = linalg.pinv(Xtf)
    coeffstf = np.dot(pinvtf, ys)
    if prt:
        _, ax = plt.subplots(1, 1)
        ax.plot(xs, X)
        ax.plot(xs, ys)
        ax.plot(xs, np.dot(X, coeffs))
        ax.plot(xs, np.dot(X, tf.T), linestyle=":")
        # ax.plot(xs, shit / mbef)
        # u,s,v = linalg.svd(X.T, full_matrices=False)
        # ax.plot(xs, v.T)
        plt.show()
        print(cov)
        print(covtf)
        print(coeffs, np.dot(tf, coeffs), coeffstf)
        print(np.dot(tf, sig**2 * cov))
    # how is that derived? let slopes of before, after and middle be x, y, z respectively
    # start off with the condition z + const > max(x, y) OR z - const < min(x, y)
    # we need to rotate this space so that it can be integrated with scipy's multivariate normal distribution
    # so first, rotate about the z axis by 1/4 turn so that x=y lies along x axis
    # next rotate about y axis by 35 degrees (asin(sqrt(1/3)) so that z=x=y lies along x axis
    # stretch along the z axis by sqrt(3) so that the dihedral angle becomes a right angle
    # finally rotate about x axis so that the appropriate boundaries are parallel to xy nd xz planes
    # now you can just integrate both regions, as they are rectangular and have appropriate boundaries

    # cond_1 = mvnd.cdf(np.array([np.inf, -disc_var, -disc_var]), mean=np.dot(tf, coeffs), cov=np.dot(tf, sig**2 * cov))
    # cond_2 = mvnd.cdf(np.array([np.inf, -disc_var, -disc_var]), mean=np.dot(tf, -coeffs), cov=-np.dot(tf, sig**2 * cov))
    # cond_1 = mvnd.cdf(np.array([np.inf, -disc_var, -disc_var]), mean=coeffstf, cov=sig**2 * covtf)
    # cond_2 = mvnd.cdf(np.array([np.inf, -disc_var, -disc_var]), mean=-coeffstf, cov=sig**2 * covtf)
    cond_1 = mvnd.cdf(np.array([0.25 * sig / rad, np.inf, -0.25 * sig / rad]), mean=coeffstf, cov=sig**2 * covtf)
    cond_2 = mvnd.cdf(np.array([0.25 * sig / rad, np.inf, -0.25 * sig / rad]), mean=-coeffstf, cov=sig**2 * covtf)
    cond_3 = mvnd.cdf(np.array([np.inf, np.inf, 0.25 * sig / rad]), mean=coeffstf, cov=sig**2 * covtf)
    cond_4 = mvnd.cdf(np.array([np.inf, np.inf, -0.25 * sig / rad]), mean=coeffstf, cov=sig**2 * covtf)
    if prt:
        print(cond_1, cond_2, cond_3, cond_4, 0.25 * sig / rad, np.sqrt(np.diag(covtf * sig ** 2)))
    return cond_1 + cond_2 + cond_3 - cond_4


def find_discont_piecewise(pos_x, pos_y, mov, min_length=100, max_length=300, buf=20, sm=81):
    px = medfilt(pos_x, sm)
    py = medfilt(pos_y, sm)
    starts = np.argwhere(np.logical_and(np.logical_not(mov[:-1]), mov[1:])).squeeze()
    stops = np.argwhere(np.logical_and(np.logical_not(mov[1:]), mov[:-1])).squeeze()
    n_seg = starts.size
    pdx = 0
    ndx = 1
    shift = np.zeros((px.size, 2))
    while ndx < n_seg:
        bend = starts[pdx] - buf
        astr = stops[ndx] + buf
        bl = starts[pdx] - 2 * buf - stops[pdx]
        al = starts[ndx] - 2 * buf - stops[ndx]
        dl = stops[ndx] + 2 * buf - starts[pdx]
        if al < min_length:
            ndx += 1
            continue
        bef = np.arange(bend - min(max_length, bl), bend)
        aft = np.arange(astr, astr + min(max_length, al))
        dur = np.arange(bend, astr)

        prob_x = fit_piecewise_contin(px[bef], bef, px[aft], aft, 1.2, pdx == 185)
        prob_y = fit_piecewise_contin(py[bef], bef, py[aft], aft, 1.2, pdx == 185)
        print(pdx, ndx, stops[pdx], stops[ndx], prob_x, prob_y)
        shift[dur, 0] = 1.0 - prob_x
        shift[dur, 1] = 1.0 - prob_y
        pdx = ndx
        ndx += 1
    return shift

    
def find_periodic_fit_best(signal_x, signal_y, freqs, pos_x, pos_y, bx=50, by=5):
    phase_freq_x, k_freq_x, r2_x, ab_x, r2_med_x = find_periodic(signal_x, freqs)
    phase_freq_y, k_freq_y, r2_y, ab_y, r2_med_y = find_periodic(signal_y, freqs)
    
    ratio = r2_x ** 2 / (r2_x ** 2 + r2_y ** 2).clip(min=1e-8)
    freq = ratio * k_freq_x + (1 - ratio) * k_freq_y
    _, ax = plt.subplots(1, 1)
    ax.plot(freq)
    plt.show()
    
    mov = estimate_oob(signal_x, signal_y, r2_x, r2_y, r2_med_x, r2_med_y, ab_x, ab_y, pos_x, pos_y)
    
    shift = find_discont_piecewise(pos_x, pos_y, mov)
    
    return 1, 2, 3, 4, 5, 6, shift
    r2 = np.square(r2_x) + np.square(r2_y)
    N = r2_x.shape[0]
    # Y = np.sum(r2 * np.arange(N)[:, None], axis=0)
    # W = np.sum(r2, axis=0)
    # x = np.arange(k) - k // 2
    # o = np.ones(k)
    # x2 = np.square(x)
    # a = correlate(W, x2, "valid")
    # d = correlate(W, o, "valid")
    # bc = correlate(W, x, "valid")
    # det = a * d - np.square(bc)
    # xwy = correlate(Y, x, "valid")
    # owy = correlate(Y, o, "valid")
    # intercept = (a * owy - xwy * bc) / det
    # plotx = np.arange(Y.size - k + 1) + k // 2
    # allx = np.arange(Y.size)
    
    # essentially a gaussian blur...
    blurx = np.exp(-0.5 * np.square((np.arange(4 * bx + 1) - 2 * bx) / bx))
    blury = np.exp(-0.5 * np.square((np.arange(4 * by + 1) - 2 * by) / by))
    gbd = sepfir2d(r2, blurx, blury)
    peaks = np.logical_and(gbd[1:-1] > gbd[:-2], gbd[1:-1] > gbd[2:])
    right = np.logical_and(peaks[:, :-1], peaks[:, 1:])
    up = np.logical_and(peaks[1:, :-1], peaks[:-1, 1:])
    down = np.logical_and(peaks[:-1, :-1], peaks[1:, 1:])
    ends = np.copy(peaks)
    ends[:, :-1] &= np.logical_not(right)
    ends[1:, :-1] &= np.logical_not(up)
    ends[:-1, :-1] &= np.logical_not(down)
    end_idxs = np.argwhere(ends)
    n_ends = end_idxs.shape[0]
    trace_scores = np.zeros((n_ends + 1, 2))
    r2pk = np.zeros_like(peaks, dtype=float)
    r2pk[peaks] = r2[1:-1][peaks]
    pkidx = np.zeros_like(peaks, dtype=int)
    traces = [None] * n_ends
    for idx in trange(1, n_ends + 1):
        y, x = end_idxs[idx - 1]
        pkidx[y, x] = idx
        traces[idx - 1] = (x, y, [y])
        while x > 0:
            l = peaks[y, x - 1]
            u = y > 0 and peaks[y - 1, x - 1]
            d = y < N - 3 and peaks[y + 1, x - 1]
            sr2pk = np.sum(r2pk[max(0, y - 1):min(y + 1, N - 2), x - 1])
            if l or u or d:
                trace_scores[idx, 0] += 1.0
                trace_scores[idx, 1] += sr2pk
                x -= 1
                y += (1 if d else 0) - (1 if u else 0)
                pkidx[y, x] = idx
                traces[idx - 1][2].append(y)
            else:
                break
    scores = trace_scores[:, 1] / np.maximum(np.sqrt(trace_scores[:, 0]), 1.0)
    score_pic = scores[pkidx]
    maxes = np.max(score_pic, axis=0)
    winners = set(maxes.tolist())
    main_trace = np.zeros_like(peaks)
    rem_traces = []
    for val in winners:
        if val == 0.0:
            continue
        main_trace |= score_pic == val
        trace_idx = np.nonzero(scores == val)[0][0] - 1
        trace = traces[trace_idx][2]
        trace_len = len(trace)
        max_count = np.count_nonzero(maxes == val)
        if max_count < 300 or trace_len < 300:
            continue
        print(val, trace_idx, trace_len)
        trace_array = np.zeros((2, trace_len))
        trace_array[1] = np.flip(np.array(trace))
        trace_array[0] = np.arange(trace_len) + traces[trace_idx][0] - trace_len + 1
        rem_traces.append(trace_array)
    
    r2normpk = np.max(gbd, axis=0) / np.sum(blurx) / np.sum(blury) / 2.0
    stat = 4.12 * r2normpk
    p_oob = 1.0 - np.square(np.square(stat)) / (np.square(np.square(stat)) + 1.0)
    
    r2norm = gbd / np.sum(gbd, axis=0).clip(min=1e-9)
    mean = np.dot(np.arange(N), r2norm)
    xmm = np.arange(N)[:, None] - mean[None, :]
    var = np.sum(np.square(xmm) * r2norm, axis=0).clip(min=1e-9)
    mu4 = np.sum(np.square(np.square(xmm)) * r2norm, axis=0)
    kurt = mu4 / np.square(var)
    
    _, ax = plt.subplots(1, 1)
    ax.imshow(r2norm, aspect="auto")
    for trace in rem_traces:
        ax.plot(trace[0], trace[1] + 1.0, c="r", linestyle=":")
    # ax.plot(r2normpk * 100.0)
    ax.plot(p_oob * 10)
    ax.plot(mean)
    ax.plot(np.sqrt(var) + 10)
    ax.plot(kurt + 20)
    plt.show()
    
    return fr2_x, fr2_y, lgs, rem_traces, r2, p_oob


def find_periodic(signal, freqs):
    m = 3.0
    tl = signal.size
    fs = 5
    coeffs = np.zeros((freqs.size, tl, 4))
    r2 = np.zeros((freqs.size, tl))
    lgs = np.around(m * fs / freqs)
    sss = np.zeros((freqs.size, tl))
    sqsig = np.square(signal)
    # for fdx, freq in enumerate(freqs):
    #     lg = round(m * fs / freq)
    #     cos = np.cos(2.0 * m * np.pi * np.arange(lg) / lg)
    #     sin = np.sin(2.0 * m * np.pi * np.arange(lg) / lg)
    #     dcos = np.cos(4.0 * m * np.pi * np.arange(lg) / lg)
    #     dsin = np.sin(4.0 * m * np.pi * np.arange(lg) / lg)
    #     a = np.correlate(signal, cos) * 2.0 / lg
    #     b = np.correlate(signal, sin) * 2.0 / lg
    #     c = np.correlate(signal, dcos) * 2.0 / lg
    #     d = np.correlate(signal, dsin) * 2.0 / lg
    #     abcd = np.stack([a, b, c, d], axis=1)
    #     abmag = np.maximum(np.sum(np.square(abcd[:, :2]), axis=1), 1e-9)
    #     cdmag = np.sum(np.square(abcd[:, 2:]), axis=1)
    #     ratio = np.maximum(cdmag / abmag, 1.0)
    #     abcd[:, 2:] /= ratio[:, None]
    #     sss[fdx, :(tl - lg + 1)] = np.correlate(sqsig, np.ones(lg))  # sum of square signal
    #     # can derive this expression, all terms of sse (acos(wt)+bsin(wt)... - sig)^2
    #     # all come out to be 0 or -N/2 * the square of (a or b or c or d)
    #     part_sse = lg * 0.5 * np.sum(np.square(abcd), axis=1)
    #     # sse = sss - part_sse, r2 = 1 - sse / sss -> simplified it
    #     # only works if we assume zero mean, which I do in the fit.
    #     r2[fdx, :(tl - lg + 1)] = part_sse / sss[fdx, :(tl - lg + 1)]
    #     coeffs[fdx, :(tl - lg + 1)] = abcd
    # best_freqs = np.argmax(correlate1d(r2, np.ones(75), mode="nearest"), axis=0)
    
#     m = 1.0
#     k_0 = int(0.5 * (m * fs / freqs[0]))
#     k_m1 = int(0.5 * (m * fs / freqs[-1]))
#     k_freqs = (np.arange(k_0, k_m1 + 1) * 2.0 + 1.0)
#     s_ab = np.zeros((k_freqs.size, tl, 2))
#     s_r2 = np.zeros((k_freqs.size, tl))
#     for kdx in range(k_m1 - k_0 + 1):
#         k = kdx + k_0
#         lg = 2 * k + 1
#         rng = np.arange(lg) - k
#         cos = np.cos(2.0 * np.pi * rng / lg)
#         sin = np.sin(2.0 * np.pi * rng / lg)
#         dcos = np.cos(4.0 * np.pi * rng / lg)
#         dsin = np.sin(4.0 * np.pi * rng / lg)
#         a = np.correlate(signal, cos, "same") * 2.0 / lg
#         b = np.correlate(signal, sin, "same") * 2.0 / lg
#         c = np.correlate(signal, dcos, "same") * 2.0 / lg
#         d = np.correlate(signal, dsin, "same") * 2.0 / lg
#         abcd = np.stack([a, b, c, d], axis=1)
#         abmag = np.maximum(np.sum(np.square(abcd[:, :2]), axis=1), 1e-9)
#         cdmag = np.sum(np.square(abcd[:, 2:]), axis=1)
#         ratio = np.maximum(cdmag / abmag, 1.0)
#         abcd[:, 2:] /= ratio[:, None]
#         sss = np.correlate(sqsig, np.ones(lg), "same")  # sum of square signal
#         # can derive this expression, all terms of sse (acos(wt)+bsin(wt)... - sig)^2
#         # all come out to be 0 or -N/2 * the square of (a or b or c or d)
#         part_sse = lg * 0.5 * np.sum(np.square(abcd), axis=1)
#         # sse = sss - part_sse, r2 = 1 - sse / sss -> simplified it
#         # only works if we assume zero mean, which I do in the fit.
#         s_r2[kdx] = part_sse / sss
#         s_ab[kdx] = abcd[:, :2]
#     best_ks = np.argmax(s_r2, axis=0)[None,:,None]
#     ab = np.take_along_axis(s_ab, best_ks, axis=0).squeeze()
#     invert = np.zeros((*ab.shape, 2))
#     invert[:, 0, :] = ab
#     invert[:, 1, 0] = -ab[:, 1]
#     invert[:, 1, 1] = ab[:, 0]
#     invert /= linalg.norm(ab, axis=1).clip(min=1e-9)[:, None, None]
#     rotated = np.matmul(invert[:-1], ab[1:, :, None]).squeeze()
#     phase_diffs = np.arctan2(rotated[:, 1], rotated[:, 0])
    
#     _, ax = plt.subplots(1, 1)
#     ax.plot(phase_diffs * 1e-3)
#     ax.plot(ab, linestyle=":")
#     ax.plot(signal, c="k")
#     ax.plot(np.arctan2(ab[:, 1], ab[:, 0]) * 1e-3, c="m", alpha=0.1)
#     plt.show()
    
    # doing this again, but with weighted least squares - 1.5 wavelengths, but weighted by exp(-0.5 * (2x/lg)^2)
    # this effectively shrinks the wings down (and the function to fit), making errors there count for less
    # the weight matrix is less simple, but the math is not much more complicated
    m = 2.0
    N = 50
    k_0 = int(0.5 * m * fs / freqs[0])
    k_m1 = int(0.5 * m * fs / freqs[-1])
    k_exact = np.linspace(k_0, k_m1, N)
    w_ab = np.zeros((N, tl, 2))
    ssws = np.zeros((N, tl))
    swbfms = np.zeros((N, tl))
    sqsws = np.zeros((N, tl))
    swdp = np.zeros((N, tl))
    w_r2 = np.zeros((N, tl))
    for kdx in trange(N):
        ke = k_exact[kdx]
        lge = 2.0 * ke + 1.0
        k = round(ke)
        lg = 2 * k + 1
        rng = np.arange(lg) - k
        cos = np.cos(2.0 * m * np.pi * rng / lge)
        sin = np.sin(2.0 * m * np.pi * rng / lge)
        dcos = np.cos(4.0 * m * np.pi * rng / lge)
        dsin = np.sin(4.0 * m * np.pi * rng / lge)
        exp = np.exp(-0.5 * np.square(3.0 * rng / ke))
        sqrt_exp = np.exp(-0.25 * np.square(3.0 * rng / ke))
        A_W = np.square(cos) * exp
        C_W = np.square(sin) * exp
        E_W = np.square(dcos) * exp
        F_W = np.square(dsin) * exp
        A = np.sum(A_W)
        B = np.sum(cos * dcos * exp)
        C = np.sum(C_W)
        D = np.sum(sin * dsin * exp)
        E = np.sum(E_W)
        F = np.sum(F_W)
        G = B ** 2 - A * E
        H = D ** 2 - C * F
        w = np.correlate(signal, cos * exp, "same")
        x = np.correlate(signal, sin * exp, "same")
        y = np.correlate(signal, dcos * exp, "same")
        z = np.correlate(signal, dsin * exp, "same")
        wxyz = np.stack([w, x, y, z], axis=1)
        # the solution, (X'T X')^-1 * X'T * y' - X'T X' is a 4x4 with a simple inverse:
        a = (B / G) * y - (E / G) * w
        b = (D / H) * z - (F / H) * x
        c = (B / G) * w - (A / G) * y
        d = (D / H) * x - (C / H) * z
        abcd = np.stack([a, b, c, d], axis=1)
        abmag = np.maximum(np.sum(np.square(abcd[:, :2]), axis=1), 1e-9)
        cdmag = np.sum(np.square(abcd[:, 2:]), axis=1)
        # ratio = np.maximum(cdmag / abmag * 100, 1.0)
        # sum of square weighted signal - (y')^2 = exp * (sig)^2
        ssws = np.correlate(sqsig, exp, "same")
        # sum of weighted best fit model, squared (X' b)^2 = exp * (cos...dsin)^2 * abcd^2
        swbfms = np.square(a) * A
        swbfms += np.square(b) * C
        swbfms += np.square(c) * E
        swbfms += np.square(d) * F
        # squared sum of whitened signal - see wiki article on weighted least squares in diagonal weighting case
        # this part is for adjusting the average - because of the weighting, can no longer assume mean = 0
        # cross term from (y' - avg(y'))^2 turns out to be -2 times the square of the average
        # -1/lg * (y * sqrt(exp))
        sqsws = np.square(np.correlate(signal, sqrt_exp / np.sqrt(lg), "same"))
        denom = ssws - sqsws
        # abcd[:, 2:] /= ratio[:, None]
        # sum of weighted dot product - cross term from (y' - X' b)^2
        # -2 * y' X' b = -2 * exp * sig * (cos...dsin) * abcd = -2 * wxyz * abcd
        swdp = np.sum(wxyz * abcd, axis=1)
        w_r2[kdx] = (2.0 * swdp - swbfms - sqsws) / denom
        w_ab[kdx] = abcd[:, :2]
        
    pbar = tqdm(total=8)
    pbar.set_description("finding best k values   ")
    # empirically, this blur is about 6 breaths long
    nbl = 37 * 3
    wr2_blur = correlate1d(w_r2, np.ones(nbl) / nbl, mode="nearest")
    k_best_spline = cspline1d(np.argmax(wr2_blur, axis=0).astype(float), 1.0)
    k_best_sm = np.around(k_best_spline).astype(int).clip(0, N - 1)
    best_ks = k_best_sm[None,:,None]
    pbar.update(1)
    pbar.set_description("creating rotation matrix")
    # ab is how  well cos and sin fit, respectively. From this we can determine phase.
    ab = np.take_along_axis(w_ab, best_ks, axis=0).squeeze()
    # start constructing a rotation matrix to rotate points by the inverse of that phase
    invert = np.zeros((*ab.shape, 2))
    invert[:, 0, :] = ab
    invert[:, 1, 0] = -ab[:, 1]
    invert[:, 1, 1] = ab[:, 0]
    invert /= linalg.norm(ab, axis=1).clip(min=1e-9)[:, None, None]
    # multiplying ab by invert with an offset of 1 gets the sin and cos components, but rotated
    pbar.update(1)
    pbar.set_description("calculating phase diffs ")
    rotated = np.matmul(invert[:-1], ab[1:, :, None]).squeeze()
    u_bnd = 2 * m * np.pi / (2 * k_0 + 1.0)
    l_bnd = 1.25 * m * np.pi / (2 * k_m1 + 1.0)
    # finally, grab the rotated sin and cos components and use arctan2 to get the phase difference
    # my math seems to have been wrong, so we need to multipy by -1
    phase_diffs = -np.arctan2(rotated[:, 1], rotated[:, 0])
    # phase_diffs = np.where(phase_diffs <= 0.0, phase_diffs + 2.0 * np.pi, phase_diffs)
    cumulative_phase = np.cumsum(phase_diffs.clip(l_bnd, u_bnd))
    
    pbar.update(1)
    pbar.set_description("estimating frequencies  ")
    mean_pd = np.mean(phase_diffs[np.logical_and(phase_diffs < u_bnd, phase_diffs > l_bnd)])
    avg_phase_offset = round(4.0 * np.pi / mean_pd) # how long would the mean phase differenc take to sum to 2 wavelengths?
    
    # provide an estimate of the instantaneous frequencies
    est_inst_freqs = np.zeros(tl)
    est_inst_freqs[:-1] = phase_diffs * (fs / (2.0 * np.pi))
    est_inst_freqs[-1] = est_inst_freqs[-2]
    
    # provide the second estimate of frequencies
    smooth_k_freqs = 0.5 * m * fs / (k_best_spline * (k_m1 - k_0) / N + k_0)
    
    pbar.update(1)
    pbar.set_description("dynamically warping time")
    # apply phase warp, find the error and then the r^2 value
    interp = np.interp(cumulative_phase - 2.0 * np.pi, cumulative_phase, np.arange(cumulative_phase.size))
    x_max = int(np.max(interp))
    resamp = np.interp(np.arange(1, x_max), interp, signal[:-1])
    err = signal[1:x_max] - resamp
    sse = np.correlate(np.square(err), np.ones(avg_phase_offset) / avg_phase_offset, mode="full")
    sst = np.correlate(sqsig[1:x_max], np.ones(avg_phase_offset) / avg_phase_offset, mode="full")
    r2_fast = 1.0 - sse / sst
    
    # go the other way, interpolate from the phase warp to the original
    pbar.update(1)
    pbar.set_description("reversing the warp      ")
    rev_interp = np.interp(interp, np.arange(tl), signal)
    
    rev_err = signal[:-1] - rev_interp
    rev_sse = np.correlate(np.square(rev_err), np.ones(avg_phase_offset) / avg_phase_offset, mode="full")
    rev_sst = np.correlate(sqsig[:-1], np.ones(avg_phase_offset) / avg_phase_offset, mode="full")
    rev_r2_fast = 1.0 - rev_sse / rev_sst
    
    gsn = [1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]
    gsn /= np.sum(gsn)
    
    pbar.update(1)
    pbar.set_description("calculating r^2         ")
    r2_max_fast = np.maximum(r2_fast[:1 - avg_phase_offset], r2_fast[avg_phase_offset - 1:])
    rev_r2_max_fast = np.maximum(rev_r2_fast[:1 - avg_phase_offset], rev_r2_fast[avg_phase_offset - 1:])
    fwd_r2 = np.correlate(np.concatenate((np.ones(4) * r2_max_fast[0], r2_max_fast, np.ones(tl - x_max + 3) * r2_max_fast[-1])), gsn, mode="same").clip(0.0)[3:-3]
    rev_r2 = np.correlate(np.concatenate((np.ones(3) * rev_r2_max_fast[0], rev_r2_max_fast, np.ones(4) * rev_r2_max_fast[-1])), gsn, mode="same").clip(0.0)[3:-3]
    
    skew_win = 8 * avg_phase_offset + 1
    r2_mu = np.correlate(r2_max_fast, np.ones(skew_win) / skew_win, mode="full")
    r2_mu = np.maximum(r2_mu[:1-skew_win], r2_mu[skew_win - 1:])
    # r2_sig = np.sqrt(np.correlate(np.square(r2_max_fast - r2_mu), np.ones(skew_win) / skew_win, mode="same")).clip(1e-7)
    # r2_mncub = np.correlate(np.square(r2_max_fast) * r2_max_fast, np.ones(skew_win) / skew_win, mode="same")
    # r2_skew = (r2_mncub - 3.0 * r2_mu * np.square(r2_sig) - np.square(r2_mu) * r2_mu) # / (np.square(r2_sig) * r2_sig)
    # r2_med = np.zeros(tl)
    # wk = 8 * avg_phase_offset
    
    # median filter - if the median is less than the segments marked periodic, it's probably the middle of a long period
    # of movement, or a period of being out of bed.
    # Anything less than 1 minute is likely to be filtered out by the median filter. This can be changed.
    periodic_r2 = np.maximum(fwd_r2, rev_r2)
    periodic_r2[0] = 0.0
    periodic_r2[-1] = 0.0
    r2_med = medfilt(periodic_r2, skew_win)
    
    pbar.update(1)
    pbar.set_description("getting best sinusoid r2")
    r2_sin = np.take_along_axis(w_r2, best_ks.reshape(1, -1), axis=0).squeeze()
    
    pbar.update(1)
    pbar.set_description("Done. Plotting..........")
    pbar.close()
    print(mean_pd, avg_phase_offset)
    
    # _, ax = plt.subplots(1, 1)
    # ax.plot(rng, signal[rng + x_igt])
    # ax.plot(interpd, signal[rng + x_nxt])
    # ax.plot(rng, signal[rng + x_nxt], "k:")
    # ax.plot(rng, (cumulative_phase[rng + x_igt] - ph_igt) * 2e-5)
    # ax.plot(rng, (cumulative_phase[rng + x_nxt] - 2.0 * np.pi - ph_igt) * 2e-5)
    # plt.show()
    
    # _, ax = plt.subplots(1, 1)
    # ax.plot(lrng, cumulative_phase[lrng + x_igt])
    # ax.plot(rng, cumulative_phase[rng + x_nxt] - 2.0 * np.pi)
    # ax.plot(interpd, cumulative_phase[rng + x_nxt] - 2.0 * np.pi, "k:")
    # plt.show()
    
    
    _, ax = plt.subplots(1, 1)
    ax.plot(np.arange(cumulative_phase.size), signal[:-1])
    ax.plot(interp, signal[:-1], linestyle=":")
    ax.plot(np.arange(1, x_max), err, alpha=0.1)
    # ax.plot(np.arange(1, x_max), np.maximum(r2_max_fast, rev_r2_max_fast[1:x_max]) * 1e-3)
    ax.plot(np.arange(tl), fwd_r2 * 1e-3, linestyle=":")
    ax.plot(np.arange(tl), rev_r2 * 1e-3, linestyle=":")
    ax.plot(np.arange(tl), r2_med * 1e-3, alpha=0.2)
    ax.plot(np.arange(1, x_max), r2_mu * 1e-3, alpha=0.2)
    ax.plot(r2_sin * 1e-3, alpha=0.1)
    plt.show()
    
    # _, ax = plt.subplots(1, 1)
    # ax.plot(cos)
    # ax.plot(sin)
    # ax.plot(cos * np.sqrt(exp))
    # ax.plot(sin * np.sqrt(exp))
    # plt.show()
    
    # _, ax = plt.subplots(1, 1)
    # ax.plot(phase_diffs * 1e-3)
    # ax.plot(ab, linestyle=":", alpha=0.1)
    # ax.plot(signal, c="k", alpha=0.1)
    # ax.plot(np.arctan2(ab[:, 1], ab[:, 0]) * 1e-3, c="r", alpha=0.1)
    # ax.axhline(u_bnd * 1e-3)
    # ax.axhline(l_bnd * 1e-3)
    # plt.show()
    
    # _, ax = plt.subplots(1, 1)
    # ax.imshow(wr2_blur.clip(min=0.0), aspect="auto")
    # ax.plot(best_ks.squeeze())
    # plt.show()
    
    # _, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True)
    # ax1.imshow(ssws[:, 12000:13000], aspect="auto")
    # ax2.imshow(swbfms[:, 12000:13000], aspect="auto")
    # ax3.imshow((2.0 * swdp[:, 12000:13000] - swbfms[:, 12000:13000] - sqsws[:, 12000:13000]), aspect="auto")
    # ax4.imshow(swdp[:, 12000:13000], aspect="auto")
    # plt.show()
    
#     final_r2 = np.zeros((tl, ))
#     final_lg = np.zeros((tl, ))
#     sigma = 2.0  # this gets us to 12 timesteps away... meaning we can span ~5 sec gaps
#     fs = 2.0
#     n_std = 6  # 6 gets us down to 1.5e-8, which is nonzero enough for me
#     c_normal = np.exp(-0.5 * np.square((np.arange(fs * sigma * n_std * 2 + 1) - (fs * sigma * n_std)) / sigma))
#     c_normal /= np.sum(c_normal)
#     for idx, fdx in trange(0):
# #     for idx, fdx in tqdm(enumerate(best_freqs), total=best_freqs.size, ncols=100):
#         offset = round(lgs[fdx] / m)
#         lg = int(lgs[fdx])
#         final_lg[idx] = lg
#         if idx + offset + lg >= tl:
#             break

#         # calculate dynamic time warp alignment
#         alignment = dtw(signal[idx + offset:idx + lg + offset], signal[idx:idx + lg], keep_internals=True)
#         # alignment.plot(type="twoway", offset=-2 * np.sqrt(np.sum(np.square(coeffs[fdx, idx]))))
#         # this will necessarily be a function even in the worst case
#         new_xs = 0.5 * (alignment.index1 + alignment.index2)
#         new_ys = alignment.index2 - alignment.index1
#         # interpolate it, then correlate, rotate it back, then perform the warp
#         # cut off the first and last bit, helping to guarantee it's in the interp range
#         ixs = np.arange(lg * 2 - 5) * 0.5 + 1.0
#         # 1-1e-8 is a bit of a kludge to get back a function but it's fine innit
#         iys = interp1d(new_xs, new_ys * (1.0 - 1e-8))(ixs)
#         max_dev = 5.0  # can be up to 1 second off
#         miys = np.mean(iys).clip(min=-max_dev, max=max_dev)
#         siys = iys.clip(miys - max_dev, miys + max_dev)
#         spys = correlate1d(siys, c_normal, mode="constant")
#         rix = ixs - 0.5 * spys
#         riy = ixs + 0.5 * spys
#         sse_dtw = np.sum(np.square(spys)) * 16.0  # abs(abs(x-30)-15)-7.5 -> r^2 = 0
#         sst_dtw = np.sum(np.square(ixs - np.mean(ixs)))
#         qys = interp1d(np.arange(lg), signal[idx + offset:idx + lg + offset])(rix)
#         tys = interp1d(np.arange(lg), signal[idx:idx + lg])(riy)
#         sse_fin = np.sum(np.square(qys - tys))
#         sst_fin = np.sum(np.square(tys - np.mean(tys)))
#         cr2 = 1.0 - sse_dtw / sst_dtw - sse_fin / sst_fin

#         if idx == 130:
#             _, (ax1, ax2) = plt.subplots(1, 2)
#             ax1.plot(rix, riy)
#             ax1.plot(alignment.index1, alignment.index2)
#             ax2.plot(np.arange(lg), signal[idx + offset:idx + lg + offset])
#             ax2.plot(riy, qys)
#             ax2.plot(riy, tys)
#             plt.show()

        # sse = np.sum(np.square(signal[idx:idx + lg] - signal[idx + offset:idx + lg + offset]))
        # cr2 = 1.0 - sse / max(sss[fdx, idx], sss[fdx, idx + offset])
        # final_r2[idx:idx + offset + lg] = np.maximum(final_r2[idx:idx + offset + lg], cr2)
    return est_inst_freqs, smooth_k_freqs, periodic_r2, ab, r2_med

def split_test(sig_a, sig_b, coords_a, coords_b, output=False):
    # prepare X matrix
    n_a = sig_a.size
    n_b = sig_b.size
    all_samples = np.concatenate((sig_a, sig_b))
    all_coords = np.concatenate((coords_a, coords_b), dtype=float)
    all_coords -= np.mean(all_coords)
    split_level = np.concatenate((np.full(n_a, -0.5), np.full(n_b, 0.5)))
    ones = np.ones(n_a + n_b)
    X = np.stack((all_coords, split_level, ones), axis=-1)
    
    # perform the fit, get the residuals
    xcov = np.linalg.inv(np.dot(X.T, X))
    beta = np.dot(np.dot(xcov, X.T), all_samples)
    res = all_samples - np.dot(X, beta)
    
    # calculate covariance - currently using estimate of sig
    sig2_est = 0.2 ** 2 # np.sum(np.square(res)) / (n_a + n_b - 3)
    cov = sig2_est * xcov
    
    # calculate statistic and probability
    T = np.abs(beta[1] / np.sqrt(cov[1, 1]))
    p = 2.0 - 2.0 * t_dist.cdf(T, n_a + n_b - 3) # two tailed
    
    if output:
        print(cov)
        print(beta)
        print(T, p)
        print(n_a, n_b)
        print(sig2_est)
        print(res)
        
        _, ax = plt.subplots(1, 1)
        ax.plot(np.concatenate((coords_a, coords_b), dtype=float), all_samples)
        ax.plot(np.concatenate((coords_a, coords_b), dtype=float), np.dot(X, beta))
        plt.show()
    
    return T, p

def get_mse(sig, winsize=75):
    c0 = correlate(np.square(sig), np.ones(winsize)) / winsize
    return np.minimum(c0[:1 - winsize], c0[winsize - 1:])

def get_mse_fracs(sig, winsize=75):
    mse0 = get_mse(sig[:, 0], winsize)
    mse1 = get_mse(sig[:, 1], winsize)
    return mse0 / (mse0 + mse1), mse1 / (mse0 + mse1)

def get_cdf_diff(residuals_1, residuals_2, plot=False):
    ord1 = np.sort(residuals_1)
    ord2 = np.sort(residuals_2)
    n = residuals_1.size
    m = residuals_2.size
    start = 0
    if ord1[0] < ord2[0]:
        start = np.max(np.argwhere(ord1 < ord2[0]).squeeze())
        sup = (start + 1) / n
#         print(1, sup)
    elif ord2[0] < ord1[0]:
        sup = np.max(np.argwhere(ord2 < ord1[0]).squeeze()) / m
#         print(2, sup)
    else:
        sup = 0.0
    
    for idx in range(start, n-1):
        lw = ord2 <= ord1[idx]
        up = ord2 < ord1[idx + 1]
        if np.count_nonzero(up) == 0:
            # none of them are less than the upper bound
            sup = max(sup, (idx + 1) / n)
#             print(3, sup)
        else:
            if np.count_nonzero(lw) == 0:
                ldx = 0
            else:
                ldx = np.max(np.argwhere(lw).squeeze())
            udx = np.max(np.argwhere(up).squeeze())
            for jdx in range(ldx, udx + 1):
                sup = max(sup, abs((idx +  1) / n - (jdx + 1) / m))
    if ord1[-1] < ord2[-1]:
        lw = np.max(np.argwhere(ord2 <= ord1[-1]).squeeze()) / m
        sup = max(sup, 1.0 - lw)
#         print(5, sup, lw)
    if plot:
        _, ax = plt.subplots(1, 1)
        o1e = np.concatenate((np.ones(1) * ord1[0] - 0.15, np.repeat(ord1, 2), np.ones(1) * ord1[-1] + 0.15))
        o2e = np.concatenate((np.ones(1) * ord2[0] - 0.15, np.repeat(ord2, 2), np.ones(1) * ord2[-1] + 0.15))
        ax.plot(o1e, np.repeat(np.arange(n + 1) / n, 2))
        ax.plot(o2e, np.repeat(np.arange(m + 1) / m, 2))
        plt.show()
#     print(sup)
    return sup

def fit_line(samples, coords):
    A = np.stack((coords, np.ones(coords.shape)), axis=-1)
    pi = np.linalg.pinv(A)
    return np.dot(pi, samples)

def get_residuals(samples, coords, mb):
    A = np.stack((coords, np.ones(coords.shape)), axis=-1)
    pred = np.dot(A, mb)
    return samples - pred

# Kolmogorov-Smirnov test - are sample_a and sample_b continuous?
# Specifically, what is the probability that a and b are fit by a single trendline just as well as by two separate ones
# Returns the likelihood that a and b are continuous
def ks_test(sample_a, sample_b, coords_a, coords_b, plot=False):
    # concatenate y and t
    all_samples = np.concatenate((sample_a, sample_b))
    all_coords = np.concatenate((coords_a, coords_b))
    # fit lines to the first samples, the second samples, and the concatenated samples
    mba = fit_line(sample_a, coords_a)
    mbb = fit_line(sample_b, coords_b)
    mbt = fit_line(all_samples, all_coords)
    # find the residuals for the above fits
    res_a = get_residuals(sample_a, coords_a, mba)
    res_b = get_residuals(sample_b, coords_b, mbb)
    res_t = get_residuals(all_samples, all_coords, mbt)
    # find the maximum cdf difference in the distribution of residuals
    cdfd_t = get_cdf_diff(np.concatenate((res_a, res_b)), res_t, plot)
    n = all_samples.size
    # calculate the probability that they are continuous (lower means less likely to be continuous)
    p = 2.0 * np.exp(-n * cdfd_t**2)
    return min(0.999, p) # 0.999 - important for bayesian updates


In [3]:

# ricky = ricker(101, 10)
# morry = morlet2(101, 12, 1.7) * 2743 / 2168
# _, ax = plt.subplots(1, 1)
# ax.plot(ricky, c="r")
# ax.plot(morry, c="k")
# plt.show()

# %%
# 211107_015105 -- this is a good one
# 211102_003909 -- another goodie. Lots of sleep, some time not in bed at the beginning.
#               -- No wake up though.
# 211101_002730 -- Excellent. 5 Sleep cycles visible. One spot not flipped right.
######
# the above ones are all old
# 220103_232249 - this one has a long gap out of bed, very disturbed sleep, but a few deep sleep and REM blobs
# 220105_005821 - pretty solid sleep, clear sleep cycles visible relatively evenly spaced, more deep sleep at the start and more REM at the end.
# 
dt = "220119_010917"

gl = sorted(glob(f"sleepypi/run{dt}/*.pkl.gz"))

streams = []
times = []

# get timezone offset for local START of night, so that DST is handled appropriately
uctdiff = datetime.strptime(dt, "%y%m%d_%H%M%S").astimezone().utcoffset()
tzoffset = (uctdiff.days * 86400 + uctdiff.seconds) * 1000  # timezone offset from utc

for idx in trange(len(gl)):
    with gzip.open(gl[idx], "rb") as f:
        p = pickle.load(f)
        data_stream, fhist, gz, fri, tstamps, video = p
        streams.append(data_stream)
        times.append(tstamps.astype(np.int64) * 50 + 1609459200000 + tzoffset)
#         with VideoWriter(gl[idx][:-6] + "mp4") as vid:
#             vid.from_array(video)

n = np.concatenate(streams, axis=0)
# convert times back to epoch time in milliseconds, then to np.datetime64 in ms
timestamps = np.concatenate(times, axis=0).astype("<M8[ms]")
print(n.shape)

plt.style.use('dark_background')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:05<00:00,  6.65it/s]

(161815, 10)





In [4]:

# rms0 = np.sqrt(get_mse(n[:, 0]))
# rms1 = np.sqrt(get_mse(n[:, 1]))
# rms_comb = np.sqrt(np.square(rms0) + np.square(rms1))

# sig0, sm0 = adjust_for_zero_trend(n[:, 0])
# sig1, sm1 = adjust_for_zero_trend(n[:, 1])
# # sig0 = n[:, 0]
# # sig1 = n[:, 1]


In [5]:
plt.close("all")
omega = 10.0
fs = 5.0
freqs = np.logspace(0.1, -1.4, 150)  # indices ~50-85 are breathing frequencies
# widths_morlet = omega * fs / (freqs * 2 * np.pi)
# z_wave = wavefinding_cwt(np.cumsum(sig0 + 1.0j * sig1), widths_morlet, omega)
# mags_z = np.abs(z_wave) / rms_comb.clip(1e-5)

# _, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True)
# # ax1.plot(sig0)
# # ax1.plot(sig1)
# ax1.plot(n[:, 0])
# ax1.plot(n[:, 1])
# ax1.plot(rms0, linestyle=":")
# ax1.plot(rms1, linestyle=":")
# # ax1.plot(sm0, linestyle=":")
# # ax1.plot(sm1, linestyle=":")
# ax2.plot(np.cumsum(sig0))
# ax2.plot(np.cumsum(sig1))
# ax3.imshow(mags_z.clip(max=np.percentile(mags_z, 98.0)), aspect="auto")
# plt.show()

In [6]:
plt.close("all")

sm = 81
pos_x = n[:, 4]
sm_pos_x = medfilt(pos_x, sm)
pos_y = n[:, 5]
sm_pos_y = medfilt(pos_y, sm)

mx_lead, mx_lag, bx_lead, bx_lag, rmsx_lead, rmsx_lag, r2x_lead, r2x_lag = fit_linear_lead_and_lag(sm_pos_x, winsize=51)
my_lead, my_lag, by_lead, by_lag, rmsy_lead, rmsy_lag, r2y_lead, r2y_lag = fit_linear_lead_and_lag(sm_pos_y, winsize=51)

k = 40
a = k * (k + 1) * (2 * k + 1) * (3 * k ** 2 + 3 * k + 1) / 15.0
b = k * (k + 1) * (2 * k + 1) / 3.0
c = 2 * k + 1
det = a * c - b ** 2
xs = np.arange(c, dtype=float) - k

xstx0 = correlate1d(sm_pos_x, np.square(xs), mode="nearest")
xstx1 = correlate1d(sm_pos_x, xs, mode="nearest")
xstx2 = correlate1d(sm_pos_x, np.ones(c), mode="nearest")
xsty0 = correlate1d(sm_pos_y, np.square(xs), mode="nearest")
xsty1 = correlate1d(sm_pos_y, xs, mode="nearest")
xsty2 = correlate1d(sm_pos_y, np.ones(c), mode="nearest")

crv_x = xstx0 * c / det - xstx2 * b / det
crv_y = xsty0 * c / det - xsty2 * b / det

slp_x = xstx1 / b
slp_y = xsty1 / b

se = 2.0 * np.sqrt(c / det)
ses = 8.0 / np.sqrt(b)

p_x = 2.0 * t_dist.cdf(np.abs(crv_x) / se, c - 3) - 1.0
p_y = 2.0 * t_dist.cdf(np.abs(crv_y) / se, c - 3) - 1.0

ps_x = 2.0 * t_dist.cdf(np.abs(slp_x) / ses, c - 3) - 1.0
ps_y = 2.0 * t_dist.cdf(np.abs(slp_y) / ses, c - 3) - 1.0

rmsx = np.minimum(rmsx_lead, rmsx_lag)
rmsy = np.minimum(rmsy_lead, rmsy_lag)

# fit the step function, just see how it goes...

m = 401 # should be odd
# first 0.5 is for half the step size (one half is negative the other is positive)
# 4 / 2m is the coefficient in the matrix inverse, so we're still missing a factor of
inv_coeff = 0.5 * 2.0 / m
stfn = np.concatenate((np.full(m, -inv_coeff), np.zeros(m), np.full(m, inv_coeff)))
step_x = correlate1d(sm_pos_x, stfn, mode="nearest")
step_y = correlate1d(sm_pos_y, stfn, mode="nearest")

sest = 50.0 * np.sqrt(2.0 / m)

pst_x = 2.0 * t_dist.cdf(np.abs(step_x) / sest, 2 * m - 2) - 1.0
pst_y = 2.0 * t_dist.cdf(np.abs(step_y) / sest, 2 * m - 2) - 1.0

dx = np.abs(pos_x - sm_pos_x)
dy = np.abs(pos_y - sm_pos_y)

med_dx = np.median(dx)
med_dy = np.median(dy)
x_cond = dx > 12.0 * med_dx
y_cond = dy > 12.0 * med_dy
dx_extr = np.where(x_cond, 10.0, 0.0)
dy_extr = np.where(y_cond, 8.0, 0.0)

exp_k = 4
either = np.logical_or(x_cond, y_cond).astype(float)
expand = np.concatenate((np.full(exp_k, either[0]), either, np.full(exp_k, either[-1])))
exp_cs = np.cumsum(expand)
mov = exp_cs[2 * exp_k:] - exp_cs[:-2 * exp_k] > 0.5
mov_float = medfilt(mov.astype(float), 2 * exp_k + 1)

_, (ax, ax2, ax3) = plt.subplots(3, 1, sharex=True)
ax.plot(pos_x, label="nx")
ax.plot(pos_y, label="ny")
ax.plot(sm_pos_x, label="medx")
ax.plot(sm_pos_y, label="medy")
ax.plot(mov_float * 10.0, label="mov")
ax.plot(either * 8.0, label="either")
# ax2.plot(crv_x, label="crvx")
# ax2.plot(crv_y, label="crvy")
ax2.plot(step_x, label="x")
ax2.plot(step_y, label="y")
# ax3.plot(p_x, label="px", linestyle=":")
# ax3.plot(p_y, label="py", linestyle=":")
# ax3.plot(ps_x, label="psx")
# ax3.plot(ps_y, label="psy")
ax3.plot(pst_x, label="x")
ax3.plot(pst_y, label="y")
ax.legend()
ax2.legend()
ax3.legend()
plt.show()



In [13]:
plt.close("all")
r2_x, r2_y, lgs, freq_traces, r2_sin, p_notinbed, p_shift = find_periodic_fit_best(n[:, 0], n[:, 1], freqs[40:90], n[:, 4], n[:, 5], 3001, 301)
_, ax = plt.subplots(1, 1)
ax.plot(pos_x, label="nx")
ax.plot(pos_y, label="ny")
ax.plot(sm_pos_x, label="medx")
ax.plot(sm_pos_y, label="medy")
ax.plot(p_shift * 10.0, label="shift", alpha=0.5, linestyle=(0,(1,1,3,1)))
ax.legend()
plt.show()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 17.83it/s]
Done. Plotting..........: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.95it/s]


0.32682838902755607 38


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 19.18it/s]
Done. Plotting..........: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.94it/s]


0.3310444677134384 38
0 10 132 745 0.014442281543560309 0.5514590012874816
10 13 745 1130 0.0019042894259631016 1.8981261049217697e-05
13 15 1130 1687 0.17685102687725207 5.461469737166765e-08
15 16 1687 2747 0.4873163476080391 0.9657448788258233
16 20 2747 3133 0.7412770189317792 0.9021820974918648
20 22 3133 4100 0.6201431718295891 0.4771955461016811
22 23 4100 4561 0.799459932329911 0.9651568686913149
23 36 4561 6111 0.3724845369941893 0.12372086016603634
36 38 6111 6393 0.9905858587406322 0.031295638757127264
38 40 6393 6931 0.9384429104769959 0.8049980163047269
40 41 6931 8162 0.7081660380150334 0.9757123140904198
41 42 8162 9185 0.9154878354929528 0.6577133184065
42 43 9185 9912 1.4561903016740985e-08 7.380929688034026e-36
43 45 9912 10872 0.9563862780835134 0.8362832223681058
45 46 10872 11313 0.6905532832292556 0.6268903922886249
46 49 11313 14253 0.22556263577825175 8.399658536443871e-05
49 51 14253 18616 6.865992508753305e-05 7.4687360424790345e-47
51 52 18616 18839 0.0004280

In [99]:
plt.close("all")
x_env = np.abs(hilbert(n[:, 0]))
y_env = np.abs(hilbert(n[:, 1]))

fx, fy = get_mse_fracs(n)

cutoff = 0.45

r2 = r2_x * fx + r2_y * fy
count = np.sum(r2 >= cutoff)
rmses = np.sqrt(np.sum((r2 >= cutoff).astype(float) * (np.square(n[:, 0]) + np.square(n[:, 1]))) / count)
condit = np.logical_and(r2 >= cutoff, np.logical_and(np.abs(n[:, 0]) < rms0 * 3.5, np.abs(n[:, 1]) < rms1 * 3.5))
starts = np.argwhere(np.logical_and(condit[1:], np.logical_not(condit[:-1]))).squeeze()
stops = np.argwhere(np.logical_and(np.logical_not(condit[1:]), condit[:-1])).squeeze()
lrg_cond = np.argwhere(stops - starts > 35).squeeze()
starts = starts[lrg_cond]
stops = stops[lrg_cond]
sm_cond = np.argwhere(starts[1:] - stops[:-1] > 2).squeeze()
starts = starts[np.concatenate((np.array([0]), sm_cond + 1))]
stops = stops[np.concatenate((sm_cond, np.array([stops.size - 1])))]
    

keep_start = np.ones(starts.size, dtype=bool)
keep_stop = np.ones(starts.size, dtype=bool)
mov_or_sigh = np.zeros(starts.size - 1)

# first trim the ones that are too short (<=1.2s), and likely not movement or sighing
# note, this will merge a lot of segments within periods of being out-of-bed
# we'll remedy that later...
for idx in range(starts.size - 1):
    # get the lengths of the segments
    lg_dur = starts[idx + 1] - stops[idx]
    short_pen = 0.5 + 0.5 * np.tanh(0.33 * lg_dur)
    rms_bef = np.sqrt(np.mean(np.square(n[starts[idx]:stops[idx], :2])))
    rms_dur = np.sqrt(np.mean(np.square(n[stops[idx]:starts[idx + 1], :2])))
    rms_aft = np.sqrt(np.mean(np.square(n[starts[idx + 1]:stops[idx + 1], :2])))
    ratio = (rms_dur * 2.0 / (rms_bef + rms_aft) - 1.0).clip(min=0.0)
    vel_stat = np.square(np.tanh(2.0 * ratio)) * short_pen
    mov_or_sigh[idx] = vel_stat
        

keep_start[1:] = mov_or_sigh > 0.5
keep_stop[:-1] = mov_or_sigh > 0.5
contig_starts = starts[keep_start]
contig_stops = stops[keep_stop]
discont = np.zeros(contig_starts.size - 1)

for idx in range(contig_starts.size - 1):
    # get the lengths of the segments
    lg_bef = contig_stops[idx] - contig_starts[idx]
    lg_dur = contig_starts[idx + 1] - contig_stops[idx]
    lg_aft = contig_stops[idx + 1] - contig_starts[idx + 1]
    # segments should not differ in length by TOO much, or else the KS test will fail
    use_lg = 2 * min(250, lg_bef, lg_aft)
    bef_b = contig_stops[idx] - min(use_lg, lg_bef)
    aft_e = contig_starts[idx + 1] + min(use_lg, lg_aft)
    bef_t = np.arange(bef_b, contig_stops[idx], dtype=float)
    aft_t = np.arange(contig_starts[idx + 1], aft_e, dtype=float)
    p_x = ks_test(n[bef_b:contig_stops[idx], 2], n[contig_starts[idx + 1]:aft_e, 2], bef_t, aft_t)
    p_y = ks_test(n[bef_b:contig_stops[idx], 3], n[contig_starts[idx + 1]:aft_e, 3], bef_t, aft_t)
    discont[idx] = 1.0 - p_x * p_y
    
# Okay, now we decide what corresponds to not being in bed
# Although, as a first pass, what we have here is already good enough!
# The non-periodic segments during out-of-bed periods are largely merged,
#  save for where I get out of and into bed
for idx in range(contig.size):
    p_oob = np.mean(p_notinbed[starts[idx]:stops[idx]])
    bef_b = 0 if idx == 0 else stops[idx - 1]
    aft_e = -1 if idx == starts.size - 1 else starts[idx + 1]
    rms_bef = np.sqrt(np.mean(np.square(n[bef_b:starts[idx], :2])) * 2.0)
    rms_dur = np.sqrt(np.mean(np.square(n[starts[idx]:stops[idx], :2])) * 2.0)
    rms_aft = np.sqrt(np.mean(np.square(n[stops[idx]:aft_e, :2])) * 2.0)
    # a few criteria to classify a segment as not in bed and delete it:
    #   1. p_oob has to be pretty low
    #   2. rms of the segment has to be pretty low

    
mx_lead, mx_lag, bx_lead, bx_lag, rmsx_lead, rmsx_lag = fit_linear_lead_and_lag(n[:, 2], winsize=155)
my_lead, my_lag, by_lead, by_lag, rmsy_lead, rmsy_lag = fit_linear_lead_and_lag(n[:, 3], winsize=155)
mw_lead, mw_lag, bw_lead, bw_lag, rmsw_lead, rmsw_lag = fit_linear_lead_and_lag(n[:, 4], winsize=31)
mz_lead, mz_lag, bz_lead, bz_lag, rmsz_lead, rmsz_lag = fit_linear_lead_and_lag(n[:, 5], winsize=31)

x_rms_frac = rmsx_lead / (rmsx_lead + rmsx_lag)
y_rms_frac = rmsy_lead / (rmsy_lead + rmsy_lag)
bx_best = x_rms_frac * bx_lag + (1.0 - x_rms_frac) * bx_lead # np.where(rmsx_lead < rmsx_lag, bx_lead, bx_lag)
by_best = y_rms_frac * by_lag + (1.0 - y_rms_frac) * by_lead # np.where(rmsy_lead < rmsy_lag, by_lead, by_lag)
w_rms_frac = rmsw_lead / (rmsw_lead + rmsw_lag)
z_rms_frac = rmsz_lead / (rmsz_lead + rmsz_lag)
bw_best = w_rms_frac * bw_lag + (1.0 - w_rms_frac) * bw_lead # np.where(rmsx_lead < rmsx_lag, bx_lead, bx_lag)
bz_best = z_rms_frac * bz_lag + (1.0 - z_rms_frac) * bz_lead # np.where(rmsy_lead < rmsy_lag, by_lead, by_lag)

b_best = np.stack((bx_best, by_best), axis=-1)

# res_x = rpt.KernelCPD(kernel="linear", min_size=100).fit(n[:, 2:4]).predict(pen=2)
# res_y = rpt.KernelCPD(kernel="linear", min_size=100).fit(n[:, 4:6]).predict(pen=25)
# res_x = rpt.Window(width=100, model="l2", jump=2).fit_predict(n[:, 2:4], pen=np.log(n.shape[0]) * 2 * 0.1**2)

_, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
ax1.plot(timestamps, my_lead)
ax1.plot(timestamps, my_lag)
ax1.plot(timestamps, bx_best)
ax1.plot(timestamps, by_best)
ax1.plot(timestamps, bw_best)
ax1.plot(timestamps, bz_best)
ax1.plot(timestamps, rmsy_lead.clip(max=10.0) * 20)
ax1.plot(timestamps, rmsy_lag.clip(max=10.0) * 20)
ax1.plot(timestamps, n[:, 2:4], alpha=0.6)
ax1.plot(timestamps, n[:, 4:6], alpha=0.2)
# ax1.legend("mylead,mylag,bxbest,bybest,mrsylead,rmsylag,n2,n3,n4,n5".split(","))
ax2.plot(timestamps, n[:, 0:2])
ax2.axhline(y=rmses * 10.0)
ax2.axhline(y=rmses * -10.0)
ax2.plot(timestamps, r2 * 1e-3, c="k", alpha=0.3, linestyle=":")
# ax2.plot(timestamps, x_env)
# ax2.plot(timestamps, y_env)

# keep = np.ones(len(res_x), dtype=bool)

# for idx in range(len(res_x) - 2):
#     xs = np.arange(res_x[idx], res_x[idx + 2])
#     res_a = ks_test(n[res_x[idx]: res_x[idx + 2], 2], xs, res_x[idx + 1] - res_x[idx], 0.01, 8)
#     res_b = ks_test(n[res_x[idx]: res_x[idx + 2], 3], xs, res_x[idx + 1] - res_x[idx], 0.01, 8)
#     if res_a and res_b:
#         keep[idx + 1] = False

# res_x = np.array(res_x)[keep]

# keep = np.ones(len(res_x), dtype=bool)
# for idx in range(len(res_x) - 2):
#     xs = np.arange(res_x[idx], res_x[idx + 2])
#     res_a = ks_test(n[res_x[idx]: res_x[idx + 2], 2], xs, res_x[idx + 1] - res_x[idx], 0.02, 8)
#     res_b = ks_test(n[res_x[idx]: res_x[idx + 2], 3], xs, res_x[idx + 1] - res_x[idx], 0.02, 8)
#     if res_a and res_b:
#         keep[idx + 1] = False

# res_x = np.array(res_x)[keep]

for idx, ln in enumerate(res_x[:-1]):
    ax1.axvline(x=timestamps[ln])
    frac = np.count_nonzero(condit[ln:res_x[idx + 1]]) / (res_x[idx + 1] - ln)
    if frac < 0.5:
        ax1.axvspan(timestamps[ln], timestamps[res_x[idx + 1]], color="r", alpha=0.2, ec=None)
# for idx, ln in enumerate(res_y[:-1]):
#     ax1.axvline(x=timestamps[ln])

for idx in range(1, starts.size):
    ts = timestamps[np.array((stops[idx - 1], starts[idx]))]
    ax2.plot(ts, 0.01 * np.array((mov_or_sigh[idx - 1], mov_or_sigh[idx - 1])), c="m")
for idx in range(contig_starts.size):
    ax2.axvspan(timestamps[contig_starts[idx]], timestamps[contig_stops[idx]], color="g", alpha=0.2)
    if idx > 0:
        ts = timestamps[np.array((contig_stops[idx - 1], contig_starts[idx]))]
        ax2.plot(ts, 0.01 * np.array((discont[idx - 1], discont[idx - 1])), c="k", linestyle=":")
    
# for ln in res_y[:-1]:
#     ax1.axvline(x=timestamps[ln], c="r")

plt.show()

rms_slide_0 = np.zeros_like(rms0)
rms_slide_1 = np.zeros_like(rms1)
rms_avg_0 = np.zeros_like(starts, dtype=float)
rms_avg_1 = np.zeros_like(starts, dtype=float)
windows = np.zeros_like(rms0)
windows_wide = np.zeros_like(rms0)
rms_avg_0[2] = 1.5
parity = np.zeros_like(rms0)

for idx in range(starts.size):
    lg = stops[idx] - starts[idx]
    win = 1.0 / (1.0 + np.exp((np.abs(np.arange(lg) - lg / 2.0 - 0.5) + 4.0 - lg / 2.0).clip(min=-10.0)))
    win_wide = 1.0 / (1.0 + np.exp((0.3 * np.abs(np.arange(lg) - lg / 2.0 - 0.5) + 5.0 - 0.3 * lg / 2.0).clip(min=-10.0)))
    win_lop = 1.0 / (1.0 + np.exp((-0.015 * np.arange(lg) + 5.0).clip(min=-10.0)))
    winsum = np.sum(win)
    if winsum == 0:
        print(idx, lg, starts[idx], stops[idx], starts.size)
    windows[starts[idx]:stops[idx]] = win
    windows_wide[starts[idx]:stops[idx]] = win_wide * win_lop
    rms_avg_0[idx] = np.sqrt(np.sum(np.square(n[starts[idx]:stops[idx], 0]) * win) / winsum)
    rms_avg_1[idx] = np.sqrt(np.sum(np.square(n[starts[idx]:stops[idx], 1]) * win) / winsum)
    parity[idx] = np.sign(np.sum(n[starts[idx]:stops[idx], 0] * n[starts[idx]:stops[idx], 1]))
    rms_slide_0[starts[idx]:stops[idx]] = rms_avg_0[idx]
    rms_slide_1[starts[idx]:stops[idx]] = rms_avg_1[idx]
    if idx == 0:
        rms_slide_0[:starts[0]] = rms_avg_0[0]
        rms_slide_1[:starts[0]] = rms_avg_1[0]
    else:
        lg = starts[idx] - stops[idx - 1]
        rms_slide_0[stops[idx - 1]:starts[idx]] = np.linspace(rms_avg_0[idx - 1], rms_avg_0[idx], lg + 1)[:-1]
        rms_slide_1[stops[idx - 1]:starts[idx]] = np.linspace(rms_avg_1[idx - 1], rms_avg_1[idx], lg + 1)[:-1]
rms_slide_0[stops[-1]:] = rms_avg_0[-1]
rms_slide_1[stops[-1]:] = rms_avg_1[-1]

sig_norm_0 = n[:, 0] / rms_slide_0.clip(min=1e-9)
sig_norm_1 = n[:, 1] / rms_slide_1.clip(min=1e-9)

sqsecdrv = np.zeros_like(n)
sqsecdrv[1:-1, 0] = np.square(sig_norm_0[:-2] + sig_norm_0[2:] - 2.0 * sig_norm_0[1:-1]) * 0.25
sqsecdrv[1:-1, 1] = np.square(sig_norm_1[:-2] + sig_norm_1[2:] - 2.0 * sig_norm_1[1:-1]) * 0.25
sqsecdrv = sqsecdrv.clip(max=4.0)

spline0 = cspline1d(sig_norm_0, lamb=1)
spline1 = cspline1d(sig_norm_1, lamb=1)
noise0 = spline0 - sig_norm_0
noise1 = spline1 - sig_norm_1
b, a = butter(5, 1, fs=5, btype='low', analog=False)
filtered = filtfilt(b, a, sig_norm_0)
_, ax = plt.subplots(1)
ax.plot(sig_norm_0)
ax.plot(filtered)
plt.show()

noise0_rms = np.zeros_like(noise0)
noise1_rms = np.zeros_like(noise0)
flip0 = np.zeros_like(noise0)
flip1 = np.zeros_like(noise0)
keep = np.ones_like(starts, dtype=bool)
slopes0 = np.zeros_like(starts, dtype=float)
slopes1 = np.zeros_like(starts, dtype=float)
snr0 = np.zeros_like(starts, dtype=float)
snr1 = np.zeros_like(starts, dtype=float)
diff_sl = np.zeros_like(starts, dtype=float)
mult0 = np.zeros_like(starts, dtype=float)
mult1 = np.zeros_like(starts, dtype=float)
mid_mult0 = np.ones_like(rms_slide_0)
mid_mult1 = np.ones_like(rms_slide_1)


for idx in range(starts.size):
    lg = stops[idx] - starts[idx]
    win = 1.0 / (1.0 + np.exp((np.abs(np.arange(lg) - lg / 2.0 - 0.5) + 7.0 - lg / 2.0).clip(min=-10.0)))
    winsum = np.sum(win)
#     n0_val = np.sqrt(np.sum(np.square(noise0[starts[idx]:stops[idx]]) * win) / winsum)
#     n1_val = np.sqrt(np.sum(np.square(noise1[starts[idx]:stops[idx]]) * win) / winsum)
    n0_val = np.sum(sqsecdrv[starts[idx]:stops[idx], 0] * win) / winsum
    n1_val = np.sum(sqsecdrv[starts[idx]:stops[idx], 1] * win) / winsum
    per_val_0 = np.sum(r2_x[starts[idx]:stops[idx]] * win) / winsum
    per_val_1 = np.sum(r2_y[starts[idx]:stops[idx]] * win) / winsum
    noise0_rms[starts[idx]:stops[idx]] = n0_val
    noise1_rms[starts[idx]:stops[idx]] = n1_val
    m0 = spline0[starts[idx] + 1:stops[idx]] - spline0[starts[idx]:stops[idx] - 1]
    m1 = spline1[starts[idx] + 1:stops[idx]] - spline1[starts[idx]:stops[idx] - 1]
    slopes0[idx] = np.mean(m0[m0>0]) + np.mean(m0[m0<0])
    slopes1[idx] = np.mean(m1[m1>0]) + np.mean(m1[m1<0])
    diff_sl[idx] = np.sign(slopes0[idx] + parity[idx] * slopes1[idx])
    snr0[idx] = n0_val
    snr1[idx] = n1_val
    if min(n0_val, n1_val) > 0.03:  # empirical value that does pretty well
        windows[starts[idx]:stops[idx]] = 0.0
        keep[idx] = False
    mult0[idx] = np.exp(-35.0 * n0_val) * per_val_0
    mult1[idx] = np.exp(-35.0 * n1_val) * per_val_1
    mid_mult0[starts[idx]:stops[idx]] = mult0[idx]
    mid_mult1[starts[idx]:stops[idx]] = mult1[idx]
    flip0[starts[idx]:stops[idx]] = diff_sl[idx]
    flip1[starts[idx]:stops[idx]] = diff_sl[idx] * parity[idx]
    if idx == 0:
        mid_mult0[:starts[0]] = mult0[idx]
        mid_mult1[:starts[0]] = mult1[idx]
        flip0[:starts[0]] = diff_sl[idx]
        flip1[:starts[0]] = diff_sl[idx] * parity[idx]
    else:
        lg = starts[idx] - stops[idx - 1]
        mid_mult0[stops[idx - 1]:starts[idx]] = np.linspace(mult0[idx - 1], mult0[idx], lg + 1)[:-1]
        mid_mult1[stops[idx - 1]:starts[idx]] = np.linspace(mult1[idx - 1], mult1[idx], lg + 1)[:-1]
        flip0[stops[idx - 1]:stops[idx-1] + lg // 2] = diff_sl[idx - 1]
        flip1[stops[idx - 1]:stops[idx-1] + lg // 2] = diff_sl[idx - 1] * parity[idx - 1]
        flip0[stops[idx-1] + lg // 2:starts[idx]] = diff_sl[idx]
        flip1[stops[idx-1] + lg // 2:starts[idx]] = diff_sl[idx] * parity[idx]
flip0[stops[-1]:] = diff_sl[-1]
flip1[stops[-1]:] = diff_sl[-1] * parity[-1]

starts = starts[keep]
stops = stops[keep]
slopes0 = slopes0[keep]
slopes1 = slopes1[keep]
rms_avg_0 = rms_avg_0[keep]
rms_avg_1 = rms_avg_1[keep]
snr0 = snr0[keep]
snr1 = snr1[keep]
mult0 = mult0[keep]
mult1 = mult1[keep]

for idx in range(starts.size):
    rms_slide_0[starts[idx]:stops[idx]] = rms_avg_0[idx]
    rms_slide_1[starts[idx]:stops[idx]] = rms_avg_1[idx]
    mid_mult0[starts[idx]:stops[idx]] = mult0[idx]
    mid_mult1[starts[idx]:stops[idx]] = mult1[idx]
    if idx == 0:
        rms_slide_0[:starts[0]] = rms_avg_0[0]
        rms_slide_1[:starts[0]] = rms_avg_1[0]
        mid_mult0[:starts[0]] = mult0[idx]
        mid_mult1[:starts[0]] = mult1[idx]
    else:
        lg = starts[idx] - stops[idx - 1]
        rms_slide_0[stops[idx - 1]:starts[idx]] = np.linspace(rms_avg_0[idx - 1], rms_avg_0[idx], lg + 1)[:-1]
        rms_slide_1[stops[idx - 1]:starts[idx]] = np.linspace(rms_avg_1[idx - 1], rms_avg_1[idx], lg + 1)[:-1]
        mid_mult0[stops[idx - 1]:starts[idx]] = np.linspace(mult0[idx - 1], mult0[idx], lg + 1)[:-1]
        mid_mult1[stops[idx - 1]:starts[idx]] = np.linspace(mult1[idx - 1], mult1[idx], lg + 1)[:-1]
rms_slide_0[stops[-1]:] = rms_avg_0[-1]
rms_slide_1[stops[-1]:] = rms_avg_1[-1]
mid_mult0[stops[-1]:] = mult0[-1]
mid_mult1[stops[-1]:] = mult1[-1]

sig_norm_0 = n[:, 0] / rms_slide_0.clip(min=1e-9)
sig_norm_1 = n[:, 1] / rms_slide_1.clip(min=1e-9)

omega = 6.0
fs = 5.0
freqs = np.logspace(0.1, -1.4, 150)  # indices ~50-85 are breathing frequencies
widths_morlet = omega * fs / (freqs * 2 * np.pi)[30:90]
x_wave = wavefinding_cwt(sig_norm_0 * windows, widths_morlet, omega)
y_wave = wavefinding_cwt(sig_norm_1 * windows, widths_morlet, omega)
mags_z = np.abs(x_wave) + np.abs(y_wave)



NameError: name 'contig' is not defined

In [51]:
plt.close("all")
_, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True)
ax1.plot(n[:, 0])
ax1.plot(n[:, 1])
ax1.plot(rms_slide_0)
ax1.plot(rms_slide_1)
ax1.legend(["x", "y", "rmsx", "rmsy"])
ax2.plot(sig_norm_0 * windows)
ax2.plot(sig_norm_1 * windows)
ax2.plot(sig_norm_0.clip(min=-5.0, max=5.0), alpha=0.2)
ax2.plot(sig_norm_1.clip(min=-5.0, max=5.0), alpha=0.2)
# ax2.plot(flip0)
# ax2.plot(flip1)
# ax2.plot(sqsecdrv[:, 0] * windows)
# ax2.plot(sqsecdrv[:, 1] * windows)
# ax2.plot(r2_x * windows)
# ax2.plot(r2_y * windows)
# ax2.plot(mid_mult0)
# ax2.plot(mid_mult1)
ax2.plot(np.minimum(noise0_rms, noise1_rms) * 100.0)
ax2.plot(filtered.clip(-4.9, 4.9))
ax2.legend(["sn0", "sn1", "c0", "c1", "r2_swap", "filt", "r2y", "mm0", "mm1", "n0", "n1"])
ax3.imshow(mags_z.clip(max=np.percentile(mags_z, 99.9)), aspect="auto")
plt.show()

In [None]:
# next up, we want to flip the signals upright
# and then combine them into one signal (after adjusting for their SNR and r2 of their periodicity)
# For each segment, cspline1d, determine if up-slopes or down-slopes have larger slope magnitudes
# upright segments should have more gradual upwards slopes
# (there's a slight pause after exhaling before inhaling, doesn't happen at inhale-exhale transition)
plt.close("all")

angles = np.arctan2(mid_mult1, mid_mult0)

sig0_scale = -n[:, 0] * flip0 * np.square(np.cos(angles)) / rms_slide_0.clip(min=1e-9)
sig1_scale = -n[:, 1] * flip1 * np.square(np.sin(angles)) / rms_slide_1.clip(min=1e-9)
sig = sig0_scale + sig1_scale
_, ax = plt.subplots(1, 1)
ax.plot(sig_norm_0 * windows)
ax.plot(sig_norm_1 * windows)
ax.plot(sig * windows)
ax.plot(np.square(np.cos(angles)) * windows)
ax.plot(np.square(np.sin(angles)) * windows)
ax.legend("sn0,sn1,comb,cos2,sin2".split(","))
plt.show()

# still TO DO:
# funky shit going on with scales of signorm traces at 1050, ~4000, 114000, 115000, probably other places
# 118800 is upside down
# add in r2 from periodicity - not sure of the best way to do this - adding their effects? multiplying? It is exp so probs adding...
# make sure at least that the parity is correct - get correlation between two signals and make sure that result follows
# somehow 110000 gets into the periodic signal section?? How?? It's so angular and small
# 95000 - high frequency but clearly noise. Should not have made it into periodic.
# 140600 - also should not have made it in
# 137200 - also should not have made it in
# 142775 - nopers
# 142600 - should have  made it in
# 15350 - no
# 14990 - no
# 2100 - kinda no
# 4600-4800 - no, out of bed
# 6000 - no
# 8950 - 9250 - no



In [None]:
# after combining, find breathing frequency
omega = 5.0
fs = 5.0
freqs = np.logspace(0.1, -1.4, 150)  # indices ~50-85 are breathing frequencies
widths_morlet = omega * fs / (freqs * 2 * np.pi)
wave = wavefinding_cwt(sig * windows, widths_morlet, omega)
wave = wavefinding_cwt(sig.clip(min=-2.5, max=2.5), widths_morlet, omega)

angles = np.angle(wave)
mags = np.abs(np.abs(wave))
cols = np.mod(angles * 0.5 / np.pi + 0.5, 1)
col_spec = cols - 0.5
dcol = col_spec[:, 1:] - col_spec[:, :-1]
dcol[dcol < -0.5] += 1.0
dcol[dcol > 0.5] -= 1.0
# dcols = cm.hsv(dcol.clip(-0.00625, 0.00625) * 80.0 + 0.5)[..., :3]
# dcols *= (mags[:, 1:, None] / np.percentile(mags, 95.0)).clip(max=1.0)
# cols = cm.hsv(cols)[..., :3]
# cols *= mags[..., None]

# ddcol = dcol[50:85] - dcol[49:84]
scale = 500.0
ddcol = dcol[1:] - dcol[:-1]
sharpen_data = correlate1d(np.exp(-scale * np.abs(dcol)), np.exp(-0.5 * np.square((np.arange(101) - 50.0) / 25.0)), axis=1)
rate_data = np.square(mags[50:90, 1:]) * ddcol[49:89].clip(0.0) * sharpen_data[50:90]
inst_rate = np.sum(np.arange(50, 90)[:, None] * rate_data, axis=0) / np.sum(rate_data, axis=0).clip(1e-5)
inst_var = np.sum(np.square(np.arange(50, 90)[:, None] - inst_rate) * rate_data, axis=0) / np.sum(rate_data, axis=0).clip(1e-5)
inst_std = np.sqrt(inst_var)
inst_std_sm = correlate(inst_std, np.ones(31), mode="same")
inst_rate_fix = np.copy(inst_rate)

for idx in range(starts.size):
    lgb = min(stops[idx] - starts[idx], 100)
    lge = min(stops[idx - 1] - starts[idx - 1], 100)
    begin_avg = np.mean(inst_rate[starts[idx]:starts[idx] + lgb])
    if idx == 0:
        start_side = begin_avg
        inst_rate_fix[:starts[0]] = begin_avg
    elif idx == starts.size - 1:
        final_side = np.mean(inst_rate[stops[idx] - lge:stops[idx]])
        inst_rate_fix[stops[-1]:] = final_side
    else:
        end_avg = np.mean(inst_rate[stops[idx - 1] - lge:stops[idx - 1]])
        space = starts[idx] - stops[idx - 1]
        inst_rate_fix[stops[idx - 1]:starts[idx]] = np.linspace(end_avg, begin_avg, space + 1)[:-1]

inst_rate_fix = cspline1d(inst_rate_fix, 0.1)
inst_rate_sm = cspline1d(inst_rate_fix, 1e10)
        
stats_size = 1300
interp_rate = np.concatenate((np.ones(stats_size) * start_side, inst_rate_fix, np.ones(stats_size) * final_side))

curve_size = 15
curve_fw = curve_size * 2 + 1
x = np.arange(curve_fw) - curve_size
A = np.stack((np.square(x), x, np.ones(curve_fw)), axis=1)
krn = np.dot(np.linalg.inv(np.dot(A.T, A)), A.T)
gskrn = np.exp(-0.5 * np.square((np.arange(401) - 200)/100))

crv = np.square(100.0 * correlate(interp_rate, krn[0], mode="same")[stats_size:-stats_size])

In [None]:
plt.close("all")
sm = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
gr = np.array([-1.0, -2.0, 0.0, 2.0, 1.0])
# sm = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
# gr = np.array([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0])

# gx = sepfir2d(ddcol, gr, sm)
# gy = sepfir2d(ddcol, sm, gr)

# g2 = np.square(gx) + np.square(gy)
ffff = np.square(mags[1:, 1:])#ddcol * sharpen_data[1:] * 

# ctd = ffff - np.mean(ffff, axis=1, keepdims=True)
cov = np.cov(ffff)
evals, evecs = linalg.eigh(cov)
evecs = evecs[:, -20:]
something = np.dot(evecs, np.dot(evecs.T, ffff))
print(something.shape)

_, ax = plt.subplots(1, 1)
ax.imshow(something, aspect="auto")
plt.show()

_, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True)
ax1.imshow((ddcol * scale).clip(0.0, 1.0), aspect="auto")
# ax1.imshow(dcols, aspect="auto")
# ax1.imshow(cols, aspect="auto")
# ax2.imshow(sharpen_data, aspect="auto")
# ax2.imshow(np.exp(-50 * g2), aspect="auto")
ax2.imshow(ffff.clip(min=0.0, max=np.percentile(ffff, 99.0)), aspect="auto")
ax3.imshow(np.square(mags).clip(max=np.percentile(np.square(mags), 99.0)), aspect="auto")
ax4.imshow(rate_data.clip(max=np.percentile(rate_data, 99.0)), aspect="auto")

In [60]:

plt.close("all")

slp = np.square(10.0 * correlate(interp_rate, krn[1], mode="same"))[curve_fw:-curve_fw]
slp[stats_size-curve_fw:curve_fw-stats_size] *= windows_wide[:-1]
slp_sm = correlate(slp, gskrn, mode="same")[curve_fw:-curve_fw] / np.sum(gskrn)

# thoughts:
# if slp_sm > 0.85 (empirical) then possibly counts
# Join together groups of those which are > 0.85 if they're within a certain distance
# If not large enough, get rid of it.
# But then there's also a physiology question...
# if there's a good bout of movement, you probably don't go into REM right away? Or no...
rem_starts = np.argwhere(np.logical_and(slp_sm[1:] >= 0.85, slp_sm[:-1] < 0.85)).squeeze()
rem_stops = np.argwhere(np.logical_and(slp_sm[1:] < 0.85, slp_sm[:-1] >= 0.85)).squeeze()

keep = np.ones(starts.size + 1, dtype=bool)
for idx in range(starts.size - 1):
    rem_lg = starts[idx + 1] - stops[idx]
    if rem_lg <= 50:
        keep[idx + 1] = False
        print("ignoring", idx, "(", rem_lg, ")")

ns_starts = starts[keep[:-1]] + stats_size
ns_stops = stops[keep[1:]] + stats_size

keep = np.ones(rem_starts.size + 1, dtype=bool)
for idx in range(rem_starts.size - 1):
    gap_lg = rem_starts[idx + 1] - rem_stops[idx]
    if gap_lg <= 360:
        keep[idx + 1] = False

rem_starts = rem_starts[keep[:-1]]
rem_stops = rem_stops[keep[1:]]

rem_starts_cut = []
rem_stops_cut = []

# 6 possibilities:
# starts and stops without any gaps - accept both start and stop.
# starts within a gap, gap ends before stop - move the start to the end of the gap.
# gap starts after start, before stop, gap ends after stop - move stop up to the start of the gap.
# gap encompasses start and stop - delete start and stop.
# start and stop encompass gap - add in gap stop and start into start and stop.
# gap ends after start and a new one begins before stop - cut off the start and the end.

for idx in range(rem_starts.size):
    starts_in_gap = np.logical_and(ns_starts > rem_starts[idx], ns_starts < rem_stops[idx])
    stops_in_gap = np.logical_and(ns_stops > rem_starts[idx], ns_stops < rem_stops[idx])
    starts_in_gap_args = np.argwhere(starts_in_gap).squeeze(1)
    stops_in_gap_args = np.argwhere(stops_in_gap).squeeze(1)
    print(idx, starts_in_gap_args, stops_in_gap_args)
    
    if not (np.any(starts_in_gap) or np.any(stops_in_gap)):
        last_start = np.argmax(ns_starts <= rem_starts[idx])
        if ns_stops[last_start] < rem_starts[idx]:
            rem_starts_cut.append(rem_starts[idx])
            rem_stops_cut.append(rem_stops[idx])
        continue
    
    if np.count_nonzero(starts_in_gap) == np.count_nonzero(stops_in_gap):
        if np.all(starts_in_gap_args == stops_in_gap_args):
            # we have gaps to open up in the thing, but all of them are contained
            rem_starts_cut.extend(ns_starts[starts_in_gap].tolist())
            rem_stops_cut.extend(ns_stops[stops_in_gap].tolist())
        else:
            # keep the start and end
            rem_starts_cut.append(rem_starts[idx])
            rem_starts_cut.extend(ns_starts[starts_in_gap].tolist())
            rem_stops_cut.extend(ns_stops[stops_in_gap].tolist())
            rem_stops_cut.append(rem_stops[idx])
        continue
    elif np.count_nonzero(starts_in_gap) == 0:
        rem_starts_cut.append(rem_starts[idx])
        rem_stops_cut.extend(ns_stops[stops_in_gap].tolist())
    elif np.count_nonzero(stops_in_gap) == 0:
        rem_starts_cut.extend(ns_starts[starts_in_gap].tolist())
        rem_stops_cut.append(rem_stops[idx])
    else:
        if starts_in_gap_args[0] == stops_in_gap_args[0]:
            rem_starts_cut.extend(ns_starts[starts_in_gap].tolist())
            rem_stops_cut.extend(ns_stops[stops_in_gap].tolist())
            rem_stops_cut.append(rem_stops[idx])
        else:
            rem_starts_cut.append(rem_starts[idx])
            rem_starts_cut.extend(ns_starts[starts_in_gap].tolist())
            rem_stops_cut.extend(ns_stops[stops_in_gap].tolist())
        continue

rem_starts = np.array(rem_starts_cut)
rem_stops = np.array(rem_stops_cut)

m_lead, m_lag, b_lead, b_lag, rmse_lead, rmse_lag = fit_linear_lead_and_lag(interp_rate, stats_size + 1)

rmse_cutoff = (np.minimum(rmse_lead[stats_size:-stats_size], rmse_lag[stats_size:-stats_size]) - 2.2).clip(min=0.0)
possible_rem = correlate(rmse_cutoff, gskrn, mode="same") / np.sum(gskrn)

keep = np.ones(rem_starts.size, dtype=bool)

for idx in range(rem_starts.size):
    rem_lg = rem_stops[idx] - rem_starts[idx]
    print(rem_lg)
    if rem_lg <= 360 or not np.any(possible_rem[rem_starts[idx] - stats_size:rem_stops[idx] - stats_size] > 0.075) or not np.any(inst_rate_sm[rem_starts[idx] - stats_size:rem_stops[idx] - stats_size] > 66.0):
        keep[idx] = False
        print("deleting", idx)

rem_starts = rem_starts[keep] - stats_size
rem_stops = rem_stops[keep] - stats_size
        
print(rem_starts.size)

slp = slp[stats_size-curve_fw:curve_fw-stats_size]
slp_sm = slp_sm[stats_size-curve_fw:curve_fw-stats_size]

_, (ax3, ax4) = plt.subplots(2, 1, sharex=True)
ax3.imshow(np.square(mags).clip(max=np.percentile(np.square(mags), 99.0)), aspect="auto")
ax4.imshow(rate_data.clip(max=np.percentile(rate_data, 99.0)), aspect="auto")
ax4.plot((inst_rate - 50.0) * np.where(windows[:-1] > 0.0, 1.0, np.nan), c="r", linestyle=":")
ax4.plot(inst_std_sm * 0.04, c="y", linestyle=":")
ax4.plot(crv + 30, c="g")
ax4.plot(slp + 40, c="m")
ax4.plot(slp_sm + 40, c="b")
#// ax3.plot(m_lead[stats_size:-stats_size] + 100, c="r", linestyle=":")
#// ax3.plot(m_lag[stats_size:-stats_size] + 100, c="r", linestyle=":")
ax3.plot(np.minimum(rmse_lead[stats_size:-stats_size], rmse_lag[stats_size:-stats_size]).clip(0.0, 10.0) * 10.0, c="g", linestyle=":")
ax3.plot(np.maximum(rmse_lead[stats_size:-stats_size], rmse_lag[stats_size:-stats_size]) / np.minimum(rmse_lead[stats_size:-stats_size], rmse_lag[stats_size:-stats_size]).clip(1.5e-1) * 10.0, c="m", linestyle=":")
ax3.plot(possible_rem * 40.0, c="y", linestyle=":")
ax4.plot(windows_wide * 5, c="b")
ax3.plot(inst_rate_fix, c="r")
ax3.plot(inst_rate_sm, c="g")
ax3.plot((inst_rate_sm[1:] - inst_rate_sm[:-1]) * 1000 + 100, c="r")


for idx in range(rem_starts.size):
    ax4.axvspan(rem_starts[idx], rem_stops[idx], color="g", alpha=0.2)

for idx in range(res_x.size):
    ax3.axvline(x=res_x[idx])
# for idx in range(ns_starts.size):
#     ax3.axvspan(ns_starts[idx] - stats_size, ns_stops[idx] - stats_size, color="g", alpha=0.2)
    
print(rem_starts)

plt.show()

# fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(10, 10))
# ax1.imshow(mags.clip(max=np.percentile(mags, 95.0)), aspect="auto")
# ax2.plot(final_sm)
# ax2.scatter(apneas, final_sm[apneas], c="k", marker="+", zorder=1000)
# ax3.imshow(dcols, aspect="auto")
# plt.show()

# _, ax = plt.subplots(1, 1)
# ax.hist(pk_dist * 0.2, bins=120, range=(0.1, 24.1))
# plt.show()

# _, ax = plt.subplots(1, 1, sharex=True)
# ax.plot(timestamps, final)
# plt.show()

# plt.close("all")
# _, ax = plt.subplots(1, 1)
# ax.imshow(mags_z.clip(max=np.percentile(mags_z, 99.9)), aspect="auto")
# plt.show()

ignoring 0 ( 9 )
ignoring 3 ( 38 )
ignoring 4 ( 5 )
ignoring 6 ( 33 )
ignoring 8 ( 26 )
ignoring 9 ( 23 )
ignoring 11 ( 26 )
ignoring 14 ( 43 )
ignoring 16 ( 4 )
ignoring 17 ( 48 )
ignoring 20 ( 33 )
ignoring 21 ( 3 )
ignoring 22 ( 3 )
ignoring 24 ( 48 )
ignoring 25 ( 44 )
ignoring 27 ( 9 )
ignoring 30 ( 14 )
ignoring 36 ( 22 )
ignoring 37 ( 17 )
ignoring 39 ( 21 )
ignoring 40 ( 17 )
ignoring 41 ( 8 )
ignoring 42 ( 26 )
ignoring 43 ( 17 )
ignoring 45 ( 12 )
ignoring 47 ( 46 )
ignoring 48 ( 21 )
ignoring 49 ( 18 )
ignoring 52 ( 19 )
ignoring 56 ( 18 )
ignoring 57 ( 47 )
ignoring 59 ( 7 )
ignoring 61 ( 3 )
ignoring 64 ( 37 )
ignoring 67 ( 38 )
ignoring 68 ( 41 )
ignoring 69 ( 18 )
ignoring 70 ( 27 )
ignoring 71 ( 35 )
ignoring 76 ( 21 )
ignoring 80 ( 15 )
ignoring 81 ( 25 )
ignoring 82 ( 49 )
ignoring 83 ( 37 )
ignoring 84 ( 9 )
ignoring 86 ( 36 )
ignoring 89 ( 17 )
ignoring 90 ( 4 )
ignoring 95 ( 28 )
ignoring 97 ( 41 )
ignoring 98 ( 7 )
ignoring 101 ( 3 )
ignoring 102 ( 3 )
ignoring 10

In [59]:
plt.close("all")
# classify stages of wakefulness and sleep, movement, and out-of-bed periods
nps = -n[:, 0] * flip0 - n[:, 1] * flip1
non_per = np.where(windows == 0.0, sig, np.nan)
print(np.count_nonzero(windows[7000:] == 0.0) / np.size(windows[7000:]))
winsize = 51
mse = correlate(np.square(nps), np.ones(winsize), mode="full") / winsize
mse_lead = mse[:1 - winsize]
mse_lag = mse[winsize - 1:]
sm_min_mse = np.correlate(np.sqrt(np.minimum(mse_lead, mse_lag)), np.ones(1501), mode="same") / 1501
sm_max_mse = np.correlate(np.sqrt(np.maximum(mse_lead, mse_lag)), np.ones(1501), mode="same") / 1501
mins = np.where(windows == 0.0, np.sqrt(np.minimum(mse_lead, mse_lag)), np.nan)
maxs = np.where(windows == 0.0, np.sqrt(np.maximum(mse_lead, mse_lag)), np.nan)
prob_mvmt = 1.0 / (1.0 + np.exp(4.0 - 2.0 * sm_max_mse))
prob_nib = 1.0 / (1.0 + np.exp(6.0 * sm_min_mse - 4.5))
nib_cond = np.logical_and(np.sqrt(np.minimum(mse_lead, mse_lag)) < 1e-4, windows == 0.0)
nib_maybe = np.logical_and(np.sqrt(np.minimum(mse_lead, mse_lag)) < 2e-4, windows == 0.0)
nib_cond[0] = False
nib_cond[-1] = False
nib_starts = np.argwhere(np.logical_and(nib_cond[1:], np.logical_not(nib_cond[:-1]))).squeeze()
nib_stops = np.argwhere(np.logical_and(nib_cond[:-1], np.logical_not(nib_cond[1:]))).squeeze()

# to properly categorize me as being out of bed...
# first, take the windows *between* where I'm classified as moving/in bed/etc
# if they're a) pretty short and b) don't go too much above the threshold and c) if they do, it's brief
# then remove that stop and next start to join them up
# finally, if any window is less than 10 seconds, it doesn't count.
# Apneas might also be caught by this (and the under-10-second segments might also be apneas)
# worth flagging anything that might fit!
# also, out-of-bed signals probably need to be bracketed  by movement on either side...
# getting movement right is important.

# movement is anything that's not periodic, and has larger amplitude than everything around it.
# I'm not sure what the shortest duration movement could be...
# for sure as low as 1 second, but more than, say, 2 frames?

# there is one other condition, where the signal was definitely part of the breathing signal
# but was not really periodic - some abnormality. Worth flagging those.
# not always something that weird... sometimes just a breath that I didn't pause at the bottom, or did
# or a particularly fast inhale or exhale, like a sigh.
# How do we tell?
# first of all, they're short - typically 2 seconds or less.
# second of all, they aren't significantly bigger or smaller than the surrounding signal, about the same rms
# might be worth, once I collect enough data, classifying them with a machine learning classifier of sorts

keep = nib_stops - nib_starts > 50
nib_starts = nib_starts[keep]
nib_stops = nib_stops[keep]
print(nib_starts)
print(nib_stops)

_, ax = plt.subplots(1, 1)
ax.plot(sig, alpha=0.2, linestyle=":")
ax.plot(non_per)
# ax.plot(prob_mvmt, linestyle=":")
# ax.plot(prob_nib, linestyle=":")
ax.plot(mins)
ax.plot(maxs)
ax.plot(nib_cond.astype(float))
ax.legend(["sig", "sig nonper", "min", "max", "notinbed"])

for idx in range(nib_starts.size):
    ax.axvspan(nib_starts[idx], nib_stops[idx], color="g", alpha=0.2)

plt.show()
_, ax = plt.subplots(1, 1)
ax.scatter(mins, maxs)
ax.set_xscale("log")
ax.set_yscale("log")
plt.show()

# the next step is to categorize the periodic segments
# Artificially segment sleep segments into overlapping chunks
# try to classify everything within each chunk - maybe assign probability to all classifications somehow
# for each overlap piece, choose the highest overall probability weighted classification
# this happens with at least three statistics - mean, slope and stderr of the breathing frequency
# for a decreasing slope with fairly small stderr, it's falling asleep
# for flat slope, small stderr, low frequency, it's deep sleep
# for moderate slope, high stderr, REM sleep
# for mild slope, low to moderate stderr, high frequency, it's awake
# not sure how to categorize "light sleep", I don't have enough data yet
# I think I can probably use the breathing style as well
# ie how much time does each segment spend near 0 compared to at its peaks
# I think when I'm awake, I don't spend very much time in the exhaled state
# whereas in deep sleep, there very often is a pause between one exhale and the next inhale

# how to do probability? First off, decide on a threshold and sharpness, use sigmoid to determine which side
# each factor I measure can then be given eg a bayes factor or something. This gets tedious but it would be good
# give each type of sleep a weight with which it corresponds to that factor - eg low frequency is critical for deep
# the pause between breaths is a good signal for deep sleep, but only about 80% confidence. Reverse for awake, roughly
# etc.

# can I do this for the non-periodic stuff too? Two factors: duration, and rms

0.056396344023511936
[ 77820  78148 152156 153199]
[ 77875  78206 152210 153310]


In [None]:
# find apneas, and cheyne-stokes style breathing
# ie this is peak picking - can we use scipy or do I need to do it manually by breathing freq (sliding min/max filts)?

# around 55900 there's a weird bit
# 36600 cheynish
# 338750 - cheynish, also some got excluded

In [15]:

final = flip_components_indiv_and_combine(n[:, :2], res_x)
final_sm = cspline1d(final, lamb=5.0)

pks, *_ = find_peaks(final_sm, prominence=0.7)
pk_dist = pks[1:] - pks[:-1]
apneas = pks[:-1][pk_dist > 30]

lgth = final.shape[0]
omega = 10.0
fs = 5.0
freqs = np.logspace(0.1, -1.4, 150)
widths_morlet = omega * fs / (freqs * 2 * np.pi)
cwt_morlet = wavefinding_cwt(final, widths_morlet, omega)
angles = np.angle(cwt_morlet)
mags = np.square(np.abs(cwt_morlet))
cols = np.mod(angles * 0.5 / np.pi + 0.5, 1)
col_spec = cols - 0.5
dcol = col_spec[:, 1:] - col_spec[:, :-1]
dcol[dcol < -0.5] += 1.0
dcol[dcol > 0.5] -= 1.0
dcols = cm.hsv(dcol.clip(-0.00625, 0.00625) * 80.0 + 0.5)[..., :3]
dcols *= (mags[:, 1:, None] / np.percentile(mags, 95.0)).clip(max=1.0)
cols = cm.hsv(cols)[..., :3]
cols *= mags[..., None]

# ddcol = dcol[50:85] - dcol[49:84]
scale = 500.0
ddcol = dcol[1:] - dcol[:-1]
rate_data = mags[50:85, 1:] * ddcol[49:84].clip(0.0) * np.exp(-scale * np.abs(dcol[50:85]))
inst_rate = np.sum(np.arange(50, 85)[:, None] * rate_data, axis=0) / np.sum(rate_data, axis=0).clip(1e-5)
inst_var = np.sum(np.square(np.arange(50, 85)[:, None] - inst_rate) * rate_data, axis=0) / np.sum(rate_data, axis=0).clip(1e-5)
inst_std = np.sqrt(inst_var)
inst_std_sm = correlate(inst_std, np.ones(31), mode="same")

_, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True)
ax1.imshow((ddcol * scale).clip(0.0, 1.0), aspect="auto")
ax2.imshow(np.exp(-scale * np.abs(dcol)), aspect="auto")
ax3.imshow(mags.clip(max=np.percentile(mags, 98.0)), aspect="auto")
ax4.imshow(rate_data, aspect="auto")
ax4.plot(inst_rate - 50.0, c="r", linestyle=":")
# ax4.plot(inst_std_sm, c="y", linestyle=":")
plt.show()

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(10, 10))
ax1.imshow(mags.clip(max=np.percentile(mags, 95.0)), aspect="auto")
ax2.plot(final_sm)
ax2.scatter(apneas, final_sm[apneas], c="k", marker="+", zorder=1000)
ax3.imshow(dcols, aspect="auto")
plt.show()

_, ax = plt.subplots(1, 1)
ax.hist(pk_dist * 0.2, bins=120, range=(0.1, 24.1))
plt.show()

_, ax = plt.subplots(1, 1, sharex=True)
ax.plot(timestamps, final)
plt.show()

quit()

n1 = norm_complex(n0_resz, n1_resz)
n2 = norm_complex(n2_resz, n3_resz)

n1 = smooth_and_norm_complex_stitch(n0_resz, n1_resz, n2_resz, n3_resz)

m1, r1, res1, skew1 = compute_stats(n[:, 0], n[:, 1], winsize)
m2, r2, res2, skew2 = compute_stats(n[:, 2], n[:, 3], winsize)

# sig1 = smooth_and_norm(res1)
# sig2 = smooth_and_norm(res2)


# %%
# CWT

lgth = n1.shape[0]
omega = 2.0
fs = 5.0
freqs = np.logspace(0.1, -1.4, 150)
widths_morlet = omega * fs / (freqs * 2 * np.pi)
cwt_morlet = wavefinding_cwt(n1, widths_morlet, omega)
angles = np.angle(cwt_morlet)
mags = np.square(np.abs(cwt_morlet))
cols = np.mod(angles * 0.5 / np.pi + 0.5, 1)
col_spec = cols - 0.5
dcol = col_spec[:, 1:] - col_spec[:, :-1]
dcol[dcol < -0.5] += 1.0
dcol[dcol > 0.5] -= 1.0
dcols = cm.hsv(dcol.clip(-0.0125, 0.0125) * 40.0 + 0.5)[..., :3]
# dcols = cm.twilight_shifted(dcol.clip(-0.05, 0.05) * 10.0 + 0.5)[..., :3]
dcols *= (mags[:, 1:, None] / np.percentile(mags, 95.0)).clip(max=1.0)
cols = cm.hsv(cols)[..., :3]
cols *= mags[..., None]

sharps = np.sum(mags[:50], axis=0)
sharps *= 0.001  # empirical constant
sharps[0] = 0.0
sharps[-1] = 0.0

fatties = np.sum(mags[-50:], axis=0)
fatties *= 0.002
fatties[0] = 0.0
fatties[-1] = 0.0

mainline = np.sum(mags[50:-50], axis=0)
mainline *= 0.004
mainline[0] = 1.0
mainline[-1] = 1.0

summed = np.sum(mags, axis=0, keepdims=True)
wt_mean = np.sum(np.arange(freqs.size)[:, None] * mags, axis=0, keepdims=True) / summed
stdev = np.sqrt(
    np.sum(mags * np.square(np.arange(freqs.size)[:, None] - wt_mean), axis=0) / summed
)

movement = np.logical_or(sharps > 0.5, np.abs(n1) > 5.00).astype(int)
stationary = (mainline < 0.5).astype(int)

mvmt_start = np.argwhere(movement[1:] - movement[:-1] == 1).squeeze()
mvmt_end = np.argwhere(movement[1:] - movement[:-1] == -1).squeeze()
stn_start = np.concatenate(
    (
        np.array([-2]),
        np.argwhere(stationary[1:] - stationary[:-1] == 1).squeeze(),
        np.array([lgth]),
    )
)
stn_end = np.concatenate(
    (
        np.array([-1]),
        np.argwhere(stationary[1:] - stationary[:-1] == -1).squeeze(),
        np.array([lgth + 1]),
    )
)

before_args = np.argmax(
    np.mod(stn_end[:, None] - mvmt_start[None, :], lgth * 2), axis=0
)
before_okay = np.logical_and(
    0 < mvmt_start - stn_end[before_args], mvmt_start - stn_end[before_args] < 15
)
before_dist = (mvmt_start - stn_start[before_args]) * before_okay.astype(int)
mvmt_start_adj = mvmt_start - before_dist

after_args = np.argmax(np.mod(mvmt_end[None, :] - stn_start[:, None], lgth * 2), axis=0)
after_okay = np.logical_and(
    0 < stn_start[after_args] - mvmt_end, stn_start[after_args] - mvmt_end < 15
)
after_dist = (stn_end[after_args] - mvmt_end) * after_okay.astype(int)
mvmt_end_adj = mvmt_end + after_dist

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(10, 10))
ax1.imshow(mags.clip(max=np.percentile(mags, 95.0)), aspect="auto")
ax2.plot(np.real(n1))
ax2.plot(np.imag(n1))
ax2.plot(np.abs(n1))
ax2.plot(sharps.clip(max=2.0) * 10.0, linestyle=":")
# ax2.plot(mainline, c="k")
# ax2.plot(stdev[0], c="b")
# ax2.plot(movement * 3.0, c="k")
# ax2.plot(stationary * 2.0, c="b")
ax2.scatter(mvmt_start_adj, np.ones_like(mvmt_start_adj), c="k", zorder=1000)
ax2.scatter(mvmt_end_adj, np.ones_like(mvmt_end_adj), c="b", zorder=1001)
ax3.imshow(dcols, aspect="auto")
plt.show()

fig, ax = plt.subplots(1, 1)
ax.imshow(dcols, aspect="auto")
fig.set_size_inches(10, 4)
plt.show()


# %%
# dto = datetime.strptime(dt, "%y%m%d_%H%M%S")
# start = dto.hour + dto.minute / 60.0 + dto.second / 3600.0
# end = start + n.shape[0] / (5.0 * 3600.0)
# t = start + np.arange(n.shape[0]) / (5.0 * 3600.0) - (24.0 if start > 12 else 0.0)

# nfft = 64
# f, tfft, s = spectrogram(sig1, fs=5.0, nfft=nfft, mode="magnitude", nperseg=nfft)
# spec = 10.0 * np.log10(np.square(s))

# %%
# _, (ax1, ax2) = plt.subplots(2, 1)
# ax1.imshow(
#     np.flip(spec, axis=0),
#     vmin=np.percentile(spec, 20.0),
#     vmax=np.percentile(spec, 99.5),
#     extent=[t[0], t[-1], f[0], f[-1]],
# )
# ax2.plot(m1)
# ax2.plot(r2 * 2.0 + 3)
# plt.show()


# %%

omega = 10.0
widths_morlet = omega * fs / (freqs * 2 * np.pi)
cwt_morlet = wavefinding_cwt(n1, widths_morlet, omega)
angles = np.angle(cwt_morlet)
mags = np.abs(cwt_morlet)
cols = np.mod(angles * 0.5 / np.pi + 0.5, 1)
col_spec = cols - 0.5
dcol = col_spec[:, 1:] - col_spec[:, :-1]
dcol[dcol < -0.5] += 1.0
dcol[dcol > 0.5] -= 1.0
dcols = cm.hsv(dcol.clip(-0.0125, 0.0125) * 40.0 + 0.5)[..., :3]
# dcols = cm.twilight_shifted(dcol.clip(-0.05, 0.05) * 10.0 + 0.5)[..., :3]
dcols *= (mags[:, 1:, None] / np.percentile(mags, 95.0)).clip(max=1.0)
cols = cm.hsv(cols)[..., :3]
cols *= mags[..., None]

sharps = np.sum(mags[:50], axis=0)
sharps *= 0.001 if omega > 4.0 else 0.005

fatties = np.sum(mags[-50:], axis=0)
fatties *= 0.002 if omega > 4.0 else 0.01

mainline = np.sum(mags[50:-50], axis=0)
mainline *= 0.004 if omega > 4.0 else 0.01

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(10, 10))
ax1.imshow(mags.clip(max=np.percentile(mags, 95.0)), aspect="auto")
ax2.plot(np.real(n1))
ax2.plot(np.imag(n1))
ax2.plot(sharps)
ax2.plot(fatties)
ax2.plot(mainline, c="k")
ax3.imshow(dcols, aspect="auto")
fig.set_size_inches(10, 10)
plt.show()


TypeError: flip_components_indiv_and_combine() missing 1 required positional argument: 'segments'