In [1]:
import numpy as np
import cupy as cp
from scipy.io import loadmat
import os
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import shutil
import h5py
from scipy.signal import lfilter
from scipy.io import loadmat

%matplotlib inline

%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
sys.path.append('../github/pykilosort') #these paths need to be modified
sys.path.append('../github/phylib')
from pykilosort import Bunch, run, add_default_handler
from pykilosort.utils import memmap_large_array, LargeArrayWriter

In [3]:
dat_path = Path('D:\kilosort-testing\\100s_data\\p1\\imec_385_100s.bin') # this path needs to be modified
dir_path = dat_path.parent

In [4]:
add_default_handler(level='INFO')

The Matlab version needs to be run first and the rez output saved at each stage at rez_path under the names rez_preprocess, rez_cluster, rez_learn, rez_split1, rez_split2 and rez_cutoff as well as the binary whitened data file

In [5]:
rez_path = Path('D:\kilosort-testing\\100s_data\\m1') # needs to be changed
test_path = Path('D:\kilosort-testing\\100s_data\\python_stages2.5') # needs to be changed
dataset_name = 'imec_385_100s' # needs to be changed

In [6]:
def get_rez(rez_loc):
    try:
        rez_file = h5py.File(rez_loc)
        main_key = list(rez_file.keys())[-1]
        rez = Bunch()
        for key in rez_file[main_key].keys():
            try:
                rez[key] = rez_file[main_key][key][()].squeeze()
            except AttributeError:
                pass
        for key in rez_file[main_key]['ops'].keys():
            try:
                rez[key] = rez_file[main_key]['ops'][key][()].squeeze()
            except AttributeError:
                pass
    except OSError:
        rez_file = loadmat(rez_loc)
        main_key = list(rez_file.keys())[-1]
        rez = Bunch()
    return rez

def get_ctx(ctx_loc):
    files = os.listdir(ctx_loc)
    ctx = Bunch()
    for file in files:
        if file[-3:] == 'npy':
            ctx[file[:-4]] = np.load(ctx_loc / file)
    return ctx

def transpose_fortran(array):
    return np.asfortranarray(array.T)

def _save(array, path, name):
    np.save(path / name, transpose_fortran(array))
    
def _save_largearray(array, path, name):
    writer = LargeArrayWriter(path / (name + '.dat'), dtype = np.float32, shape = (*array.shape[:-1], -1))
    writer.append(np.asfortranarray(array))
    writer.close()

def setup_dir(path, name):
    test_path = path / name / '.kilosort' / dataset_name
    if os.path.isdir(test_path):
        shutil.rmtree(test_path)
    os.makedirs(test_path)
    return test_path

def test(name, rez, ctx, mapping=None, mapping_axes=None, python_name=None, atol=1e-08):
    
    if python_name is None:
        python_name = name
    var_m = np.copy(rez[name]).T
    var_p = cp.asnumpy(ctx.intermediate[python_name])
    
    if mapping_axes is not None:
        assert mapping is not None
        for i in mapping_axes:
            var_m = np.take(var_m, mapping, axis=i)
    
    return np.allclose(var_m, var_p, atol=atol)

def test_abs(name, rez, ctx, mapping=None, mapping_axes=None, python_name=None, atol=1e-08):
    
    if python_name is None:
        python_name = name
    var_m = np.abs(np.copy(rez[name]).T)
    var_p = np.abs(cp.asnumpy(ctx.intermediate[python_name]))
    
    if mapping_axes is not None:
        assert mapping is not None
        for i in mapping_axes:
            var_m = np.take(var_m, mapping, axis=i)
    
    return np.allclose(var_m, var_p, atol=atol)

## Test Whitening Matrix

In [377]:
test_white_path = setup_dir(test_path, 'test_white')

In [378]:
dir_path = Path('D:\kilosort-testing\\100s_data\\p1\\imec_385_100s.bin').parent

probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [379]:
ctx = run(dat_path, probe=probe, dir_path=test_white_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, stop_after='whitening_matrix')

