In [1]:
import cudf
import pandas as pd
import kdephys as kd
import acr.units as au
import os

In [2]:
import kdephys.cudf_flavor as kfl

In [3]:
@kfl.register_cudf_method
def rec(self, rec):
    return self.loc[self.recording == rec]

In [4]:
df = cudf.read_parquet("/Volumes/opto_loc/Data/ACR_14/sorting_data/spike_dataframes/sdpi-NNXr.parquet")

In [6]:
df.rec('sdpi-bl')

Unnamed: 0,time,cluster_id,group,note,channel,sort_id,recording,datetime,stim,state
0,0.003686,12,mua,noisy,8.0,sdpi-NNXr,sdpi-bl,2022-09-20 09:04:25.999999000,,Wake
1,0.009585,139,mua,noisy,13.0,sdpi-NNXr,sdpi-bl,2022-09-20 09:04:26.005897241,,Wake
2,0.089825,12,mua,noisy,8.0,sdpi-NNXr,sdpi-bl,2022-09-20 09:04:26.086137889,,Wake
3,0.101089,22,mua,noisy,11.0,sdpi-NNXr,sdpi-bl,2022-09-20 09:04:26.097401890,,Wake
4,0.129761,141,mua,merged,15.0,sdpi-NNXr,sdpi-bl,2022-09-20 09:04:26.126073893,,Wake
...,...,...,...,...,...,...,...,...,...,...
5072495,86494.955530,44,mua,noisy,5.0,sdpi-NNXr,sdpi-bl,2022-09-21 09:06:00.951842323,,no_state
5072496,86494.973880,139,mua,noisy,13.0,sdpi-NNXr,sdpi-bl,2022-09-21 09:06:00.970192405,,no_state
5072497,86494.977812,22,mua,noisy,11.0,sdpi-NNXr,sdpi-bl,2022-09-21 09:06:00.974124565,,no_state
5072498,86494.981867,22,mua,noisy,11.0,sdpi-NNXr,sdpi-bl,2022-09-21 09:06:00.978179606,,no_state


In [None]:
df1 = au.load_spike_dfs('ACR_14', 'sdpi-NNXr')
df2 = au.load_spike_dfs('ACR_14', 'sdpi-NNXo')

In [None]:
def register_dataframe_method(method):
    """Register a function as a method attached to the Pandas DataFrame.
    Example
    -------
    .. code-block:: python
        @register_dataframe_method
        def print_column(df, col):
            '''Print the dataframe column given'''
            print(df[col])
    """

    def inner(*args, **kwargs):
        class AccessorMethod(object):
            def __init__(self, pandas_obj):
                self._obj = pandas_obj

            @wraps(method)
            def __call__(self, *args, **kwargs):
                return method(self._obj, *args, **kwargs)

        register_dataframe_accessor(method.__name__)(AccessorMethod)

        return method

    return inner()

In [None]:
def load_spike_dfs(subject, sort_id=None):
    """
    Load sorted spike dataframes
    if sort_id is specified, only load that one
    if sort_id is not specified, load all in sorting_data/spike_dataframes folder

    Args:
        subject (str): subject name
        sort_id (optional): specific sort_id to load. Defaults to None.

    Returns:
        spikes_df: spike dataframe or dictionary of spike dataframes, depending on sort_id
    """
    path = f"/Volumes/opto_loc/Data/{subject}/sorting_data/spike_dataframes/"
    if sort_id:
        key = sort_id + ".parquet"
        spike_dfs = pd.read_parquet(path + key)
    else:
        spike_dfs = {}
        for f in os.listdir(path):
            sort_id = f.split(".")[0]
            spike_dfs[sort_id] = pd.read_parquet(path + f)
    return spike_dfs

In [3]:
df = pd.concat([df1, df2])
cdf = cudf.DataFrame.from_pandas(df)

In [None]:
def load_cudf_units(subject, sort_id=None):
    """
    Load sorted spike dataframes
    if sort_id is specified, load only those
    if sort_id is not specified, load all in sorting_data/spike_dataframes folder

    Args:
        subject (str): subject name
        sort_id (optional): specific sort_id to load. Defaults to None.

    Returns:
        spikes_df: spike dataframe or dictionary of spike dataframes, depending on sort_id
    """
    path = f"/Volumes/opto_loc/Data/{subject}/sorting_data/spike_dataframes/"
    if sort_id:
        key = sort_id + ".parquet"
        spike_dfs = cudf.read_parquet(path + key)
    else:
        spike_dfs = {}
        for f in os.listdir(path):
            sort_id = f.split(".")[0]
            spike_dfs[sort_id] = cudf.read_parquet(path + f)
    return spike_dfs

In [23]:
cdf.loc[cdf.sort_id == 'sdpi-NNXr'].cluster_id.value_counts()

