diff --git a/README.md b/README.md index 0cd113a..631d015 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# WizardHat +# WizardHat ![logo](https://github.com/merlin-neurotech/WizardHat/blob/master/WizardHatLogoSmall.jpg) WizardHat is a library for the streaming and handling of EEG data from consumer-grade devices using the Lab Streaming Layer (LSL) protocol. WizardHat's prupose is to enable users and especially first timers to flexibly build brain-computer interfaces (BCIs) without the fuss of configuring a streaming environment. WizardHat was built by Merlin Neurotech at Queen's University. Currently, WizardHat supports the Muse (2016) brain-sensing headband, the OpenBCI Ganglion, and runs on Python 3.6. WizardHat is easy to use and only requires three lines of code to get started. WizardHat's framework enables streaming, manipulation, and visualization of online EEG data. @@ -8,26 +8,49 @@ For first time python users, please refer to our [beginner's guide](https://docs ## Note : Active Development Our dedicated team at Merlin Neurotech is continously working to improve WizardHat and add new functionality. Current on-going projects: -- Frequency Spectrum Data Class - MNE Library Compatibility - Implementing simple filters -- Power spectrum transformer Check back soon if the feature you are looking for is under development! ## Getting Started -To set up WizardHat, begin by cloning this repository on your local environment. Once cloned, ensure you are in a new virtual environment and download the required dependencies. +The procedure for installing WizardHat depends on whether or not you will be contributing to its development. In either case, begin by creating and activating a new python virtual environment. +### Installing for use only +Simply run + + pip install wizardhat + +This will automatically install the most recent release of WizardHat along with the required dependencies. + +### Installing for development +To set up WizardHat for development, begin by forking the repository on GitHub, then clone your fork: + + git clone https://github.com//WizardHat.git + +If you are also developing for ble2lsl, fork and then clone the ble2lsl repository as well, and install its dependencies: + + git clone https://github.com//ble2lsl.git + cd ble2lsl pip install -r requirements.txt + pip install -e . + cd .. + +Whether or not you cloned ble2lsl, install the remaining dependencies for WizardHat: + + cd WizardHat + pip install -r requirements.txt + +### Finally -For more details on how to set up your python environment on Windows/MacOS/Linux please refer to our detailed instructions in the documentation file. +For more details on how to set up your Python environment on Windows/MacOS/Linux please refer to our detailed instructions in the documentation file. Next, to ensure a bug free experience, open [your virtual env name]/lib/python3.6/site packages/pygatt/backends/bgapi/bgapi.py in a text or code editor and add: time.sleep(0.25) -between line 200 and 201 and save the file. This ensures that the bluetooth protocol will be given adequate time to connect to the Muse before timing out. +between line 200 and 201 and save the file. This ensures that the Bluetooth protocol will be given adequate time to connect to the Muse before timing out. Now you are ready to use WizardHat! @@ -83,4 +106,4 @@ Chris, Hamada ## Acknowledgements -This project was inspired by Alexander Barachant's [muse-lsl](https://github.com/alexandrebarachant/muse-lsl) from which some of the modules are derived or informed (particularly `ble2lsl` and some of `wizardhat.acquire`). The device specification for the OpenBCI Ganglion is largely derived from [OpenBCI_Python](https://github.com/OpenBCI/OpenBCI_Python). +This project was inspired by Alexander Barachant's [muse-lsl](https://github.com/alexandrebarachant/muse-lsl) from which some of the modules were originally based. The device specification for the OpenBCI Ganglion is largely derived from [OpenBCI_Python](https://github.com/OpenBCI/OpenBCI_Python). diff --git a/ble2lsl/__init__.py b/ble2lsl/__init__.py deleted file mode 100644 index ae85927..0000000 --- a/ble2lsl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ - -from ble2lsl.ble2lsl import * diff --git a/ble2lsl/ble2lsl.py b/ble2lsl/ble2lsl.py deleted file mode 100644 index 5382d2a..0000000 --- a/ble2lsl/ble2lsl.py +++ /dev/null @@ -1,500 +0,0 @@ -"""Interfacing between Bluetooth Low Energy and Lab Streaming Layer protocols. - -Interfacing with devices over Bluetooth Low Energy (BLE) is achieved using the -`Generic Attribute Profile`_ (GATT) standard procedures for data transfer. -Reading and writing of GATT descriptors is provided by the `pygatt`_ module. - -All classes streaming data through an LSL outlet should subclass -`BaseStreamer`. - -Also includes dummy streamer objects, which do not acquire data over BLE but -pass local data through an LSL outlet, e.g. for testing. - -TODO: - * AttrDict for attribute-like dict access from device PARAMS? - -.. _Generic Attribute Profile: - https://www.bluetooth.com/specifications/gatt/generic-attributes-overview - -.. _pygatt: - https://github.com/peplin/pygatt -""" - -from queue import Queue -from struct import error as StructError -import threading -import time -from warnings import warn - -import numpy as np -import pygatt -from pygatt.backends.bgapi.exceptions import ExpectedResponseTimeout -import pylsl as lsl - -INFO_ARGS = ['type', 'channel_count', 'nominal_srate', 'channel_format'] - - -class BaseStreamer: - """Base class for streaming data through an LSL outlet. - - Prepares `pylsl.StreamInfo` and `pylsl.StreamOutlet` objects as well as - data buffers for handling of incoming chunks. - - Subclasses must implement `start` and `stop` methods for stream control. - - TODO: - * Public access to outlets and stream info? - * Push chunks, not samples (have to generate intra-chunk timestamps anyway) - """ - - def __init__(self, device, subscriptions=None, time_func=time.time, - ch_names=None): - """Construct a `BaseStreamer` object. - - Args: - device: A device module in `ble2lsl.devices`. - time_func (function): Function for generating timestamps. - subscriptions (Iterable[str]): Types of device data to stream. - Some subset of `SUBSCRIPTION_NAMES`. - ch_names (dict[Iterable[str]]): User-defined channel names. - e.g. `{'EEG': ('Ch1', 'Ch2', 'Ch3', 'Ch4')}`. - """ - self._device = device - if subscriptions is None: - subscriptions = get_default_subscriptions(device) - self._subscriptions = tuple(subscriptions) - self._time_func = time_func - self._user_ch_names = ch_names if ch_names is not None else {} - self._stream_params = self._device.PARAMS['streams'] - - self._chunk_idxs = stream_idxs_zeros(self._subscriptions) - self._chunks = empty_chunks(self._stream_params, - self._subscriptions) - - # StreamOutlet.push_chunk doesn't like single-sample chunks... - # but want to keep using push_chunk for intra-chunk timestamps - # doing this beforehand to avoid a chunk size check for each push - chunk_size = self._stream_params["chunk_size"] - self._push_func = {name: (self._push_chunk_as_sample - if chunk_size[name] == 1 - else self._push_chunk) - for name in self._subscriptions} - - def start(self): - """Begin streaming through the LSL outlet.""" - raise NotImplementedError() - - def stop(self): - """Stop/pause streaming through the LSL outlet.""" - raise NotImplementedError() - - def _init_lsl_outlets(self): - """Call in subclass after acquiring address.""" - source_id = "{}-{}".format(self._device.NAME, self._address) - self._info = {} - self._outlets = {} - for name in self._subscriptions: - info = {arg: self._stream_params[arg][name] for arg in INFO_ARGS} - outlet_name = '{}-{}'.format(self._device.NAME, name) - self._info[name] = lsl.StreamInfo(outlet_name, **info, - source_id=source_id) - self._add_device_info(name) - chunk_size = self._stream_params["chunk_size"][name] - self._outlets[name] = lsl.StreamOutlet(self._info[name], - chunk_size=chunk_size, - max_buffered=360) - - def _push_chunk(self, name, timestamp): - self._outlets[name].push_chunk(self._chunks[name].tolist(), - timestamp) - - def _push_chunk_as_sample(self, name, timestamp): - self._outlets[name].push_sample(self._chunks[name].tolist()[0], - timestamp) - - def _add_device_info(self, name): - """Adds device-specific parameters to `info`.""" - desc = self._info[name].desc() - try: - desc.append_child_value("manufacturer", self._device.MANUFACTURER) - except KeyError: - warn("Manufacturer not specified in device file") - - channels = desc.append_child("channels") - try: - ch_names = self._stream_params["ch_names"][name] - # use user-specified ch_names if available and right no. channels - if name in self._user_ch_names: - user_ch_names = self._user_ch_names[name] - if len(user_ch_names) == len(ch_names): - if len(user_ch_names) == len(set(user_ch_names)): - ch_names = user_ch_names - else: - print("Non-unique names in user-defined {} ch_names; " - .format(name), "using default ch_names.") - else: - print("Wrong # of channels in user-defined {} ch_names; " - .format(name), "using default ch_names.") - - for c, ch_name in enumerate(ch_names): - unit = self._stream_params["units"][name][c] - type_ = self._stream_params["type"][name] - channels.append_child("channel") \ - .append_child_value("label", ch_name) \ - .append_child_value("unit", unit) \ - .append_child_value("type", type_) - except KeyError: - raise ValueError("Channel names, units, or types not specified") - - @property - def subscriptions(self): - """The names of the subscribed streams.""" - return self._subscriptions - - -class Streamer(BaseStreamer): - """Streams data to an LSL outlet from a BLE device. - - TODO: - * Try built-in LSL features for intra-chunk timestamps (StreamOutlet) - * initialize_timestamping: should indices be reset to 0 mid-streaming? - """ - - def __init__(self, device, address=None, backend='bgapi', interface=None, - autostart=True, scan_timeout=10.5, internal_timestamps=False, - **kwargs): - """Construct a `Streamer` instance for a given device. - - Args: - device (dict): A device module in `ble2lsl.devices`. - For example, `ble2lsl.devices.muse2016`. - Provides info on BLE characteristics and device metadata. - address (str): Device MAC address for establishing connection. - By default, this is acquired automatically using device name. - backend (str): Which `pygatt` backend to use. - Allowed values are `'bgapi'` or `'gatt'`. The `'gatt'` backend - only works on Linux under the BlueZ protocol stack. - interface (str): The identifier for the BLE adapter interface. - When `backend='gatt'`, defaults to `'hci0'`. - autostart (bool): Whether to start streaming on instantiation. - scan_timeout (float): Seconds before timeout of BLE adapter scan. - internal_timestamps (bool): Use internal timestamping. - If `False` (default), uses initial timestamp, nominal sample - rate, and device-provided sample ID to determine timestamp. - If `True` (or when sample IDs not provided), generates - timestamps at the time of chunk retrieval, only using - nominal sample rate as need to determine timestamps within - chunks. - """ - BaseStreamer.__init__(self, device=device, **kwargs) - self._transmit_queue = Queue() - self._ble_params = self._device.PARAMS["ble"] - self._address = address - - # use internal timestamps if requested, or if stream is variable rate - # (LSL uses nominal_srate=0.0 for variable rates) - nominal_srates = self._stream_params["nominal_srate"] - self._internal_timestamps = {name: (internal_timestamps - if nominal_srates[name] else True) - for name in device.STREAMS} - self._start_time = stream_idxs_zeros(self._subscriptions) - self._first_chunk_idxs = stream_idxs_zeros(self._subscriptions) - - # initialize gatt adapter - if backend == 'bgapi': - self._adapter = pygatt.BGAPIBackend(serial_port=interface) - elif backend in ['gatt', 'bluez']: - # only works on Linux - interface = self.interface or 'hci0' - self._adapter = pygatt.GATTToolBackend(interface) - else: - raise(ValueError("Invalid backend specified; use bgapi or gatt.")) - self._backend = backend - self._scan_timeout = scan_timeout - - self._transmit_thread = threading.Thread(target=self._transmit_chunks) - - if autostart: - self.connect() - self.start() - - def _init_timestamp(self, name, chunk_idx): - """Set the starting timestamp and chunk index for a subscription.""" - self._first_chunk_idxs[name] = chunk_idx - self._start_time[name] = self._time_func() - - def start(self): - """Start streaming by writing to the send characteristic.""" - self._transmit_thread.start() - self._ble_device.char_write(self._ble_params['send'], - value=self._ble_params['stream_on'], - wait_for_response=False) - - def stop(self): - """Stop streaming by writing to the send characteristic.""" - self._ble_device.char_write(self._ble_params["send"], - value=self._ble_params["stream_off"], - wait_for_response=False) - - def send_command(self, value): - """Write some value to the send characteristic.""" - self._ble_device.char_write(self._ble_params["send"], - value=value, - wait_for_response=False) - - def disconnect(self): - """Disconnect from the BLE device and stop the adapter. - - Note: - After disconnection, `start` will not resume streaming. - - TODO: - * enable device reconnect with `connect` - """ - self.stop() # stream_off command - self._ble_device.disconnect() # BLE disconnect - self._adapter.stop() - - def connect(self): - """Establish connection to BLE device (prior to `start`). - - Starts the `pygatt` adapter, resolves the device address if necessary, - connects to the device, and subscribes to the channels specified in the - device parameters. - """ - adapter_started = False - while not adapter_started: - try: - self._adapter.start() - adapter_started = True - except (ExpectedResponseTimeout, StructError): - continue - - if self._address is None: - # get the device address if none was provided - self._address = self._resolve_address(self._device.NAME) - try: - self._ble_device = self._adapter.connect(self._address, - address_type=self._ble_params['address_type'], - interval_min=self._ble_params['interval_min'], - interval_max=self._ble_params['interval_max']) - - except pygatt.exceptions.NotConnectedError: - e_msg = "Unable to connect to device at address {}" \ - .format(self._address) - raise(IOError(e_msg)) - - # initialize LSL outlets and packet handler - self._init_lsl_outlets() - self._packet_handler = self._device.PacketHandler(self) - - # subscribe to receive characteristic notifications - process_packet = self._packet_handler.process_packet - for name in self._subscriptions: - try: - uuids = [self._ble_params[name] + ''] - except TypeError: - uuids = self._ble_params[name] - for uuid in uuids: - if uuid: - self._ble_device.subscribe(uuid, callback=process_packet) - # subscribe to recieve simblee command from ganglion doc - - def _resolve_address(self, name): - list_devices = self._adapter.scan(timeout=self._scan_timeout) - for device in list_devices: - if name in device['name']: - return device['address'] - raise(ValueError("No devices found with name `{}`".format(name))) - - def _transmit_chunks(self): - """TODO: missing chunk vs. missing sample""" - # nominal duration of chunks for progressing non-internal timestamps - chunk_period = {name: (self._stream_params["chunk_size"][name] - / self._stream_params["nominal_srate"][name]) - for name in self._subscriptions - if not self._internal_timestamps[name]} - first_idx = self._first_chunk_idxs - while True: - name, chunk_idx, chunk = self._transmit_queue.get() - self._chunks[name][:, :] = chunk - - # update chunk index records and report missing chunks - # passing chunk_idx=-1 to the queue averts this (ex. status stream) - if not chunk_idx == -1: - if self._chunk_idxs[name] == 0: - self._init_timestamp(name, chunk_idx) - self._chunk_idxs[name] = chunk_idx - 1 - if not chunk_idx == self._chunk_idxs[name] + 1: - print("Missing {} chunk {}: {}" - .format(name, chunk_idx, self._chunk_idxs[name])) - self._chunk_idxs[name] = chunk_idx - else: - # track number of received chunks for non-indexed streams - self._chunk_idxs[name] += 1 - - # generate timestamp; either internally or - if self._internal_timestamps[name]: - timestamp = self._time_func() - else: - timestamp = chunk_period[name] * (chunk_idx - first_idx[name]) - timestamp += self._start_time[name] - - self._push_func[name](name, timestamp) - - @property - def backend(self): - """The name of the `pygatt` backend used by the instance.""" - return self._backend - - @property - def address(self): - """The MAC address of the device.""" - return self._address - - -class Dummy(BaseStreamer): - """Mimicks a device and pushes local data into an LSL outlet. - - TODO: - * verify timestamps/delays (seems too fast in plot.Lines) - """ - - def __init__(self, device, chunk_iterator=None, subscriptions=None, - autostart=True, **kwargs): - """Construct a `Dummy` instance. - - Args: - device: BLE device to impersonate (i.e. from `ble2lsl.devices`). - chunk_iterator (generator): Class that iterates through chunks. - autostart (bool): Whether to start streaming on instantiation. - """ - nominal_srate = device.PARAMS["streams"]["nominal_srate"] - if subscriptions is None: - subscriptions = get_default_subscriptions(device) - subscriptions = {name for name in subscriptions - if nominal_srate[name] > 0} - - BaseStreamer.__init__(self, device=device, subscriptions=subscriptions, - **kwargs) - - self._address = "DUMMY" - self._init_lsl_outlets() - - chunk_shapes = {name: self._chunks[name].shape - for name in self._subscriptions} - self._delays = {name: 1 / (nominal_srate[name] / chunk_shapes[name][1]) - for name in self._subscriptions} - - # generate or load fake data - if chunk_iterator is None: - chunk_iterator = NoisySinusoids - self._chunk_iter = {name: chunk_iterator(chunk_shapes[name], - nominal_srate[name]) - for name in self._subscriptions} - - # threads to mimic incoming BLE data - self._threads = {name: threading.Thread(target=self._stream, - kwargs=dict(name=name)) - for name in self._subscriptions} - - if autostart: - self.start() - - def start(self): - """Start pushing data into the LSL outlet.""" - self._proceed = True - for name in self._subscriptions: - self._threads[name].start() - - def stop(self): - """Stop pushing data. Ends execution of chunk streaming threads. - - Restart requires a new `Dummy` instance. - """ - self._proceed = False - - def _stream(self, name): - """Run in thread to mimic periodic hardware input.""" - for chunk in self._chunk_iter[name]: - if not self._proceed: - break - self._chunks[name] = chunk - timestamp = time.time() - self._push_func[name](name, timestamp) - time.sleep(self._delays[name]) - - def make_chunk(self, chunk_ind): - """Prepare a chunk from the totality of local data. - - TODO: - * replaced when using an iterator - """ - self._chunks - # TODO: more realistic timestamps - timestamp = self._time_func() - self._timestamps = np.array([timestamp]*self._chunk_size) - - -def stream_idxs_zeros(subscriptions): - """Initialize an integer index for each subscription.""" - idxs = {name: 0 for name in subscriptions} - return idxs - - -def empty_chunks(stream_params, subscriptions): - """Initialize an empty chunk array for each subscription.""" - chunks = {name: np.zeros((stream_params["chunk_size"][name], - stream_params["channel_count"][name]), - dtype=stream_params["numpy_dtype"][name]) - for name in subscriptions} - return chunks - - -def get_default_subscriptions(device): - # look for default list; if unavailable, subscribe to all - try: - subscriptions = device.DEFAULT_SUBSCRIPTIONS - except AttributeError: - subscriptions = device.STREAMS - return subscriptions - - -class ChunkIterator: - """Generator object (i.e. iterator) that yields chunks. - - Placeholder until I figure out how this might work as a base class. - """ - - def __init__(self, chunk_shape, srate): - self._chunk_shape = chunk_shape - self._srate = srate - - -class NoisySinusoids(ChunkIterator): - """Iterator class to provide noisy sinusoidal chunks of data.""" - - def __init__(self, chunk_shape, srate, freqs=[5, 10, 12, 20], noise_std=1): - super().__init__(chunk_shape=chunk_shape, srate=srate) - self._ang_freqs = 2 * np.pi * np.array(freqs) - self._speriod = 1 / self._srate - self._chunk_t_incr = (1 + chunk_shape[0]) / self._srate - self._freq_amps = np.random.randint(1, 5, len(freqs)) - self._noise_std = noise_std - - def __iter__(self): - self._t = (np.arange(self._chunk_shape[0]).reshape((-1, 1)) - * self._speriod) - return self - - def __next__(self): - # start with noise - chunk = np.random.normal(0, self._noise_std, self._chunk_shape) - - # sum frequencies with random amplitudes - for i, freq in enumerate(self._ang_freqs): - chunk += self._freq_amps[i] * np.sin(freq * self._t) - - self._t += self._chunk_t_incr - - return chunk diff --git a/ble2lsl/devices/__init__.py b/ble2lsl/devices/__init__.py deleted file mode 100644 index e925783..0000000 --- a/ble2lsl/devices/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""BLE/LSL interfacing parameters for specific devices. - -TODO: - * Simple class (or specification/template) for device parameters -""" diff --git a/ble2lsl/devices/device.py b/ble2lsl/devices/device.py deleted file mode 100644 index 87ec2e8..0000000 --- a/ble2lsl/devices/device.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Specification/abstract parent for BLE2LSL device files. - -To support a new BLE device, create a new module in `ble2lsl.devices` and -include the following module-level attributes: - - NAME (str): The device's name. - Used for automatically finding the device address, so should be some - substring of the name found by `pygatt`'s '`adapter.scan`. - MANUFACTURER (str): The name of the device's manufacturer. - STREAMS (List[str]): Names of data sources provided by the device. - PARAMS (dict): Device-specific parameters, containing two dictionaries: - streams (dict): Contains general per-stream parameters. Each member - should be a dictionary with the names of the potential stream - subscriptions as keys for the stream's respective value(s) of the - parameter. (See `muse2016` module for an example of how to do this - without being very messy.) - - type (dict[str]): The type of data to be streamed. Should be an - `XDF` format string when possible. - channel_count (dict[int]): The number of channels in the stream. - nominal_srate (dict[float]): The stream's design sample rate (Hz). - Used to generate dejittered timestamps for incoming samples. - channel_format (dict[str]): The LSL datatype of the stream's data. - LSL streams are of a single datatype, so one string should be - given for each stream. - numpy_dtype (dict[str or numpy.dtype]): The Numpy datatype for - stream data. - This will not always be identical to `channel_format`; for - example, `'string'` is the string type in LSL but not in Numpy. - units (dict[Iterable[str]]): Units for each channel in the stream. - ch_names (dict[Iterable[str]]): Name of each channel in the stream. - chunk_size (dict[int]): No. of samples pushed at once through LSL. - - ble (dict): Contains BLE-specific device parameters. Must contain - keys for each of the streams named in `STREAMS`, with values of - one or more characteristic UUIDs that `ble2lsl` must subscribe - to when providing that stream. Some of these may be redundant or - empty strings, as long as the device's `PacketHandler` separates - incoming packets' data into respective streams (for example, - see `ganglion`). - - address_type (BLEAddressType): One of `BLEAddressType.public` or - `BLEAddressType.random`, depending on the device. - interval_min (int): Minimum BLE connection interval. - interval_max (int): Maximum BLE connection interval. - Connection intervals are multiples of 1.25 ms. A good choice of - `interval_min` and `interval_max` may be necessary to prevent - dropped packets. - send (str): UUID for the send/control characteristic. - Control commands (e.g. to start streaming) are written to this - characteristic. - stream_on: Command to write to start streaming. - stream_off: Command to write to end streaming. - - -As devices typically do not share a common format for the packets sent over -BLE, include a subclass of `PacketHandler` in the device file. This subclass -should provide a `process_packet` method, to which BLE2LSL will pass incoming -packets and the handles of the BLE characteristics from which they were -received. This method should perform any necessary processing on the packets, -delegating to other methods in the device file if necessary. After filling the -`_chunks` and `_chunk_idxs` attributes for a given stream, the chunk may be -enqueued for processing by `ble2lsl` by passing the stream name to -`_enqueue_chunk()`. - -Summary of necessary inclusions to support a data source provided by a device: - * A name for the stream in `STREAMS`. - * Corresponding entries in each member of `PARAMS["streams"]`, and an entry - in `PARAMS["ble"]` containing one or more UUIDs for characteristics - that constitute a BLE subscription to the data source. - * A means for `process_packet` to map an incoming packet to the appropriate - stream name, typically using its handle; by passing this name with the - enqueued chunk, this ensures it is pushed to the appropriate LSL stream - by `ble2lsl.Streamer`. - * Methods to process the contents of the packets and render them into - appropriately-sized chunks as specified by the `channel_count` and - `chunk_size` members of `PARAMS["streams"]`, and to return the chunks - at the appropriate time (e.g. if multiple packets must be received to - fill a single chunk.) - -See `ble2lsl.devices.muse2016` for an example device implementation. - -When a user instantiates `ble2lsl.Streamer`, they may provide a list -`DEFAULT_SUBSCRIPTIONS` of stream names to which to subscribe, which should be -some subset of the `STREAMS` attribute of the respective device file. - -.. _XDF: - https://github.com/sccn/xdf/wiki/Specifications -""" - -from ble2lsl import empty_chunks, stream_idxs_zeros - -import numpy as np - - -class BasePacketHandler: - """Abstract parent for device-specific packet manager classes.""" - - def __init__(self, stream_params, streamer, **kwargs): - """Construct a `PacketHandler` instance. - - Args: - stream_params (dict): Stream parameters. - Pass `PARAMS["streams"]` from the device file. - streamer (ble2lsl.Streamer): The master `Streamer` instance. - """ - self._streamer = streamer - self._transmit_queue = streamer._transmit_queue - - subscriptions = self._streamer.subscriptions - self._chunks = empty_chunks(stream_params, subscriptions) - self._chunk_idxs = stream_idxs_zeros(subscriptions) - - def process_packet(self, handle, packet): - """BLE2LSL passes incoming BLE packets to this method for parsing.""" - raise NotImplementedError() - - def _enqueue_chunk(self, name): - """Ensure copies are returned.""" - self._transmit_queue.put((name, - self._chunk_idxs[name], - np.copy(self._chunks[name]) - )) diff --git a/ble2lsl/devices/ganglion/LICENSE b/ble2lsl/devices/ganglion/LICENSE deleted file mode 100644 index aff6d3c..0000000 --- a/ble2lsl/devices/ganglion/LICENSE +++ /dev/null @@ -1,24 +0,0 @@ -The contents of ganglion.py are derived from software distributed under the -following license: - -The MIT License (MIT) - -Copyright (c) 2015 OpenBCI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/ble2lsl/devices/ganglion/__init__.py b/ble2lsl/devices/ganglion/__init__.py deleted file mode 100644 index a699954..0000000 --- a/ble2lsl/devices/ganglion/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ - -from ble2lsl.devices.ganglion.ganglion import * diff --git a/ble2lsl/devices/ganglion/ganglion.py b/ble2lsl/devices/ganglion/ganglion.py deleted file mode 100644 index 65c55fd..0000000 --- a/ble2lsl/devices/ganglion/ganglion.py +++ /dev/null @@ -1,372 +0,0 @@ -"""Interfacing parameters for the OpenBCI Ganglion Board.""" - -from ble2lsl.devices.device import BasePacketHandler -from ble2lsl.utils import bad_data_size, dict_partial_from_keys - -import struct -from warnings import warn - -import numpy as np -from pygatt import BLEAddressType - -NAME = "Ganglion" - -MANUFACTURER = "OpenBCI" - -STREAMS = ["EEG", "accelerometer", "messages"] -"""Data provided by the OpenBCI Ganglion, and available for subscription.""" - -DEFAULT_SUBSCRIPTIONS = ["EEG", "messages"] -"""Streams to which to subscribe by default.""" - -# for constructing dicts with STREAMS as keys -streams_dict = dict_partial_from_keys(STREAMS) - -PARAMS = dict( - streams=dict( - type=streams_dict(STREAMS), # same as stream names - channel_count=streams_dict([4, 3, 1]), - nominal_srate=streams_dict([200, 10, 0.0]), - channel_format=streams_dict(['float32', 'float32', 'string']), - numpy_dtype=streams_dict(['float32', 'float32', 'object']), - units=streams_dict([('uV',) * 4, ('g\'s',) * 3, ('',)]), - ch_names=streams_dict([('A', 'B', 'C', 'D'), ('x', 'y', 'z'), - ('message',)]), - chunk_size=streams_dict([1, 1, 1]), - ), - ble=dict( - address_type=BLEAddressType.random, - # service='fe84', - interval_min=6, # OpenBCI suggest 9 - interval_max=11, # suggest 10 - - # receive characteristic UUIDs - EEG=["2d30c082f39f4ce6923f3484ea480596"], - accelerometer='', # placeholder; already subscribed through eeg - messages='', # placeholder; subscription not required - - # send characteristic UUID and commands - send="2d30c083f39f4ce6923f3484ea480596", - stream_on=b'b', - stream_off=b's', - accelerometer_on=b'n', - accelerometer_off=b'N', - # impedance_on=b'z', - # impedance_off=b'Z', - - # other characteristics - # disconnect="2d30c084f39f4ce6923f3484ea480596", - ), -) -"""OpenBCI Ganglion LSL- and BLE-related parameters.""" - -INT_SIGN_BYTE = (b'\x00', b'\xff') -SCALE_FACTOR = streams_dict([1.2 / (8388608.0 * 1.5 * 51.0), - 0.016, - 1 # not used (messages) - ]) -"""Scale factors for conversion of EEG and accelerometer data to mV.""" - -ID_TURNOVER = streams_dict([201, 10]) -"""The number of samples processed before the packet ID cycles back to zero.""" - - -class PacketHandler(BasePacketHandler): - """Process packets from the OpenBCI Ganglion into chunks.""" - - def __init__(self, streamer, **kwargs): - super().__init__(PARAMS["streams"], streamer, **kwargs) - - self._sample_ids = streams_dict([-1] * len(STREAMS)) - - if "EEG" in self._streamer.subscriptions: - self._last_eeg_data = np.zeros(self._chunks["EEG"].shape[1]) - - if "messages" in self._streamer.subscriptions: - self._chunks["messages"][0] = "" - self._chunk_idxs["messages"] = -1 - - if "accelerometer" in self._streamer.subscriptions: - # queue accelerometer_on command - self._streamer.send_command(PARAMS["ble"]["accelerometer_on"]) - - # byte ID ranges for parsing function selection - self._byte_id_ranges = {(101, 200): self._parse_compressed_19bit, - (0, 0): self._parse_uncompressed, - (206, 207): self._parse_message, - (1, 100): self._parse_compressed_18bit, - (201, 205): self._parse_impedance, - (208, -1): self._unknown_packet_warning} - - def process_packet(self, handle, packet): - """Process incoming data packet. - - Calls the corresponding parsing function depending on packet format. - """ - start_byte = packet[0] - for r in self._byte_id_ranges: - if start_byte >= r[0] and start_byte <= r[1]: - self._byte_id_ranges[r](start_byte, packet[1:]) - break - - def _update_counts_and_enqueue(self, name, sample_id): - """Update last packet ID and dropped packets""" - if self._sample_ids[name] == -1: - self._sample_ids[name] = sample_id - self._chunk_idxs[name] = 1 - return - # sample IDs loops every 101 packets - self._chunk_idxs[name] += sample_id - self._sample_ids[name] - if sample_id < self._sample_ids[name]: - self._chunk_idxs[name] += ID_TURNOVER[name] - self._sample_ids[name] = sample_id - - if name == "EEG": - self._chunks[name][0, :] = np.copy(self._last_eeg_data) - self._chunks[name] *= SCALE_FACTOR[name] - self._enqueue_chunk(name) - - def _unknown_packet_warning(self, start_byte, packet): - """Print if incoming byte ID is unknown.""" - warn("Unknown Ganglion packet byte ID: {}".format(start_byte)) - - def _parse_message(self, start_byte, packet): - """Parse a partial ASCII message.""" - if "messages" in self._streamer.subscriptions: - self._chunks["messages"] += str(packet) - if start_byte == 207: - self._enqueue_chunk("messages") - self._chunks["messages"][0] = "" - - def _parse_uncompressed(self, packet_id, packet): - """Parse a raw uncompressed packet.""" - if bad_data_size(packet, 19, "uncompressed data"): - return - # 4 channels of 24bits - self._last_eeg_data[:] = [int_from_24bits(packet[i:i + 3]) - for i in range(0, 12, 3)] - # = np.array([chan_data], dtype=np.float32).T - self._update_counts_and_enqueue("EEG", packet_id) - - def _update_data_with_deltas(self, packet_id, deltas): - for delta_id in [0, 1]: - # convert from packet to sample ID - sample_id = (packet_id - 1) * 2 + delta_id + 1 - # 19bit packets hold deltas between two samples - self._last_eeg_data += np.array(deltas[delta_id]) - self._update_counts_and_enqueue("EEG", sample_id) - - def _parse_compressed_19bit(self, packet_id, packet): - """Parse a 19-bit compressed packet without accelerometer data.""" - if bad_data_size(packet, 19, "19-bit compressed data"): - return - - packet_id -= 100 - # should get 2 by 4 arrays of uncompressed data - deltas = decompress_deltas_19bit(packet) - self._update_data_with_deltas(packet_id, deltas) - - def _parse_compressed_18bit(self, packet_id, packet): - """ Dealing with "18-bit compression without Accelerometer" """ - if bad_data_size(packet, 19, "18-bit compressed data"): - return - - # set appropriate accelerometer byte - id_ones = packet_id % 10 - 1 - if id_ones in [0, 1, 2]: - value = int8_from_byte(packet[18]) - self._chunks["accelerometer"][0, id_ones] = value - if id_ones == 2: - self._update_counts_and_enqueue("accelerometer", - packet_id // 10) - - # deltas: should get 2 by 4 arrays of uncompressed data - deltas = decompress_deltas_18bit(packet[:-1]) - self._update_data_with_deltas(packet_id, deltas) - - def _parse_impedance(self, packet_id, packet): - """Parse impedance data. - - After turning on impedance checking, takes a few seconds to complete. - """ - raise NotImplementedError # until this is sorted out... - - if packet[-2:] != 'Z\n': - print("Wrong format for impedance: not ASCII ending with 'Z\\n'") - - # convert from ASCII to actual value - imp_value = int(packet[:-2]) - # from 201 to 205 codes to the right array size - self.last_impedance[packet_id - 201] = imp_value - self.push_sample(packet_id - 200, self._data, - self.last_accelerometer, self.last_impedance) - - -def int_from_24bits(unpacked): - """Convert 24-bit data coded on 3 bytes to a proper integer.""" - if bad_data_size(unpacked, 3, "3-byte buffer"): - raise ValueError("Bad input size for byte conversion.") - - # FIXME: quick'n dirty, unpack wants strings later on - int_bytes = INT_SIGN_BYTE[unpacked[0] > 127] + struct.pack('3B', *unpacked) - - # unpack little endian(>) signed integer(i) (-> platform independent) - int_unpacked = struct.unpack('>i', int_bytes)[0] - - return int_unpacked - - -def int32_from_19bit(three_byte_buffer): - """Convert 19-bit data coded on 3 bytes to a proper integer.""" - if bad_data_size(three_byte_buffer, 3, "3-byte buffer"): - raise ValueError("Bad input size for byte conversion.") - - # if LSB is 1, negative number - if three_byte_buffer[2] & 0x01 > 0: - prefix = 0b1111111111111 - int32 = ((prefix << 19) | (three_byte_buffer[0] << 16) - | (three_byte_buffer[1] << 8) | three_byte_buffer[2]) \ - | ~0xFFFFFFFF - else: - prefix = 0 - int32 = (prefix << 19) | (three_byte_buffer[0] << 16) \ - | (three_byte_buffer[1] << 8) | three_byte_buffer[2] - - return int32 - - -def int32_from_18bit(three_byte_buffer): - """Convert 18-bit data coded on 3 bytes to a proper integer.""" - if bad_data_size(three_byte_buffer, 3, "3-byte buffer"): - raise ValueError("Bad input size for byte conversion.") - - # if LSB is 1, negative number, some hasty unsigned to signed conversion to do - if three_byte_buffer[2] & 0x01 > 0: - prefix = 0b11111111111111 - int32 = ((prefix << 18) | (three_byte_buffer[0] << 16) - | (three_byte_buffer[1] << 8) | three_byte_buffer[2]) \ - | ~0xFFFFFFFF - else: - prefix = 0 - int32 = (prefix << 18) | (three_byte_buffer[0] << 16) \ - | (three_byte_buffer[1] << 8) | three_byte_buffer[2] - - return int32 - - -def int8_from_byte(byte): - """Convert one byte to signed integer.""" - if byte > 127: - return (256 - byte) * (-1) - else: - return byte - - -def decompress_deltas_19bit(buffer): - """Parse packet deltas from 19-bit compression format.""" - if bad_data_size(buffer, 19, "19-byte compressed packet"): - raise ValueError("Bad input size for byte conversion.") - - deltas = np.zeros((2, 4)) - - # Sample 1 - Channel 1 - minibuf = [(buffer[0] >> 5), - ((buffer[0] & 0x1F) << 3 & 0xFF) | (buffer[1] >> 5), - ((buffer[1] & 0x1F) << 3 & 0xFF) | (buffer[2] >> 5)] - deltas[0][0] = int32_from_19bit(minibuf) - - # Sample 1 - Channel 2 - minibuf = [(buffer[2] & 0x1F) >> 2, - (buffer[2] << 6 & 0xFF) | (buffer[3] >> 2), - (buffer[3] << 6 & 0xFF) | (buffer[4] >> 2)] - deltas[0][1] = int32_from_19bit(minibuf) - - # Sample 1 - Channel 3 - minibuf = [((buffer[4] & 0x03) << 1 & 0xFF) | (buffer[5] >> 7), - ((buffer[5] & 0x7F) << 1 & 0xFF) | (buffer[6] >> 7), - ((buffer[6] & 0x7F) << 1 & 0xFF) | (buffer[7] >> 7)] - deltas[0][2] = int32_from_19bit(minibuf) - - # Sample 1 - Channel 4 - minibuf = [((buffer[7] & 0x7F) >> 4), - ((buffer[7] & 0x0F) << 4 & 0xFF) | (buffer[8] >> 4), - ((buffer[8] & 0x0F) << 4 & 0xFF) | (buffer[9] >> 4)] - deltas[0][3] = int32_from_19bit(minibuf) - - # Sample 2 - Channel 1 - minibuf = [((buffer[9] & 0x0F) >> 1), - (buffer[9] << 7 & 0xFF) | (buffer[10] >> 1), - (buffer[10] << 7 & 0xFF) | (buffer[11] >> 1)] - deltas[1][0] = int32_from_19bit(minibuf) - - # Sample 2 - Channel 2 - minibuf = [((buffer[11] & 0x01) << 2 & 0xFF) | (buffer[12] >> 6), - (buffer[12] << 2 & 0xFF) | (buffer[13] >> 6), - (buffer[13] << 2 & 0xFF) | (buffer[14] >> 6)] - deltas[1][1] = int32_from_19bit(minibuf) - - # Sample 2 - Channel 3 - minibuf = [((buffer[14] & 0x38) >> 3), - ((buffer[14] & 0x07) << 5 & 0xFF) | ((buffer[15] & 0xF8) >> 3), - ((buffer[15] & 0x07) << 5 & 0xFF) | ((buffer[16] & 0xF8) >> 3)] - deltas[1][2] = int32_from_19bit(minibuf) - - # Sample 2 - Channel 4 - minibuf = [(buffer[16] & 0x07), buffer[17], buffer[18]] - deltas[1][3] = int32_from_19bit(minibuf) - - return deltas - - -def decompress_deltas_18bit(buffer): - """Parse packet deltas from 18-byte compression format.""" - if bad_data_size(buffer, 18, "18-byte compressed packet"): - raise ValueError("Bad input size for byte conversion.") - - deltas = np.zeros((2, 4)) - - # Sample 1 - Channel 1 - minibuf = [(buffer[0] >> 6), - ((buffer[0] & 0x3F) << 2 & 0xFF) | (buffer[1] >> 6), - ((buffer[1] & 0x3F) << 2 & 0xFF) | (buffer[2] >> 6)] - deltas[0][0] = int32_from_18bit(minibuf) - - # Sample 1 - Channel 2 - minibuf = [(buffer[2] & 0x3F) >> 4, - (buffer[2] << 4 & 0xFF) | (buffer[3] >> 4), - (buffer[3] << 4 & 0xFF) | (buffer[4] >> 4)] - deltas[0][1] = int32_from_18bit(minibuf) - - # Sample 1 - Channel 3 - minibuf = [(buffer[4] & 0x0F) >> 2, - (buffer[4] << 6 & 0xFF) | (buffer[5] >> 2), - (buffer[5] << 6 & 0xFF) | (buffer[6] >> 2)] - deltas[0][2] = int32_from_18bit(minibuf) - - # Sample 1 - Channel 4 - minibuf = [(buffer[6] & 0x03), buffer[7], buffer[8]] - deltas[0][3] = int32_from_18bit(minibuf) - - # Sample 2 - Channel 1 - minibuf = [(buffer[9] >> 6), - ((buffer[9] & 0x3F) << 2 & 0xFF) | (buffer[10] >> 6), - ((buffer[10] & 0x3F) << 2 & 0xFF) | (buffer[11] >> 6)] - deltas[1][0] = int32_from_18bit(minibuf) - - # Sample 2 - Channel 2 - minibuf = [(buffer[11] & 0x3F) >> 4, - (buffer[11] << 4 & 0xFF) | (buffer[12] >> 4), - (buffer[12] << 4 & 0xFF) | (buffer[13] >> 4)] - deltas[1][1] = int32_from_18bit(minibuf) - - # Sample 2 - Channel 3 - minibuf = [(buffer[13] & 0x0F) >> 2, - (buffer[13] << 6 & 0xFF) | (buffer[14] >> 2), - (buffer[14] << 6 & 0xFF) | (buffer[15] >> 2)] - deltas[1][2] = int32_from_18bit(minibuf) - - # Sample 2 - Channel 4 - minibuf = [(buffer[15] & 0x03), buffer[16], buffer[17]] - deltas[1][3] = int32_from_18bit(minibuf) - - return deltas diff --git a/ble2lsl/devices/muse2016.py b/ble2lsl/devices/muse2016.py deleted file mode 100644 index 390fce2..0000000 --- a/ble2lsl/devices/muse2016.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Interfacing parameters for the Muse headband (2016 version). - -More information on the data provided by the Muse 2016 headband can be found -at `Available Data - Muse Direct`_ - -TODO: - * Figure out maximum string size for status messages, or split into fields - (can't send dict over LSL) - * return standard acceleration units and not g's... - * verify telemetry and IMU conversions and units - * DRL/REF characteristic - * don't use lambdas for CONVERT_FUNCS? - * save Muse address to minimize connect time? - * packet ID rollover (uint16) -- generalize in device file? - -.. _Available Data - Muse Direct: - http://developer.choosemuse.com/tools/windows-tools/available-data-muse-direct -""" - -from ble2lsl.devices.device import BasePacketHandler -from ble2lsl.utils import dict_partial_from_keys - -import bitstring -import numpy as np -from pygatt import BLEAddressType - -NAME = 'Muse' -MANUFACTURER = 'Interaxon' - -STREAMS = ['EEG', 'accelerometer', 'gyroscope', 'telemetry', 'status'] -"""Data sources provided by the Muse 2016 headset.""" - -DEFAULT_SUBSCRIPTIONS = STREAMS -"""Sources to which to subscribe by default.""" - -# for constructing dicts with STREAMS as keys -streams_dict = dict_partial_from_keys(STREAMS) - -PARAMS = dict( - streams=dict( - type=streams_dict(STREAMS), # identity mapping. best solution? - channel_count=streams_dict([5, 3, 3, 4, 1]), - nominal_srate=streams_dict([256, 52, 52, 0.1, 0.0]), - channel_format=streams_dict(['float32', 'float32', 'float32', - 'float32', 'string']), - numpy_dtype=streams_dict(['float32', 'float32', 'float32', 'float32', - 'object']), - units=streams_dict([('uV',) * 5, - ('g\'s',) * 3, - ('deg/s',) * 3, - ('%', 'mV', 'mV', 'C'), - ('',)]), - ch_names=streams_dict([('TP9', 'AF7', 'AF8', 'TP10', 'Right AUX'), - ('x', 'y', 'z'), - ('x', 'y', 'z'), - ('battery', 'fuel_gauge', 'adc_volt', - 'temperature'), - ('message',)]), - chunk_size=streams_dict([12, 3, 3, 1, 1]), - ), - - ble=dict( - address_type=BLEAddressType.public, - interval_min=60, # pygatt default, seems fine - interval_max=76, # pygatt default - - # receive characteristic UUIDs - EEG=['273e0003-4c4d-454d-96be-f03bac821358', - '273e0004-4c4d-454d-96be-f03bac821358', - '273e0005-4c4d-454d-96be-f03bac821358', - '273e0006-4c4d-454d-96be-f03bac821358', - '273e0007-4c4d-454d-96be-f03bac821358'], - # reference=['273e0008-4c4d-454d-96be-f03bac821358'], - accelerometer='273e000a-4c4d-454d-96be-f03bac821358', - gyroscope='273e0009-4c4d-454d-96be-f03bac821358', - telemetry='273e000b-4c4d-454d-96be-f03bac821358', - status='273e0001-4c4d-454d-96be-f03bac821358', # same as send - - # send characteristic UUID and commands - send='273e0001-4c4d-454d-96be-f03bac821358', - stream_on=(0x02, 0x64, 0x0a), # b'd' - stream_off=(0x02, 0x68, 0x0a), # ? - # keep_alive=(0x02, 0x6b, 0x0a), # (?) b'k' - # request_info=(0x03, 0x76, 0x31, 0x0a), - # request_status=(0x02, 0x73, 0x0a), - # reset=(0x03, 0x2a, 0x31, 0x0a) - ) -) -"""Muse 2016 LSL- and BLE-related parameters.""" - -HANDLE_NAMES = {14: "status", 26: "telemetry", 23: "accelerometer", - 20: "gyroscope", 32: "EEG", 35: "EEG", 38: "EEG", 41: "EEG", - 44: "EEG"} -"""Stream name associated with each packet handle.""" - -PACKET_FORMATS = streams_dict(['uint:16' + ',uint:12' * 12, - 'uint:16' + ',int:16' * 9, - 'uint:16' + ',int:16' * 9, - 'uint:16' + ',uint:16' * 4, - ','.join(['uint:8'] * 20)]) -"""Byte formats of the incoming packets.""" - -CONVERT_FUNCS = streams_dict([lambda data: 0.48828125 * (data - 2048), - lambda data: 0.0000610352 * data.reshape((3, 3)), - lambda data: 0.0074768 * data.reshape((3, 3)), - lambda data: np.array([data[0] / 512, - 2.2 * data[1], - data[2], data[3]]).reshape((1, 4)), - lambda data: None]) -"""Functions to render unpacked data into the appropriate shape and units.""" - -EEG_HANDLE_CH_IDXS = {32: 0, 35: 1, 38: 2, 41: 3, 44: 4} -EEG_HANDLE_RECEIVE_ORDER = [44, 41, 38, 32, 35] -"""Channel indices and receipt order of EEG packets.""" - - -class PacketHandler(BasePacketHandler): - """Process packets from the Muse 2016 headset into chunks.""" - - def __init__(self, streamer, **kwargs): - super().__init__(PARAMS["streams"], streamer, **kwargs) - - if "status" in self._streamer.subscriptions: - self._chunks["status"][0] = "" - self._chunk_idxs["status"] = -1 - - def process_packet(self, handle, packet): - """Unpack, convert, and return packet contents.""" - name = HANDLE_NAMES[handle] - unpacked = _unpack(packet, PACKET_FORMATS[name]) - - if name not in self._streamer.subscriptions: - return - - if name == "status": - self._process_status(unpacked) - else: - data = np.array(unpacked[1:], - dtype=PARAMS["streams"]["numpy_dtype"][name]) - - if name == "EEG": - idx = EEG_HANDLE_CH_IDXS[handle] - self._chunks[name][:, idx] = CONVERT_FUNCS[name](data) - if not handle == EEG_HANDLE_RECEIVE_ORDER[-1]: - return - else: - try: - self._chunks[name][:, :] = CONVERT_FUNCS[name](data) - except ValueError: - print(name) - - self._chunk_idxs[name] = unpacked[0] - self._enqueue_chunk(name) - - def _process_status(self, unpacked): - message_chars = [chr(i) for i in unpacked[1:]] - status_message_partial = "".join(message_chars)[:unpacked[0]] - self._chunks["status"] += status_message_partial.replace('\n', '') - if status_message_partial[-1] == '}': - # ast.literal_eval(self._message)) - # parse and enqueue dict - self._enqueue_chunk("status") - self._chunks["status"][0] = "" - - -def _unpack(packet, packet_format): - packet_bits = bitstring.Bits(bytes=packet) - unpacked = packet_bits.unpack(packet_format) - return unpacked diff --git a/ble2lsl/utils.py b/ble2lsl/utils.py deleted file mode 100644 index d7b2217..0000000 --- a/ble2lsl/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Utilities for use within BLE2LSL.""" - -from warnings import warn - -def invert_map(dict_): - """Invert the keys and values in a dict.""" - inverted = {v: k for k, v in dict_.items()} - return inverted - - -def bad_data_size(data, size, data_type="packet"): - """Return `True` if length of `data` is not `size`.""" - if len(data) != size: - warn('Wrong size for {}, {} instead of {} bytes' - .format(data_type, len(data), size)) - return True - return False - - -def dict_partial_from_keys(keys): - """Return a function that constructs a dict with predetermined keys.""" - def dict_partial(values): - return dict(zip(keys, values)) - return dict_partial diff --git a/example_plot.py b/scripts/example_plot.py similarity index 100% rename from example_plot.py rename to scripts/example_plot.py diff --git a/example_plot_dummy.py b/scripts/example_plot_dummy.py similarity index 100% rename from example_plot_dummy.py rename to scripts/example_plot_dummy.py diff --git a/scripts/example_plot_psd.py b/scripts/example_plot_psd.py new file mode 100644 index 0000000..6539bba --- /dev/null +++ b/scripts/example_plot_psd.py @@ -0,0 +1,16 @@ +"""Plot time series data streamed through a dummy LSL outlet. +""" + +import ble2lsl +from ble2lsl.devices import muse2016 +from wizardhat import acquire, plot, transform + +device = muse2016 +plot_stream = "EEG" + +if __name__ == '__main__': + dummy_outlet = ble2lsl.Dummy(device) + receiver = acquire.Receiver() + psd = transform.PSD(receiver.buffers[plot_stream]) + plot.Spectra(psd.buffer_out) + #plot.Lines(receiver.buffers[plot_stream]) diff --git a/scripts/example_psd.py b/scripts/example_psd.py new file mode 100644 index 0000000..104c665 --- /dev/null +++ b/scripts/example_psd.py @@ -0,0 +1,15 @@ +import ble2lsl +from ble2lsl.devices import muse2016 +from wizardhat import acquire, plot, transform + +import pylsl as lsl + +device = muse2016 +plot_stream = 'EEG' + +if __name__ == '__main__': + streamer = ble2lsl.Dummy(device) + receiver = acquire.Receiver() + psd_transformer = transform.PSD(receiver.buffers['EEG'], n_samples=256) + psd_averaged = transform.MovingAverage(psd_transformer.buffer_out, n_avg=5) + plotter = plot.Spectra(psd_averaged.buffer_out) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e7b83ec --- /dev/null +++ b/setup.py @@ -0,0 +1,131 @@ +import io +import os +import sys +from shutil import rmtree + +from setuptools import find_packages, setup, Command + +# Package meta-data. +NAME = 'WizardHat' +DESCRIPTION = ('Real-time processing and plotting of data streamed over LSL, ' + + 'with a focus on student-led BCI projects.') +URL = 'https://github.com/merlin-neurotech/WizardHat' +EMAIL = '' +AUTHOR = 'Merlin Neurotech' +REQUIRES_PYTHON = '>=3.5.0' +VERSION = '0.2.0' + +# What packages are required for this module to be executed? +REQUIRED = [ + 'ble2lsl', 'numpy==1.14.0', 'scipy==1.0.0', 'pylsl==1.10.5', 'mne==0.15.2', + 'bokeh==0.13.0', +] + +# What packages are optional? +EXTRAS = { + # 'fancy feature': ['django'], +} + +# The rest you shouldn't have to touch too much :) +# ------------------------------------------------ +# Except, perhaps the License and Trove Classifiers! +# If you do change the License, remember to change the Trove Classifier for that! + +here = os.path.abspath(os.path.dirname(__file__)) + +# Import the README and use it as the long-description. +# Note: this will only work if 'README.md' is present in your MANIFEST.in file! +try: + with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = '\n' + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +# Load the package's __version__.py module as a dictionary. +about = {} +if not VERSION: + with open(os.path.join(here, NAME, '__version__.py')) as f: + exec(f.read(), about) +else: + about['__version__'] = VERSION + + +class UploadCommand(Command): + """Support setup.py upload.""" + + description = 'Build and publish the package.' + user_options = [] + + @staticmethod + def status(s): + """Prints things in bold.""" + print('\033[1m{0}\033[0m'.format(s)) + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + try: + self.status('Removing previous builds…') + rmtree(os.path.join(here, 'dist')) + except OSError: + pass + + self.status('Building Source and Wheel (universal) distribution…') + os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) + + self.status('Uploading the package to PyPI via Twine…') + os.system('twine upload dist/*') + + self.status('Pushing git tags…') + os.system('git tag v{0}'.format(about['__version__'])) + os.system('git push --tags') + + sys.exit() + + +# Where the magic happens: +setup( + name=NAME, + version=about['__version__'], + description=DESCRIPTION, + long_description=long_description, + long_description_content_type='text/markdown', + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages(exclude=('tests',)), + # If your package is a single module, use this instead of 'packages': + # py_modules=['mypackage'], + + # entry_points={ + # 'console_scripts': ['mycli=mymodule:cli'], + # }, + install_requires=REQUIRED, + extras_require=EXTRAS, + include_package_data=True, + license='BSD 3-Clause License', + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + 'License :: OSI Approved :: BSD License', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: Implementation :: CPython', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Human Machine Interfaces', + 'Topic :: Scientific/Engineering :: Information Analysis', + 'Topic :: Scientific/Engineering :: Visualization', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + ], + # $ setup.py publish support. + cmdclass={ + 'upload': UploadCommand, + }, +) diff --git a/wizardhat/acquire.py b/wizardhat/acquire.py index 0992282..06d8508 100644 --- a/wizardhat/acquire.py +++ b/wizardhat/acquire.py @@ -26,6 +26,7 @@ from serial.serialutil import SerialException import threading +import warnings import numpy as np import pylsl as lsl @@ -52,7 +53,7 @@ class Receiver: """ def __init__(self, source_id=None, with_types=('',), dejitter=True, - max_chunklen=0, autostart=True, **kwargs): + max_chunklen=0, autostart=True, window=10, **kwargs): """Instantiate LSLStreamer given length of data store in seconds. Args: @@ -101,9 +102,16 @@ def __init__(self, source_id=None, with_types=('',), dejitter=True, # acquire inlet parameters self.sfreq, self.n_chan, self.ch_names, self.buffers = {}, {}, {}, {} - for name, inlet in self._inlets.items(): + for name, inlet in list(self._inlets.items()): info = inlet.info() self.sfreq[name] = info.nominal_srate() + # TODO: include message/status streams? + if self.sfreq[name] < 1 / window: + warn_msg = ("Stream '{}' sampling period larger".format(name) + + " than buffer window: will not be stored") + print(warn_msg) + self._inlets.pop(name) + continue self.n_chan[name] = info.channel_count() self.ch_names[name] = get_ch_names(info) if '' in self.ch_names[name]: @@ -119,6 +127,7 @@ def __init__(self, source_id=None, with_types=('',), dejitter=True, self.sfreq[name], metadata=metadata, label=info.name(), + window=window, **kwargs) self._dejitter = dejitter @@ -127,6 +136,11 @@ def __init__(self, source_id=None, with_types=('',), dejitter=True, if autostart: self.start() + @classmethod + def record(cls, duration, **kwargs): + """Collect data over a finite interval, then stop.""" + return cls(window=duration, store_once=True, **kwargs) + def start(self): """Start data streaming. @@ -165,8 +179,11 @@ def _receive(self, name): #print(name, samples, timestamps) if timestamps: if self._dejitter: - timestamps = self._dejitter_timestamps(name, - timestamps) + try: + timestamps = self._dejitter_timestamps(name, + timestamps) + except IndexError: + print(name) self.buffers[name].update(timestamps, samples) except SerialException: diff --git a/wizardhat/buffers/buffers.py b/wizardhat/buffers/buffers.py index dcef639..2d8e26c 100644 --- a/wizardhat/buffers/buffers.py +++ b/wizardhat/buffers/buffers.py @@ -14,10 +14,10 @@ import wizardhat.utils as utils import atexit -import datetime import json import os import threading +import time import numpy as np @@ -62,7 +62,6 @@ class Buffer: label (str): User-defined addition to standard filename. Attributes: - updated (threading.Event): Flag for threads waiting for data updates. filename (str): Final (generated or specified) filename for writing. metadata (dict): All metadata included in instance's `.json`. @@ -79,7 +78,7 @@ def __init__(self, metadata=None, filename=None, data_dir='./data', # thread control self._lock = threading.Lock() - self.updated = threading.Event() + self.event_hook = utils.EventHook() # file output preparations if not data_dir[0] in ['.', '/']: @@ -122,13 +121,8 @@ def data(self): @property def unstructured(self): - """Return structured as regular `np.ndarray`. - - TODO: - * dtype (np.float64?) based on context - * ValueError if self._data is not structured? - """ - return self.data.view((np.float64, self.n_chan + 1)) + """Return structured data as regular `np.ndarray`.""" + raise NotImplementedError() @property def dtype(self): @@ -168,12 +162,13 @@ def _write_metadata_to_file(self): f.write(metadata_json) def _new_filename(self, data_dir='data', label=''): - date = datetime.date.today().isoformat() + time_str = time.strftime("%y%m%d-%H%M%S", time.localtime()) classname = type(self).__name__ if label: label += '_' - filename = '{}/{}_{}_{}{{}}'.format(data_dir, date, classname, label) + filename = '{}/{}_{}_{}{{}}'.format(data_dir, time_str, + classname, label) # incremental counter to prevent overwrites # (based on existence of metadata file) count = 0 @@ -187,7 +182,6 @@ def __deepcopy__(self, memo): # threading objects cannot be copied normally # & a new filename is needed mask = {'_lock': threading.Lock(), - 'updated': threading.Event(), 'filename': self._new_filename(self._data_dir, self._label)} return utils.deepcopy_mask(self, memo, mask) @@ -202,10 +196,13 @@ class TimeSeries(Buffer): TODO: * Warning (error?) when timestamps are out of order + * Marker channels + * Per-channel units? + * store_once behaviour is a bit awkward. What about long windows? """ - def __init__(self, ch_names, n_samples=2560, record=True, channel_fmt='f8', - **kwargs): + def __init__(self, ch_names, n_samples=2560, sfreq=None, record=True, + channel_fmt='f8', store_once=False, **kwargs): """Create a new `TimeSeries` object. Args: @@ -219,19 +216,29 @@ def __init__(self, ch_names, n_samples=2560, record=True, channel_fmt='f8', Strings should conform to NumPy string datatype specifications; for example, a 64-bit float is specified as `'f8'`. Types may be Python base types (e.g. `float`) or NumPy base dtypes - ( e.g. `np.float64`). + (e.g. `np.float64`). + store_once (bool): Whether to stop storing data when window filled. """ Buffer.__init__(self, **kwargs) - if str(channel_fmt) == channel_fmt: # quack + self.sfreq = sfreq + # if single dtype given, expand to number of channels + try: + np.dtype(channel_fmt) channel_fmt = [channel_fmt] * len(ch_names) + except TypeError: + pass + try: - self._dtype = np.dtype({'names': ["time"] + ch_names, - 'formats': [np.float64] + channel_fmt}) + self._dtype = np.dtype({'names': ["time"] + list(ch_names), + 'formats': [np.float64] + list(channel_fmt) + }) except ValueError: raise ValueError("Number of formats must match number of channels") self._record = record + self._write = True + self._store_once = store_once # write remaining data to file on program exit (e.g. quit()) if record: atexit.register(self.write_to_file) @@ -250,7 +257,7 @@ def with_window(cls, ch_names, sfreq, window=10, **kwargs): This constructor also expects to be passed a nominal sampling frequency so that it can determine the number of samples corresponding to the desired duration. Note that duration is usually not evenly divisible by - sampling frequency, so that the number of samples stored + sampling frequency, and the number of samples stored will be rounded. Args: ch_names (List[str]): List of channel names. @@ -258,7 +265,7 @@ def with_window(cls, ch_names, sfreq, window=10, **kwargs): window (float): Desired duration of live storage. """ n_samples = int(window * sfreq) - return cls(ch_names, n_samples, **kwargs) + return cls(ch_names, n_samples, sfreq, **kwargs) def initialize(self, n_samples=None): """Initialize NumPy structured array for data storage. @@ -280,10 +287,22 @@ def update(self, timestamps, samples): samples (Iterable): Channel data. Data type(s) in `Iterable` correspond to the type(s) specified in `dtype`. + + TODO: + * Sort timestamps/warn if unsorted? """ - self._new = self._format_samples(timestamps, samples) + new = self._format_samples(timestamps, samples) + self.update_with_structured(new) + + def update_with_structured(self, new): + """Append already structured data to stored data. + + Args: + new (np.ndarray): Structured data (`dtype=self.dtype`). + """ + self._new = new self._split_append(self._new) - self.updated.set() + self.event_hook.fire() def write_to_file(self, force=False): """Write any unwritten samples to file. @@ -305,10 +324,14 @@ def _split_append(self, new): # however, last chunk added may push out some unwritten samples # therefore split appends before and after write_to_file cutoff = self._count - self._append(new[:cutoff]) - if self._count == 0: - self.write_to_file() - self._append(new[cutoff:]) + if self._write: + self._append(new[:cutoff]) + if self._count == 0: + self.write_to_file() + if self._store_once: + self._write = False + else: + self._append(new[cutoff:]) def _append(self, new): with self._lock: @@ -324,11 +347,75 @@ def _format_samples(self, timestamps, samples): raise ValueError(str(stacked)) return stacked_ + def get_samples(self, last_n=0): + """Return copy of channel data, without timestamps. + + Args: + last_n (int): Number of most recent samples to return. + """ + with self._lock: + return np.copy(self._data[list(self.ch_names)][-last_n:]) + + def get_unstructured(self, last_n=0): + """Return unstructured copy of channel data, without timestamps. + + Args: + last_n (int): Number of most recent samples to return. + """ + samples = self.get_samples(last_n=last_n) + try: + return np.array(samples.tolist()) + #return samples.view((samples.dtype[0], self.n_chan)) + except ValueError as e: + print(samples.shape, samples.dtype, self.n_chan) + raise e + raise ValueError("Cannot return unstructured data for " + + "channels with different datatypes/sample shapes") + + def get_timestamps(self, last_n=0): + """Return copy of timestamps. + + Args: + last_n (int): Number of most recent timestamps to return. + """ + with self._lock: + return np.copy(self._data['time'][-last_n:]) + + @property + def samples(self): + """Copy of channel data, without timestamps.""" + return self.get_samples() + + @property + def unstructured(self): + """Unstructured copy of channel data, without timestamps.""" + return self.get_unstructured() + + @property + def timestamps(self): + """Copy of timestamps.""" + return self.get_timestamps() + + @property + def last_samples(self): + return np.copy(self._new) + + @property + def last_sample(self): + """Last-stored row (timestamp and sample).""" + with self._lock: + return np.copy(self._data[-1]) + @property def n_samples(self): """Number of samples stored in the NumPy array.""" return self._data.shape[0] + @property + def n_new(self): + """Number of samples received on last update.""" + return self._new.shape[0] + @property def ch_names(self): """Channel names. @@ -343,24 +430,71 @@ def n_chan(self): """Number of channels.""" return len(self.ch_names) - @property - def samples(self): - """Return copy of channel data, without timestamps.""" - with self._lock: - return np.copy(self._data[list(self.ch_names)]) - @property - def timestamps(self): - """Return copy of timestamps.""" - with self._lock: - return np.copy(self._data['time']) +class Spectra(TimeSeries): + """Manages a time series of spectral (e.g. frequency-domain) data. + + This is a constrained subclass of `TimeSeries`. Spectral data may be + stored for multiple channels, but all channels will share the same + spectral range (the `range` property). + + TODO: + * What do timestamps mean here? Transformer-dependent? + """ + + def __init__(self, ch_names, indep_range, indep_name="Frequency", + values_dtype=None, **kwargs): + """Create a new `Spectra` object. + + Args: + ch_names (List[str]): List of channel names. + indep_range (Iterable): Values of the independent variable. + n_samples (int): Number of spectra updates to keep. + indep_name (str): Name of the independent variable. + Default: `"freq"`. + values_dtype (type or np.dtype): Spectrum datatype. + Default: `np.float64`. + """ + if values_dtype is None: + values_dtype = np.float64 + + #try: + # if not sorted(indep_range) == list(indep_range): + # raise TypeError + #except TypeError: + # raise TypeError("indep_range not a monotonic increasing sequence") + + super().__init__(ch_names=ch_names, + channel_fmt=(values_dtype, len(indep_range)), + **kwargs) + + self._range = indep_range + self._indep_name = indep_name + + def update(self, timestamp, spectra): + """Append a spectrum to stored data. + + Args: + timestamp (np.float64): Timestamp for the current spectra. + spectrum: Spectra for each of the channels. + Should be a 2D iterable structure (e.g. list of lists, or + `np.ndarray`) where the first dimension corresponds to channels + and the second to the spectrum range. + + TODO: + * May be able to remove this method if `TimeSeries` update method + appends based on channel data type (see `TimeSeries` TODOs) + """ + try: + super(Spectra, self).update([timestamp], [spectra]) + except ValueError: + msg = "cannot update with spectra of incorrect/inconsistent length" + raise ValueError(msg) @property - def last_samples(self): - return np.copy(self._new) + def range(self): + return np.copy(self._range) @property - def last_sample(self): - """Last-stored row (timestamp and sample).""" - with self._lock: - return np.copy(self._data[-1]) + def indep_name(self): + return self._indep_name diff --git a/wizardhat/plot/plot.py b/wizardhat/plot/plot.py index 0a382e1..8b412fc 100644 --- a/wizardhat/plot/plot.py +++ b/wizardhat/plot/plot.py @@ -15,14 +15,10 @@ the gridplot method seems to work well for this. TODO: - * Figure out sampling method- possibly using Data's self.updated attribute - to trigger an update? Maybe we can update everything "in-place" because - buffer.data already has a built-in window.. * Automatically determine device name/set to title? """ from functools import partial -from threading import Thread from bokeh.layouts import row,gridplot, widgetbox from bokeh.models.widgets import Button, RadioButtonGroup @@ -30,11 +26,11 @@ from bokeh.palettes import all_palettes as palettes from bokeh.plotting import figure from bokeh.server.server import Server +import numpy as np from tornado import gen -import time -class Plotter(): +class Plotter: """Base class for plotting.""" def __init__(self, buffer, autostart=True): @@ -45,11 +41,16 @@ def __init__(self, buffer, autostart=True): plot_params (dict): Plot display parameters. """ self.buffer = buffer + self.buffer.event_hook += self._buffer_update_callback # output_file('WizardHat Plotter.html') self.server = Server({'/': self._app_manager}) #self.add_widgets() self.autostart = autostart + def _buffer_update_callback(self): + """Called by `buffer` when new data is available.""" + raise NotImplementedError() + def add_widgets(self): self.stream_option = RadioButtonGroup(labels=['EEG', 'ACC', 'GYR'], active=0) self.filter_option = RadioButtonGroup(labels=['Low Pass', 'High Pass', 'Band Pass'], active=0) @@ -60,7 +61,6 @@ def add_widgets(self): def run_server(self): self.server.start() self.server.io_loop.add_callback(self.server.show, '/') - self._update_thread.start() self.server.io_loop.start() def _app_manager(self, curdoc): @@ -106,7 +106,6 @@ def __init__(self, buffer, n_samples=5000, palette='Category10', data_dict = {name: [] # [self.buffer.data[name][:n_samples]] for name in self.buffer.dtype.names} self._source = ColumnDataSource(data_dict) - self._update_thread = Thread(target=self._get_new_samples) self._n_samples = n_samples self._colors = palettes[palette][len(self.buffer.ch_names)] @@ -131,22 +130,76 @@ def _set_layout(self): color=self._colors[i], source=self._source) self.plots.append([p]) - @gen.coroutine def _update(self, data_dict): self._source.stream(data_dict, self._n_samples) - def _get_new_samples(self): + def _buffer_update_callback(self): #TODO Time delay of 1 second is necessary because there seems to be plotting issue related to server booting #time delay allows the server to boot before samples get sent to it. - time.sleep(1) - while True: - self.buffer.updated.wait() - data_dict = {name: self.buffer.last_samples[name] - for name in self.buffer.dtype.names} - try: # don't freak out if IOLoop - self._curdoc.add_next_tick_callback(partial(self._update, - data_dict)) - except AttributeError: - pass - self.buffer.updated.clear() + data_dict = {name: self.buffer.last_samples[name] + for name in self.buffer.dtype.names} + try: # don't freak out if IOLoop + self._curdoc.add_next_tick_callback(partial(self._update, + data_dict)) + except AttributeError: + pass + + +class Spectra(Plotter): + """ + """ + + def __init__(self, buffer, palette='Category10', bgcolor="white", + **kwargs): + """ + Args: + buffer (buffers.Buffer or List[buffers.Buffer]): Objects with data + to be plotted. Multiple objects may be passed in a list, in + which case the plot can cycle through plotting the data in + each object by pressing 'd'. However, all data objects passed + should have a similar form (e.g. `TimeSeries` with same number + of rows/samples and channels). + plot_params (dict): Plot display parameters. + """ + + super().__init__(buffer, **kwargs) + + # TODO: initialize with existing samples in self.buffer.data + data_dict = {name: np.zeros_like(self.buffer.range) + for name in self.buffer.dtype.names if not name == 'time'} + data_dict["range"] = self.buffer.range + self._source = ColumnDataSource(data_dict) + + self._colors = palettes[palette][len(self.buffer.ch_names)] + self._bgcolor = bgcolor + + if self.autostart: + self.run_server() + + def _set_layout(self): + self.plots = [] + for i, ch in enumerate(self.buffer.ch_names): + p = figure(plot_height=100, + tools="xpan,xwheel_zoom,xbox_zoom,reset", + y_axis_location="right") + p.yaxis.axis_label = ch + p.background_fill_color = self._bgcolor + p.line(x='range', y=ch, alpha=0.8, line_width=2, + color=self._colors[i], source=self._source) + self.plots.append([p]) + + @gen.coroutine + def _update(self, data_dict): + self._source.data = data_dict + + def _buffer_update_callback(self): + last_samples = self.buffer.last_samples + data_dict = {name: last_samples[name].T + for name in self.buffer.dtype.names if not name == 'time'} + data_dict['range'] = self.buffer.range + try: # don't freak out if IOLoop + self._curdoc.add_next_tick_callback(partial(self._update, + data_dict)) + except AttributeError: + pass diff --git a/wizardhat/transform/transform.py b/wizardhat/transform/transform.py index f8b9bb1..43ac7d5 100644 --- a/wizardhat/transform/transform.py +++ b/wizardhat/transform/transform.py @@ -1,29 +1,27 @@ """Applying arbitrary transformations/calculations to `Data` objects. - - -TODO: - * Switch from threading.Thread to multiprocessing.Process (not a good idea - to use threads for CPU-intensive stuff) """ +from wizardhat.buffers import Spectra +import wizardhat.utils as utils + import copy -import threading import mne import numpy as np +import scipy.signal as spsig -class Transformer(threading.Thread): - """Base class for transforming data stored in `Buffer` objects. +class Transformer: + """Base class for transforming data handled by `Buffer` objects. Attributes: - buffer_in (buffers.Buffer): Input data. - buffer_out (buffers.Buffer): Output data. + buffer_in (buffers.Buffer): Input data buffer. + buffer_out (buffers.Buffer): Output data buffer. """ def __init__(self, buffer_in): - threading.Thread.__init__(self) self.buffer_in = buffer_in + self.buffer_in.event_hook += self._buffer_update_callback def similar_output(self): """Called in `__init__` when `buffer_out` has same form as `buffer_in`. @@ -32,7 +30,8 @@ def similar_output(self): self.buffer_out.update_pipeline_metadata(self) self.buffer_out.update_pipeline_metadata(self.buffer_out) - def run(self): + def _buffer_update_callback(self): + """Called by `buffer_in` when new data is available to filter.""" raise NotImplementedError() @@ -91,7 +90,7 @@ def _from_mne_array(self, mne_array): class MNEFilter(MNETransformer): - """Apply MNE filters to TimeSeries buffer objects.""" + """Apply MNE filters to `TimeSeries` buffer objects.""" def __init__(self, buffer_in, l_freq, h_freq, sfreq, update_interval=10): """Construct an `MNEFilter` instance. @@ -111,25 +110,112 @@ def __init__(self, buffer_in, l_freq, h_freq, sfreq, update_interval=10): self._update_interval = update_interval self._count = 0 - self._proceed = True - self.start() - - def run(self): - # wait until buffer_in is updated - while self._proceed: - self.buffer_in.updated.wait() - self.buffer_in.updated.clear() - self._count += 1 - if self._count == self._update_interval: - data = self.buffer_in.unstructured - timestamps, samples = data[:, 1], data[:, 1:] - filtered = mne.filter.filter_data(samples.T, self._sfreq, - *self._band) - # samples_mne = self._to_mne_array(samples) - # filtered_mne = samples_mne.filter(*self._band) - # filtered = self._from_mne_array(filtered_mne) - self.buffer_out.update(timestamps, filtered.T) - self._count = 0 - - def stop(self): - self._proceed = False + + def _buffer_update_callback(self): + self._count += 1 + if self._count == self._update_interval: + data = self.buffer_in.unstructured + timestamps, samples = data[:, 1], data[:, 1:] + filtered = mne.filter.filter_data(samples.T, self._sfreq, + *self._band) + # samples_mne = self._to_mne_array(samples) + # filtered_mne = samples_mne.filter(*self._band) + # filtered = self._from_mne_array(filtered_mne) + self.buffer_out.update(timestamps, filtered.T) + self._count = 0 + + +class PSD(Transformer): + """Calculate the power spectral density for time series data. + + TODO: + * control over update frequency? + """ + + def __init__(self, buffer_in, n_samples=256, pow2=True, window=np.hamming): + self.sfreq = buffer_in.sfreq + if pow2: + n_samples = utils.next_pow2(n_samples) + self.n_fft = n_samples + self.window = window(self.n_fft).reshape((self.n_fft, 1)) + self.indep_range = np.fft.rfftfreq(self.n_fft, 1 / self.sfreq) + self.buffer_out = Spectra(buffer_in.ch_names, self.indep_range) + + Transformer.__init__(self, buffer_in=buffer_in) + + def _buffer_update_callback(self): + """Called by `buffer_in` when new data is available.""" + timestamp = self.buffer_in.last_sample["time"] + data = self.buffer_in.get_unstructured(last_n=self.n_fft) + psd = self._get_power_spectrum(data) + self.buffer_out.update(timestamp, psd.T) + + def _get_windowed(self, data): + data_centered = data - np.mean(data, axis = 0) + data_windowed = data_centered * self.window + return data_windowed + + def _get_power_spectrum(self, data): + data_windowed = self._get_windowed(data) + data_fft = np.fft.rfft(data_windowed, n=self.n_fft, axis=0) + data_fft /= self.n_fft + psd = 2 * np.abs(data_fft) + return psd + + +class Convolve(Transformer): + """Convolve a time series of data. + + Currently only convolves across the sampling dimension (e.g. the rows in + unstructured data returned by a `buffers.TimeSeries` object) of all + channels, and assumes that all channels have the same shape (i.e. as + returned by the `get_unstructured` method.) + """ + + def __init__(self, buffer_in, conv_arr, conv_mode='valid', + conv_method='direct'): + """Create a new `Convolve` object. + + Args: + buffer_in (buffers.Buffer): Buffer managing data to convolve. + conv_arr (np.ndarray): Array to convolve data with. + Should not be longer than `buffer_in.n_samples`. + conv_mode (str): Mode for `scipy.signal.convolve`. + Default: `'valid'`. + conv_method (str): Method for `scipy.signal.convolve`. + Default: `'direct'`. For many channels and very large + convolution windows, it may be faster to use `'fft'`. + """ + Transformer.__init__(self, buffer_in=buffer_in) + self.similar_output() + self.conv_mode = conv_mode + self.conv_method = conv_method + + # expand convolution array across independent (non-sampling) dims + ch_shape = self.buffer_in.unstructured.shape[1:] + self.conv_arr = np.array(conv_arr).reshape([-1] + [1] * len(ch_shape)) + self._conv_n_edge = len(self.conv_arr) - 1 + + if self.conv_mode == 'valid': + self._timestamp_slice = slice(self._conv_n_edge, + -self._conv_n_edge) + else: + raise NotImplementedError() + + def _buffer_update_callback(self): + """Called by `buffer_in` when new data is available.""" + n_new = self.buffer_in.n_new + last_n = max(n_new + 2 * self._conv_n_edge, self.buffer_in.n_samples) + data = self.buffer_in.get_unstructured(last_n=last_n) + timestamps = self.buffer_in.get_timestamps(last_n=last_n) + data_conv = spsig.convolve(data, self.conv_arr, mode=self.conv_mode, + method=self.conv_method) + self.buffer_out.update(timestamps[self._timestamp_slice], data_conv) + + +class MovingAverage(Convolve): + """Calculate a uniformly-weighted moving average over a data series.""" + + def __init__(self, buffer_in, n_avg): + conv_arr = np.array([1 / n_avg] * n_avg) + Convolve.__init__(self, buffer_in=buffer_in, conv_arr=conv_arr) diff --git a/wizardhat/utils.py b/wizardhat/utils.py index 8bc736c..845a44c 100644 --- a/wizardhat/utils.py +++ b/wizardhat/utils.py @@ -12,6 +12,47 @@ import numpy as np +class EventHook: + """Handler for multiple callbacks triggered by a single event. + + Callbacks may be registered with an `EventHook` instance using the + incremental add operator (`event_hook_instance += some_callback_function`), + and deregistered by incremental subtraction. When the instance's `fire` + method is called (i.e. upon some event), all of the registered callback + functions will also be called. + + The primary use for this class is in `Buffer` classes, whose `EventHook` + instances allow them to call the update functions of all downstream objects + (e.g. `Plotter` or `Transformer` instances). + + TODO: + * multiprocessing: spread the workload over several processes; maybe + give the option to use either threading or multiprocessing for a given + callback, depending on its complexity (IO vs. calculations) + """ + def __init__(self): + self._handlers = [] + + def __iadd__(self, handler): + self._handlers.append(handler) + return self + + def __isub__(self, handler): + self._handlers.remove(handler) + return self + + def fire(self, *args, **keywargs): + """Call all registered callback functions.""" + for handler in self._handlers: + handler(*args, **keywargs) + + def clear_handlers(self, in_object): + """Deregister all methods of a given object.""" + for handler in self.__handlers: + if handler.__self__ == in_object: + self -= handler + + def deepcopy_mask(obj, memo, mask=None): """Generalized method for deep copies of objects. @@ -63,3 +104,8 @@ def makedirs(filepath): filepath (str): The path for which to create directories. """ os.makedirs(os.path.dirname(filepath), exist_ok=True) + + +def next_pow2(n): + """Return the nearest power of 2 greater than a number.""" + return int(2 ** np.ceil(np.log2(n)))