# Lossy compression of raw ephys data

## Imports

In [128]:
from itertools import islice
import os
from pathlib import Path
import sys
import zlib

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.signal import decimate
import seaborn as sns
from tqdm import tqdm
from neurodsp.voltage import decompress_destripe_cbin, destripe

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import mtscomp
from mtscomp import Reader, Writer, compress, decompress

In [195]:
%matplotlib inline
plt.rcParams["figure.dpi"] = 180
plt.rcParams["axes.grid"] = False
sns.set_theme(style="white")

## Paths

In [18]:
EPHYS_DIR = Path(".").resolve()

In [19]:
path_cbin = list(EPHYS_DIR.glob("*.cbin"))[0]
path_cbin

PosixPath('/home/cyrille/31-ibl/qcapp/lossy/_spikeglx_ephysData_g0_t0.imec0.ap.cbin')

## Decompressing the .cbin

In [20]:
decomp = decompress(path_cbin)

In [27]:
fs = decomp.sample_rate; fs

30000.0

In [44]:
T = decomp.n_samples / sr; T

3897.896366666667

In [45]:
def t2s(t):
    return int(round(t * fs))

In [46]:
T_s = t2s(duration)

In [47]:
ms = 1e-3

In [50]:
h = 20*ms
h_s = t2s(h)

## Function to get and preprocess the data

In [154]:
DOWNSAMPLE = 4

In [178]:
def get(t, h=h, downsample=DOWNSAMPLE):  # in seconds
    t_s = t2s(t)
    h_s = t2s(h)
    
    t_s = np.clip(t_s, h_s, T_s - h_s)
    chunk = decomp[t_s - h_s:t_s + h_s, :-1].T
    chunk = destripe(chunk, fs=fs)
    if downsample > 1:
        chunk = decimate(chunk, downsample, axis=1)
    return chunk

In [162]:
chunk = get(0)
m, M = chunk.min(), chunk.max()

## Interactive raw ephys data viewer

In [164]:
@interact(t0=(0 + h, T - h), downsample=(1, 24), vmin=(m, M), vmax=(m, M))
def show(t0=h, vmin=m, vmax=M, downsample=DOWNSAMPLE):
    chunk = get(t0, downsample=downsample)
    plt.imshow(chunk, cmap="gray", aspect="auto", interpolation='none', vmin=vmin, vmax=vmax);
    k = 5
    plt.xticks(ticks=np.linspace(0, chunk.shape[1], k), labels=['%.3f' % _ for _ in np.linspace(t0-h, t0+h, k)]);

interactive(children=(FloatSlider(value=0.02, description='t0', max=3897.876366666667, min=0.02), FloatSlider(…

## Extracting data excerpts

In [165]:
n_excerpts = 10
N = int(round(T))
excerpts = np.hstack([get(t, .5) for t in .5 + np.arange(1, N - 1, N // n_excerpts)])

In [166]:
excerpts.shape

(384, 82500)

## Computing the SVD of the excerpts

In [167]:
U, sigma, V = np.linalg.svd(excerpts, full_matrices=False)

In [168]:
U.shape, V.shape

((384, 384), (384, 82500))

In [169]:
Usigma = U @ np.diag(sigma)
Usigma_inv = np.linalg.inv(Usigma)

## Compression/decompression functions

In [170]:
def compress(chunk):
    return (Usigma_inv @ chunk)

In [171]:
def reconstruct(comp, rk):
    return U[:, :rk] @ np.diag(sigma[:rk]) @ comp[:rk, :]

In [172]:
def reduce_depth(arr, cmin=None, cmax=None):
    lossy = (arr - cmin) / (cmax - cmin)
    lossy = np.clip(lossy, 0, 1)
    lossy = np.round(lossy * 255)
    lossy = lossy.astype(np.uint8)
    lossy = lossy.astype(np.float32) * (1./255)
    return cmin + lossy * (cmax - cmin)

## Visualizing the data

In [173]:
ming, maxg = excerpts.min(), excerpts.max()

In [174]:
k = .0025
min0, max0 = np.quantile(excerpts.ravel(), k), np.quantile(excerpts.ravel(), 1-k)

In [190]:
t0 = 10
RANK = nc // 2

In [198]:
@interact(
    t=(h, T-h), 
    rk=(3, nc),
    downsample=(1, 24),
    cvmin=(ming, maxg),
    cmax=(ming, maxg),
    reduce_bit_depth=True,
    vmin=(ming, maxg),
    vmax=(ming, maxg))
def compare(t=t0, rk=RANK, downsample=DOWNSAMPLE, vmin=min0, vmax=max0, cmin=min0, cmax=max0, reduce_bit_depth=None):
    orig = get(t, downsample=1)
    
    chunk = get(t, downsample=downsample)
    comp = compress(chunk)
    lossy = reconstruct(comp, rk)
    ratio = nc / float(rk) * downsample
    if reduce_bit_depth:
        lossy = reduce_depth(lossy, cmin, cmax)
        ratio *= 2
        
    fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharey=True);
    kwargs = dict(cmap="gray", aspect="auto", interpolation='none', vmin=vmin, vmax=vmax)
    
    n_ticks = 5
    ticks = ['%.3f' % _ for _ in np.linspace(t-h, t+h, n_ticks)]
    
    axes[0].imshow(orig, **kwargs);
    axes[0].set_xticks(np.linspace(0, orig.shape[1], n_ticks), ticks);
    axes[0].set_xlabel(f"time (s)");
    axes[0].set_title(f"original (destriped)");
    
    axes[1].imshow(lossy, **kwargs);
    axes[1].set_xticks(np.linspace(0, lossy.shape[1], n_ticks), ticks);
    axes[1].set_xlabel(f"time (s)");
    axes[1].set_title(f"lossy, rank={rk}, downsampled {downsample}x, ratio {ratio:.1f}x");
    
    plt.show();

interactive(children=(FloatSlider(value=10.0, description='t', max=3897.876366666667, min=0.02), IntSlider(val…