22     4271600
141    2217335
139    2148243
152    1885466
12     1863281
19      945505
154     835323
44      510140
36      229661
Name: cluster_id, dtype: int32

In [7]:
import cudf as gd
@gd.api.extensions.register_dataframe_accessor("udf")
class UDFAccessor:
    def __init__(self, obj):
        self._obj = obj

    
    def cols(self):
        return self._obj.columns

    def sid(self, sort_id):
        return self._obj.loc[self._obj.sort_id == sort_id]
    def rec(self, rec):
        return self._obj.loc[self._obj.recording == rec]
    def stm(self):
        return self._obj.loc[self.stim == 1]



In [11]:
from cudf.api.extensions import register_dataframe_accessor

In [12]:
register_dataframe_accessor?

[0;31mSignature:[0m [0mregister_dataframe_accessor[0m[0;34m([0m[0mname[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Extends `cudf.DataFrame` with custom defined accessor

Parameters
----------
name : str
    The name to be registered in `DataFrame` for the custom accessor

Returns
-------
decorator : callable
    Decorator function for accessor

Notes
-----
The `DataFrame` object will be passed to your custom accessor upon first
invocation. And will be cached for future calls.

If the data passed to your accessor is of wrong datatype, you should
raise an `AttributeError` in consistent with other cudf methods.


Examples
--------

In your library code:

    >>> import cudf as gd
    >>> @gd.api.extensions.register_dataframe_accessor("point")
    ... class PointsAccessor:
    ...     def __init__(self, obj):
    ...         self._validate(obj)
    ...         self._obj = obj
    ...     @staticmethod
    ...     def _validate(obj):
    ...         cols = obj.col

In [8]:
cdf.udf.stm()

AttributeError: 'UDFAccessor' object has no attribute 'stm'

In [None]:
bl_ends = slice('2022-09-09 11:01', '2022-09-09 11:06')

In [None]:
td = t.reset_index().set_index('datetime')

In [None]:
p = p.reset_index().set_index('datetime')

In [None]:
td.loc[bl_ends]

Unnamed: 0_level_0,channel,time,data,timedelta,condition,state
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2022-09-09 11:01:00.000266018,1,14087.000267,39.807999,59999845171,laser1-EEGr,Wake
2022-09-09 11:01:00.000266018,2,14087.000267,41.983997,59999845171,laser1-EEGr,Wake
2022-09-09 11:01:00.001249058,1,14087.001250,105.727997,60000828211,laser1-EEGr,Wake
2022-09-09 11:01:00.001249058,2,14087.001250,41.215996,60000828211,laser1-EEGr,Wake
2022-09-09 11:01:00.002232098,1,14087.002233,-0.704000,60001811251,laser1-EEGr,Wake
...,...,...,...,...,...,...
2022-09-09 11:05:59.997525794,2,14386.997527,9.791999,359997104947,laser1-EEGr,Wake
2022-09-09 11:05:59.998508834,1,14386.998510,-7.296000,359998087987,laser1-EEGr,Wake
2022-09-09 11:05:59.998508834,2,14386.998510,-40.447998,359998087987,laser1-EEGr,Wake
2022-09-09 11:05:59.999491874,1,14386.999493,182.655991,359999071027,laser1-EEGr,Wake


In [None]:
p.loc[bl_ends]

Unnamed: 0_level_0,channel,time,data,timedelta,condition,state
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2022-09-09 11:01:00.000266018,1,14087.000267,39.807999,0 days 00:00:59.999845171,laser1-EEGr,Wake
2022-09-09 11:01:00.000266018,2,14087.000267,41.983997,0 days 00:00:59.999845171,laser1-EEGr,Wake
2022-09-09 11:01:00.001249058,1,14087.001250,105.727997,0 days 00:01:00.000828211,laser1-EEGr,Wake
2022-09-09 11:01:00.001249058,2,14087.001250,41.215996,0 days 00:01:00.000828211,laser1-EEGr,Wake
2022-09-09 11:01:00.002232098,1,14087.002233,-0.704000,0 days 00:01:00.001811251,laser1-EEGr,Wake
...,...,...,...,...,...,...
2022-09-09 11:05:59.997525794,2,14386.997527,9.791999,0 days 00:05:59.997104947,laser1-EEGr,Wake
2022-09-09 11:05:59.998508834,1,14386.998510,-7.296000,0 days 00:05:59.998087987,laser1-EEGr,Wake
2022-09-09 11:05:59.998508834,2,14386.998510,-40.447998,0 days 00:05:59.998087987,laser1-EEGr,Wake
2022-09-09 11:05:59.999491874,1,14386.999493,182.655991,0 days 00:05:59.999071027,laser1-EEGr,Wake
