In [15]:
# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import json
import upath
import pandas as pd
import mongomock
import xarray as xr
import numpy as np
from fsspec.implementations.local import LocalFileSystem
from upath import UPath
import matplotlib.pyplot as plt

from signalstore import UnitOfWorkProvider

In [17]:
# Functions

def deserialize_dataarray(data_object):
        """Deserializes a data object.
        Arguments:
            data_object {dict} -- The data object to deserialize.
        Returns:
            dict -- The deserialized data object.
        """
        attrs = data_object.attrs.copy()
        for key, value in attrs.items():
            if isinstance(value, str):
                value = value.replace("'", '"')
                if value.lower() == 'true':
                    attrs[key] = True
                elif value.lower() == 'false':
                    attrs[key] = False
                elif value.lower() == 'none':
                    attrs[key] = None
                elif value.startswith('{'):
                    attrs[key] = json.loads(value)
            if isinstance(value, np.ndarray):
                attrs[key] = value.tolist()
        data_object.attrs = attrs
        return data_object


In [18]:
# Mock DB client
mongo_client = mongomock.MongoClient()

# Demo filesystem
tmpdir =  UPath.cwd().parent / r"data" / r"internal"
tmpdir = upath.UPath(tmpdir)
filesystem = LocalFileSystem(root=str(tmpdir))
# clear tmpdir
for file in filesystem.ls(tmpdir):
    filesystem.rm(file)

# Empty memory store for demo
memory_store = dict()

# Get uow to act on database
uow_provider = UnitOfWorkProvider(mongo_client, filesystem, memory_store)
unit_of_work = uow_provider(str(tmpdir))

# Get raw_property_models
property_models_path = UPath.cwd().parent / r"tests" / r"data" / r"valid_data" / r"models" / r"property_models.json"
with open(property_models_path, 'r') as file:
    raw_property_models = json.load(file)

# Get raw_metamodels
metamodels_dir = UPath.cwd().parent / r"tests" / r"data" / r"valid_data" / r"models" / r"metamodels"
metamodel_filepaths = list(metamodels_dir.glob("*.json"))
raw_metamodels = []
for filepath in metamodel_filepaths:
    with open(filepath, 'r') as file:
        raw_metamodels.append(json.load(file))

# Get raw_data_models
data_models_dir = UPath.cwd().parent / r"tests" / r"data" / r"valid_data" / r"models" / r"data_models"
data_model_filepaths = list(data_models_dir.glob("*.json"))
raw_data_models = []
for filepath in data_model_filepaths:
    with open(filepath, 'r') as file:
        raw_data_models.append(json.load(file))

# Get raw_records
netcdf_dir = UPath.cwd().parent / r"data" / r"input"
records_files = list(netcdf_dir.glob("*.xlsx"))
raw_records = []
for filepath in records_files:
        file = pd.read_excel(filepath, engine='openpyxl', dtype=str)
        file_json = file.to_json(orient="records")
        records = json.loads(file_json)
        raw_records.extend(records)

# Get dataarrays
netcdf_dir = UPath.cwd().parent / r"data" / r"input"
netcdf_files = list(netcdf_dir.glob("*.nc"))
dataarrays = []
for filepath in netcdf_files:
    dataarray = xr.open_dataarray(filepath)
    dataarray = deserialize_dataarray(dataarray)
    dataarrays.append(dataarray)

# Add property models, metamodels, data models, and records
with unit_of_work as uow:
    for property_model in raw_property_models:
        uow.domain_models.add(property_model)
    for metamodel in raw_metamodels:
        uow.domain_models.add(metamodel)
    for data_model in raw_data_models:
        uow.domain_models.add(data_model)
    for record in raw_records:
        if not record.get("has_file"):
            uow.data.add(record)
    for dataarray in dataarrays:
        if not dataarray.attrs.get("schema_ref") == "test":
            uow.data.add(dataarray)
    uow.commit()