[0m21:42:20.832 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m21:42:20.832 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m21:42:20.834 [I] utils:334            Starting step whitening_matrix.[0m
[0m21:42:20.834 [I] utils:334            Starting step whitening_matrix.[0m
Computing the whitening matrix: 100%|████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.70s/it]
[0m21:42:30.258 [I] preprocess:275       Computed the whitening matrix.[0m
[0m21:42:30.258 [I] preprocess:275       Computed the whitening matrix.[0m
[0m21:42:30.286 [I] utils:344            Step `whitening_matrix` took 9.45s.[0m
[0m21:42:30.286 [I] utils:344            Step `whitening_matrix` took 9.45s.[0m


In [380]:
ir = ctx.intermediate
ir.keys()

dict_keys(['Nbatch', 'igood', 'Wrot'])

In [381]:
rez_pre = get_rez(rez_path / 'rez_preprocess.mat')

In [383]:
Wrot_python = cp.asnumpy(ir.Wrot)
Wrot_matlab = rez_pre.Wrot

In [388]:
np.allclose(Wrot_python, Wrot_matlab.T, atol=1e-02)

True

## Test Preprocess

In [7]:
test_pre_path = setup_dir(test_path, 'test_preprocess')

rez_preprocess = get_rez(rez_path / 'rez_preprocess.mat')

In [8]:
variables = ['Wrot']

for variable in variables:
    _save(rez_preprocess[variable], test_pre_path, variable)

In [9]:
del rez_preprocess

In [10]:
probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [12]:
ctx = run(dat_path, probe=probe, dir_path=test_pre_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, stop_after='preprocess')

[0m01:35:48.558 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m01:35:48.560 [I] utils:334            Starting step preprocess.[0m
[0m01:35:48.663 [I] preprocess:396       Loading raw data and applying filters.[0m
Preprocessing: 100%|███████████████████████████████████████████████████████████████████| 46/46 [01:39<00:00,  2.17s/it]
[0m01:37:28.484 [I] utils:344            Step `preprocess` took 99.92s.[0m


In [13]:
rez_pre = get_rez(rez_path / 'rez_preprocess.mat')
ir = ctx.intermediate

In [14]:
raw_data_python = np.memmap(ir.proc_path, dtype=np.int16, mode="r", order="F")
raw_data_matlab = np.memmap(rez_path / 'temp_wh.dat', dtype=np.int16, mode='r', order='F')

In [15]:
Nchan = ctx.probe.Nchan
NT = ctx.params.NT
Nbatch = ir.Nbatch

In [19]:
tol = 2

print(f'Tolerance set to {tol}')

for i in np.random.choice(Nbatch, 10, replace=False):
    batch_python = raw_data_python[NT * Nchan * i:NT * Nchan * (i+1)].reshape((-1, Nchan), order='F')
    batch_matlab = raw_data_matlab[NT * Nchan * i:NT * Nchan * (i+1)].reshape((Nchan, -1), order='F').T
    print(f'Batch {i} matches: {np.allclose(batch_python, batch_matlab, atol=2)}')

Tolerance set to 2
Batch 13 matches: True
Batch 39 matches: True
Batch 7 matches: True
Batch 32 matches: True
Batch 12 matches: True
Batch 0 matches: True
Batch 24 matches: True
Batch 11 matches: True
Batch 10 matches: True
Batch 27 matches: True


## Test Learn

## Test Merge

In [7]:
test_merge_path = setup_dir(test_path, 'test_merge')

rez_learn = get_rez(rez_path / 'rez_learn.mat')

In [8]:
def save_rez_learn(rez, path):
    
    variables = ['ccbsort', 'dWU', 'igood', 'iNeigh', 'iNeighPC', 'iorig', 'mu', 
                 'simScore', 'U', 'U_a', 'U_b', 'UA', 'W', 'W_a', 'W_b',
                'WA', 'Wrot', 'wPCA', 'wTEMP']
    
    for variable in variables:
        _save(rez[variable], path, variable)
        
    _save(rez.ccb, path, 'ccb0')
    
    _save_largearray(rez.cProj, path, 'fW')
    _save_largearray(rez.cProjPC, path, 'fWPC')
    
    st3 = np.copy(rez.st3)
    st3[1,:] = rez.st3[1,:] - 1 # 0-index channels
    _save(st3, path, 'st3')
    
    shutil.copyfile(path / 'fW.dat', path / 'proc.dat') # hack to avoid re-running the pre-processing stage

In [9]:
save_rez_learn(rez_learn, test_merge_path)
del rez_learn

In [10]:
probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [11]:
ctx = run(dat_path, probe=probe, dir_path=test_merge_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, stop_after='merge')

[0m23:25:29.215 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m23:25:29.356 [I] utils:334            Starting step merge.[0m
  Q = (Qi / max(Q00, Q01)).min()
Finding merges: 100%|███████████████████████████████████████████████████████████████| 357/357 [00:03<00:00, 114.03it/s]
[0m23:25:32.765 [I] utils:344            Step `merge` took 3.41s.[0m


In [12]:
rez_merge = get_rez(rez_path / 'rez_merge.mat')

For the merge phase it is enough to check that st3 is the same

In [13]:
st3_matlab = np.copy(rez_merge.st3.T)
st3_matlab[:,1] = st3_matlab[:,1] - 1
st3_python = cp.asnumpy(ctx.intermediate.st3_m)
np.allclose(st3_matlab, st3_python)

True

In [14]:
ix = np.where(st3_matlab[:,1] - st3_python[:,1] != 0)[0]
bad_clusters = np.unique(np.concatenate((st3_python[:,1][ix], st3_matlab[:,1][ix])))
print(f'The following clusters were different: {bad_clusters}')

The following clusters were different: []


In [15]:
del ctx
del rez_merge

## Test First Split

In [16]:
test_split1_path = setup_dir(test_path, 'test_split1')
rez_merge = get_rez(rez_path / 'rez_merge.mat')

In [17]:
def save_rez_merge(rez, path):
   
    variables = ['ccbsort', 'dWU', 'igood', 'iNeigh', 'iNeighPC', 'iorig', 'mu', 
                 'simScore', 'U', 'U_a', 'U_b', 'UA', 'W', 'W_a', 'W_b',
                'WA', 'Wrot', 'wPCA', 'wTEMP']
    
    for variable in variables:
        _save(rez[variable], path, variable)
        
    _save(rez.ccb, path, 'ccb0')
    
    _save_largearray(rez.cProj, path, 'fW')
    _save_largearray(rez.cProjPC, path, 'fWPC')
    
    st3 = np.copy(rez.st3)
    st3[1,:] = rez.st3[1,:] - 1 # 0-index channels
    _save(st3, path, 'st3')
    _save(st3, path, 'st3_m')
    
    shutil.copyfile(path / 'fW.dat', path / 'proc.dat')

In [18]:
save_rez_merge(rez_merge, test_split1_path)
del rez_merge

In [19]:
probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [20]:
ctx = run(dat_path, probe=probe, dir_path=test_split1_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, \
          stop_after='split_1')

[0m23:25:36.662 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m23:25:36.804 [I] utils:334            Starting step split_1.[0m
[0m23:25:37.532 [I] postprocess:522      Found 0 splits, checked 0/357 clusters, nccg 0[0m
[0m23:25:42.824 [I] postprocess:522      Found 4 splits, checked 100/361 clusters, nccg 8[0m
[0m23:25:47.425 [I] postprocess:522      Found 10 splits, checked 200/367 clusters, nccg 9[0m
[0m23:25:52.510 [I] postprocess:522      Found 13 splits, checked 300/370 clusters, nccg 14[0m
[0m23:25:58.552 [I] postprocess:732      Finished splitting. Found 19 splits, checked 376/376 clusters, nccg 19[0m
[0m23:25:58.741 [I] utils:344            Step `split_1` took 21.94s.[0m


In [21]:
rez_split1 = get_rez(rez_path / 'rez_split1.mat')

In [22]:
st3_matlab = np.copy(rez_split1.st3.T)
st3_matlab[:,1] = st3_matlab[:,1] - 1
st3_python = cp.asnumpy(ctx.intermediate.st3_s1)

In [23]:
ix = np.where(st3_matlab[:,1] - st3_python[:,1] != 0)[0]
bad_clusters = np.unique(np.concatenate((st3_python[:,1][ix], st3_matlab[:,1][ix])))
print(f'The following clusters were different: {bad_clusters}')

The following clusters were different: [ 44. 173. 195. 246. 357. 358. 359. 360. 361. 362. 363. 365. 369. 371.
 372. 373. 374. 375.]


Some of these differences are just due to labelling, we can try and learn a mapping from the Python clusters to their corresponding Matlab clusters:

In [24]:
cluster_mapping = np.zeros((len(bad_clusters), 2), dtype = int)
cluster_mapping[:,0] = bad_clusters
cluster_mapping[:,1] = -1

for i in range(len(bad_clusters)):
    ix_p = np.where(st3_python[:,1] == bad_clusters[i])[0]
    for j in bad_clusters:
        ix_m = np.where(st3_matlab[:,1] == j)[0]
        if np.array_equal(ix_p, ix_m):
            cluster_mapping[i,1] = j
            break

mapping_found = False
if np.sum(cluster_mapping[:,1] == -1):
    print(f"No mapping found as the following Python clusters can't be matched:" \
        f"{cluster_mapping[:,0][np.where(cluster_mapping[:,1] == 0)[0]]}")
elif np.max(st3_matlab[:,1]) != np.max(st3_python[:,1]):
    print("Some Matlab clusters had no corresponding Python match")
else:
    mapping = np.arange(np.max(st3_python[:,1] + 1), dtype=int)
    for i in range(cluster_mapping.shape[0]):
        mapping[cluster_mapping[i,0]] = cluster_mapping[i,1]
    mapping_found = True
    print("Mapping was found")

Mapping was found


In [25]:
print(f"mu matches: {test('mu', rez_split1, ctx, mapping, [0], 'mu_s')}")
print(f"simScore matches: {test('simScore', rez_split1, ctx, mapping, [0,1], 'simScore_s')}")
print(f"isplit matches: {test('isplit', rez_split1, ctx, mapping, [0,1])}")

print(f"iNeighPC matches: {np.allclose(cp.asnumpy(ctx.intermediate.iNeighPC_s.T), (rez_split1.iNeighPC-1)[mapping])}")

mu matches: True
simScore matches: True
isplit matches: True
iNeighPC matches: True


SVD components only match up to an arbitrary sign so we compare absolute values and also use a higher tolerance

In [26]:
print(f"W matches: {test_abs('W', rez_split1, ctx, mapping, [1], 'W_s', atol=1e-05)}")
print(f"U matches: {test_abs('U', rez_split1, ctx, mapping, [1], 'U_s', atol=1e-05)}")
print(f"Wphy matches: {test_abs('Wphy', rez_split1, ctx, mapping, [1], atol=1e-05)}")

W matches: True
U matches: True
Wphy matches: True


iNeigh stores the nearest 32 neighbours for each template. Due to rounding errors this may not always match so we check to see how many nearest neighbours match. iList does exactly the same (repeated variable)

In [27]:
print(f"Number of nearest neighbours that match in iNeigh: \
{np.min(np.where((rez_split1.iNeigh-1)[mapping] != mapping[cp.asnumpy(ctx.intermediate.iNeigh_s.T)])[1])}/32")
print(f"Number of nearest neighbours that match in iList: \
{np.min(np.where((rez_split1.iList-1)[mapping] != mapping[cp.asnumpy(ctx.intermediate.iList.T)])[1])}/32")

Number of nearest neighbours that match in iNeigh: 21/32
Number of nearest neighbours that match in iList: 21/32


In [28]:
del ctx
del rez_split1

## Test Second Split

In [29]:
test_split2_path = setup_dir(test_path, 'test_split2')
rez_split1 = get_rez(rez_path / 'rez_split1.mat')

In [30]:
def save_rez_split1(rez, path):
   
    variables = ['ccbsort', 'dWU', 'igood', 'iList', 'iNeigh', 'iNeighPC', 'iorig', 'isplit', 'mu', 
                 'simScore', 'U', 'U_a', 'U_b', 'UA', 'W', 'W_a', 'W_b',
                'WA', 'Wrot', 'Wphy', 'wPCA', 'wTEMP']
    
    for variable in variables:
        _save(rez[variable], path, variable)
        
    _save(rez.ccb, path, 'ccb0')
    
    _save_largearray(rez.cProj, path, 'fW')
    _save_largearray(rez.cProjPC, path, 'fWPC')
    
    st3 = np.copy(rez.st3)
    st3[1,:] = rez.st3[1,:] - 1 # 0-index channels
    _save(st3, path, 'st3')
    _save(st3, path, 'st3_m')
    _save(st3, path, 'st3_s1')
    
    _save(rez.iNeigh, path, 'iNeigh_s')
    _save(rez.iNeighPC, path, 'iNeighPC_s')
    _save(rez.mu, path, 'mu_s')
    _save(rez.simScore, path, 'simScore_s')
    _save(rez.U, path, 'U_s')
    _save(rez.W, path, 'W_s')
    
    shutil.copyfile(path / 'fW.dat', path / 'proc.dat')

In [31]:
save_rez_split1(rez_split1, test_split2_path)
del rez_split1

In [32]:
probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [33]:
ctx = run(dat_path, probe=probe, dir_path=test_split2_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, \
          stop_after='split_2')

[0m23:26:02.808 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m23:26:03.043 [I] utils:334            Starting step split_2.[0m
[0m23:26:03.308 [I] postprocess:522      Found 0 splits, checked 0/376 clusters, nccg 0[0m
[0m23:26:07.380 [I] postprocess:522      Found 1 splits, checked 100/377 clusters, nccg 9[0m
[0m23:26:10.468 [I] postprocess:522      Found 1 splits, checked 200/377 clusters, nccg 10[0m
[0m23:26:14.338 [I] postprocess:522      Found 1 splits, checked 300/377 clusters, nccg 14[0m
[0m23:26:19.683 [I] postprocess:732      Finished splitting. Found 3 splits, checked 379/379 clusters, nccg 20[0m
[0m23:26:19.844 [I] utils:344            Step `split_2` took 16.80s.[0m


In [34]:
rez_split2 = get_rez(rez_path / 'rez_split2.mat')

In [35]:
st3_matlab = np.copy(rez_split2.st3.T)
st3_matlab[:,1] = st3_matlab[:,1] - 1
st3_python = cp.asnumpy(ctx.intermediate.st3_s0)

In [36]:
ix = np.where(st3_matlab[:,1] - st3_python[:,1] != 0)[0]
bad_clusters = np.unique(np.concatenate((st3_python[:,1][ix], st3_matlab[:,1][ix])))
print(f'The following clusters were different: {bad_clusters}')

The following clusters were different: []


Some of these differences are just due to labelling, we can try and learn a mapping from the Python clusters to their corresponding Matlab clusters:

In [37]:
cluster_mapping = np.zeros((len(bad_clusters), 2), dtype = int)
cluster_mapping[:,0] = bad_clusters
cluster_mapping[:,1] = -1

for i in range(len(bad_clusters)):
    ix_p = np.where(st3_python[:,1] == bad_clusters[i])[0]
    for j in bad_clusters:
        ix_m = np.where(st3_matlab[:,1] == j)[0]
        if np.array_equal(ix_p, ix_m):
            cluster_mapping[i,1] = j
            break

mapping_found = False
if np.sum(cluster_mapping[:,1] == -1) > 0:
    print(f"No mapping found as the following Python clusters can't be matched:" \
        f"{cluster_mapping[:,0][np.where(cluster_mapping[:,1] == 0)[0]]}")
elif np.max(st3_matlab[:,1]) != np.max(st3_python[:,1]):
    print("Some Matlab clusters had no corresponding Python match")
else:
    mapping = np.arange(np.max(st3_python[:,1] + 1), dtype=int)
    for i in range(cluster_mapping.shape[0]):
        mapping[cluster_mapping[i,0]] = cluster_mapping[i,1]
    mapping_found = True
    print("Mapping was found")

Mapping was found


In [38]:
print(f"mu matches: {test('mu', rez_split2, ctx, mapping, [0], 'mu_s')}")
print(f"simScore matches: {test('simScore', rez_split2, ctx, mapping, [0,1], 'simScore_s')}")
print(f"isplit matches: {test('isplit', rez_split2, ctx, mapping, [0,1])}")

print(f"iNeighPC matches: {np.allclose(cp.asnumpy(ctx.intermediate.iNeighPC_s.T), rez_split2.iNeighPC - 1)}")

mu matches: True
simScore matches: True
isplit matches: True
iNeighPC matches: True


SVD components only match up to an arbitrary sign so we compare absolute values and also use a higher tolerance

In [39]:
print(f"W matches: {test_abs('W', rez_split2, ctx, mapping, [1], 'W_s', atol=1e-05)}")
print(f"U matches: {test_abs('U', rez_split2, ctx, mapping, [1], 'U_s', atol=1e-05)}")
print(f"Wphy matches: {test_abs('Wphy', rez_split2, ctx, mapping, [1], atol=1e-05)}")

W matches: True
U matches: True
Wphy matches: True


iNeigh stores the nearest 32 neighbours for each template. Due to rounding errors this may not always match so we check to see how many nearest neighbours match. iList does exactly the same (repeated variable)

In [40]:
print(f"Number of nearest neighbours that match in iNeigh: \
{np.min(np.where((rez_split2.iNeigh-1)[mapping] != mapping[cp.asnumpy(ctx.intermediate.iNeigh_s.T)])[1])}/32")
print(f"Number of nearest neighbours that match in iList: \
{np.min(np.where((rez_split2.iList-1)[mapping] != mapping[cp.asnumpy(ctx.intermediate.iList.T)])[1])}/32")

Number of nearest neighbours that match in iNeigh: 21/32
Number of nearest neighbours that match in iList: 21/32


In [41]:
del ctx
del rez_split2

## Test Cutoff

In [42]:
test_cutoff_path = setup_dir(test_path, 'test_cutoff')
rez_split2 = get_rez(rez_path / 'rez_split2.mat')

In [43]:
def save_rez_split2(rez, path):
   
    variables = ['ccbsort', 'dWU', 'igood', 'iList', 'iNeigh', 'iNeighPC', 'iorig', 'isplit', 'mu', 
                 'simScore', 'U', 'U_a', 'U_b', 'UA', 'W', 'W_a', 'W_b',
                'WA', 'Wrot', 'Wphy', 'wPCA', 'wTEMP']
    
    for variable in variables:
        _save(rez[variable], path, variable)
        
    _save(rez.ccb, path, 'ccb0')
    
    _save_largearray(rez.cProj, path, 'fW')
    _save_largearray(rez.cProjPC, path, 'fWPC')
    
    st3 = np.copy(rez.st3)
    st3[1,:] = rez.st3[1,:] - 1 # 0-index channels
    _save(st3, path, 'st3')
    _save(st3, path, 'st3_m')
    _save(st3, path, 'st3_s1')
    _save(st3, path, 'st3_s0')
    
    _save(rez.iNeigh, path, 'iNeigh_s')
    _save(rez.iNeighPC, path, 'iNeighPC_s')
    _save(rez.mu, path, 'mu_s')
    _save(rez.simScore, path, 'simScore_s')
    _save(rez.U, path, 'U_s')
    _save(rez.W, path, 'W_s')
    
    shutil.copyfile(path / 'fW.dat', path / 'proc.dat')

In [44]:
save_rez_split2(rez_split2, test_cutoff_path)
del rez_split2

In [45]:
probe = Bunch()
probe.NchanTOT = 385
# WARNING: indexing mismatch with MATLAB hence the -1
probe.chanMap = np.load(dir_path / 'chanMap.npy').squeeze().astype(np.int64) - 1
probe.xc = np.load(dir_path / 'xc.npy').squeeze()
probe.yc = np.load(dir_path / 'yc.npy').squeeze()
probe.kcoords = np.load(dir_path / 'kcoords.npy').squeeze()

In [46]:
ctx = run(dat_path, probe=probe, dir_path=test_cutoff_path.parent.parent, n_channels=385, dtype=np.int16, sample_rate=3e4, \
          stop_after='cutoff')

[0m23:26:23.964 [I] main:56              Loaded raw data with 385 channels, 3000000 samples.[0m
[0m23:26:24.207 [I] utils:334            Starting step cutoff.[0m
Setting cutoff: 100%|███████████████████████████████████████████████████████████████| 379/379 [00:02<00:00, 186.19it/s]
[0m23:26:26.249 [I] utils:344            Step `cutoff` took 2.04s.[0m


In [47]:
rez_cutoff = get_rez(rez_path / 'rez_cutoff.mat')

In [48]:
print(f"Ths matches: {np.allclose(rez_cutoff.Ths, ctx.intermediate.Ths)}")
print(f"good matches: {np.allclose(rez_cutoff.good, ctx.intermediate.good)}")
print(f"est_contam_rate matches: {np.allclose(rez_cutoff.est_contam_rate, ctx.intermediate.est_contam_rate)}")

st3_matlab = np.copy(rez_cutoff.st3.T)
st3_matlab[:,1] = st3_matlab[:,1] - 1
st3_python = cp.asnumpy(ctx.intermediate.st3_c)
print(f"st3 matches: {np.allclose(st3_python, st3_matlab)}")

Ths matches: True
good matches: True
est_contam_rate matches: True
st3 matches: True


In [49]:
del ctx
del rez_cutoff