DIMS TO ADD: ['spike_idx', '1']
SERIALIZED: {'schema_ref': 'spike_labels', 'data_name': 'NON-73-6_1_20180606_101138', 'has_file': 'True', 'data_dimensions': array(['spike_idx', '1'], dtype='<U9'), 'dimension_of_measure': '[nominal]', 'session_data_ref': '{"schema_ref": "session", "data_name": "NON-73-6_1_20180606_101138"}', 'animal_data_ref': '{"schema_ref": "animal", "data_name": "NON-73-6"}', 'probe_data_ref': '{"schema_ref": "probe", "data_name": "1"}'}
BEFORE CAUTION DESER: ['spike_idx' '1']
TOLIST?: ['spike_idx', '1']
AFTER CAUTION DESER: ['spike_idx', '1']
DIMS TO ADD: ['spike_idx', '1']
SERIALIZED: {'schema_ref': 'spike_times', 'data_name': 'NON-73-6_1_20180606_101138', 'has_file': 'True', 'data_dimensions': array(['spike_idx', '1'], dtype='<U9'), 'dimension_of_measure': '[time]', 'session_data_ref': '{"schema_ref": "session", "data_name": "NON-73-6_1_20180606_101138"}', 'animal_data_ref': '{"schema_ref": "animal", "data_name": "NON-73-6"}', 'probe_data_ref': '{"schema_ref": "pr

In [19]:
with unit_of_work as uow:
    query = {
        "schema_ref": "session"
    }
    sorted_by = [("session_start", -1)]
    sessions = uow.data.find(query)
    print(sessions[0])

{'schema_ref': 'session', 'data_name': 'NON-73-6_1_20180606_101138', 'animal_id': 'NON-73-6', 'session_date': '20180606', 'start_time': '10:11:38', 'tetrode_depth': '3000', 'stimulus_id': 'NO', 'duration': '601.0', 'duration_unit': 'second', 'stimulus_type': 'object', 'time_of_save': datetime.datetime(2024, 4, 23, 1, 54, 44, 472136), 'time_of_removal': None}


In [25]:
with unit_of_work as uow:
    query = {
        "session_data_ref": {'schema_ref': 'session', 'data_name': 'NON-73-6_1_20180606_101138'},
        "probe_data_ref": {'schema_ref': 'probe', 'data_name': '1'}
    }
    data = uow.data.find(query)
    data = xr.Dataset({d["schema_ref"]: uow.data.get(d["schema_ref"], d["data_name"]) for d in data})
    print(data)

DIMS GOT: ['spike_idx' '1']
DIMS GOT: ['spike_idx' '1']
DIMS GOT: ['spike_idx' 'channel' 'sample']
<xarray.Dataset> Size: 11MB
Dimensions:          (spike_idx: 13374, 1: 1, sample: 50, channel: 4)
Coordinates:
  * spike_idx        (spike_idx) int32 53kB 0 1 2 3 ... 13370 13371 13372 13373
  * sample           (sample) float32 200B 0.0 1.0 2.0 3.0 ... 47.0 48.0 49.0
  * channel          (channel) float32 16B 0.0 1.0 2.0 3.0
Dimensions without coordinates: 1
Data variables:
    spike_labels     (spike_idx, 1) float32 53kB ...
    spike_times      (spike_idx, 1) float32 53kB ...
    spike_waveforms  (spike_idx, channel, sample) float32 11MB ...


In [28]:
def plot_colored_waveforms(data, ax=None):
    """Plot waveforms colored by spike_labels"""
    # use xarray operations to filter waveforms by spike_labels
    if ax is None:
        fig, ax = plt.subplots()
    unique_labels = data.spike_labels.unique()
    for label in unique_labels:
        filtered_data = data.where(data.spike_labels == label, drop=True)
        ax.plot(filtered_data.waveform, label=label)
    ax.legend()
    return ax