Skip to content

Commit

Permalink
add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
dmeliza committed Jan 13, 2024
1 parent f10d64e commit 24af0b6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 54 deletions.
105 changes: 57 additions & 48 deletions arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
This is ARF, a python library for storing and accessing audio and ephys data in
HDF5 containers.
"""
import numbers
from datetime import datetime
from pathlib import Path
from typing import Union
from time import mktime, struct_time
from typing import Iterator, Optional, Tuple, Union
from uuid import UUID

import h5py as h5
import numpy as np
import numpy.typing as npt

Timestamp = Union[datetime, struct_time, int, float, Tuple[int, int]]
ArfTimeStamp = np.ndarray
Datashape = Tuple[int, ...]

spec_version = "2.1"
__version__ = version = "2.6.5"
Expand Down Expand Up @@ -48,27 +60,24 @@ def _fromstring(cls, s):

def open_file(
path: Union[Path, str],
mode=None,
driver=None,
libver=None,
userblock_size=None,
mode: Optional[str] = None,
driver: Optional[str] = None,
libver: Optional[str] = None,
userblock_size: Optional[int] = None,
**kwargs,
):
) -> h5.File:
"""Open an ARF file, creating as necessary.
Use this instead of h5py.File to ensure that root-level attributes and group
creation property lists are set correctly.
"""
import os
import sys
from packaging.version import Version

from h5py import File, h5p

# Caution: This is a private API of h5py, subject to change without notice
from h5py._hl import files as _files
from h5py.version import version as h5py_version
from packaging.version import Version

path = Path(path)
exists = path.exists()
Expand Down Expand Up @@ -116,7 +125,9 @@ def open_file(
return fp


def create_entry(group, name, timestamp, **attributes):
def create_entry(
group: h5.Group, name: str, timestamp: Timestamp, **attributes
) -> h5.Group:
"""Create a new ARF entry under group, setting required attributes.
An entry is an abstract collection of data which all refer to the same time
Expand Down Expand Up @@ -152,16 +163,16 @@ def create_entry(group, name, timestamp, **attributes):


def create_dataset(
group,
name,
data,
units="",
group: h5.Group,
name: str,
data: npt.ArrayLike,
units: str = "",
datatype=DataTypes.UNDEFINED,
chunks=True,
maxshape=None,
compression=None,
chunks: Union[bool, Datashape] = True,
maxshape: Optional[Datashape] = None,
compression: Optional[str] = None,
**attributes,
):
) -> h5.Dataset:
"""Create an ARF dataset under group, setting required attributes
Required arguments:
Expand Down Expand Up @@ -227,14 +238,16 @@ def create_dataset(
return dset


def create_table(group, name, dtype, **attributes):
def create_table(
group: h5.File, name: str, dtype: npt.DTypeLike, **attributes
) -> h5.Dataset:
"""Create a new array dataset under group with compound datatype and maxshape=(None,)"""
dset = group.create_dataset(name, shape=(0,), dtype=dtype, maxshape=(None,))
set_attributes(dset, **attributes)
return dset


def append_data(dset, data):
def append_data(dset: h5.Dataset, data: npt.ArrayLike):
"""Append data to dset along axis 0. Data must be a single element or
a 1D array of the same type as the dataset (including compound datatypes)."""
N = data.shape[0] if hasattr(data, "shape") else 1
Expand All @@ -246,7 +259,7 @@ def append_data(dset, data):
dset[oldlen:] = data


def select_interval(dset, begin, end):
def select_interval(dset: h5.Dataset, begin: float, end: float):
"""Extracts values from dataset between [begin, end), specified in seconds. For
point process data, times are offset to the beginning of the interval.
Returns (values, offset)
Expand Down Expand Up @@ -277,7 +290,7 @@ def select_interval(dset, begin, end):
return data, begin


def check_file_version(file):
def check_file_version(file: h5.File):
"""Check the ARF version attribute of file for compatibility.
Raises DeprecationWarning for backwards-incompatible files, FutureWarning
Expand Down Expand Up @@ -318,7 +331,7 @@ def check_file_version(file):
return file_version


def set_attributes(node, overwrite=True, **attributes):
def set_attributes(node: h5.HLObject, overwrite: bool = True, **attributes) -> None:
"""Set multiple attributes on node.
If overwrite is False, and the attribute already exists, does nothing. If
Expand All @@ -336,17 +349,17 @@ def set_attributes(node, overwrite=True, **attributes):
aset[k] = v


def keys_by_creation(group):
"""Returns a sequence of links in group in order of creation.
def keys_by_creation(group: h5.Group) -> Iterator[str]:
"""Returns a lazy sequence of links in group in order of creation.
Raises an error if the group was not set to track creation order.
"""
from h5py import h5

out = []
out: list[bytes] = []
try:
group._id.links.iterate(
group.id.links.iterate(
out.append, idx_type=h5.INDEX_CRT_ORDER, order=h5.ITER_INC
)
except (AttributeError, RuntimeError):
Expand All @@ -355,11 +368,11 @@ def f(name):
if name.find(b"/", 1) == -1:
out.append(name)

group._id.links.visit(f, idx_type=h5.INDEX_CRT_ORDER, order=h5.ITER_INC)
group.id.links.visit(f, idx_type=h5.INDEX_CRT_ORDER, order=h5.ITER_INC)
return map(group._d, out)


def convert_timestamp(obj):
def convert_timestamp(obj: Timestamp) -> ArfTimeStamp:
"""Make an ARF timestamp from an object.
Argument can be a datetime.datetime object, a time.struct_time, an integer,
Expand All @@ -371,10 +384,6 @@ def convert_timestamp(obj):
between float and integer tuple may not be reversible.
"""
import numbers
from datetime import datetime
from time import mktime, struct_time

from numpy import zeros

out = zeros(2, dtype="int64")
Expand All @@ -388,30 +397,32 @@ def convert_timestamp(obj):
elif isinstance(obj, numbers.Real):
out[0] = obj
out[1] = (obj - out[0]) * 1e6
else:
elif isinstance(obj, tuple):
try:
out[:2] = obj[:2]
except IndexError as err:
raise TypeError("unable to convert %s to timestamp" % obj) from err
raise TypeError("tuple timestamp needs two elements") from err
else:
raise TypeError(f"unable to convert {obj} to timestamp")
return out


def timestamp_to_datetime(timestamp):
def timestamp_to_datetime(timestamp: ArfTimeStamp) -> datetime:
"""Convert an ARF timestamp to a datetime.datetime object (naive local time)"""
from datetime import datetime, timedelta

obj = datetime.fromtimestamp(timestamp[0])
return obj + timedelta(microseconds=int(timestamp[1]))


def timestamp_to_float(timestamp):
def timestamp_to_float(timestamp: ArfTimeStamp) -> float:
"""Convert an ARF timestamp to a floating point (sec since epoch)"""
return sum(t1 * t2 for t1, t2 in zip(timestamp, (1.0, 1e-6)))


def set_uuid(obj, uuid=None):
def set_uuid(obj: h5.HLObject, uuid: Union[str, bytes, UUID, None] = None):
"""Set the uuid attribute of an HDF5 object. Use this method to ensure correct dtype"""
from uuid import UUID, uuid4
from uuid import uuid4

if uuid is None:
uuid = uuid4()
Expand All @@ -426,7 +437,7 @@ def set_uuid(obj, uuid=None):
obj.attrs.create("uuid", str(uuid).encode("ascii"), dtype="|S36")


def get_uuid(obj):
def get_uuid(obj: h5.HLObject) -> UUID:
"""Return the uuid for obj, or null uuid if none is set"""
# TODO: deprecate null uuid ret val
from uuid import UUID
Expand All @@ -443,7 +454,7 @@ def get_uuid(obj):
return UUID(uuid)


def count_children(obj, type=None):
def count_children(obj: h5.HLObject, type=None) -> int:
"""Return the number of children of obj, optionally restricting by class"""
if type is None:
return len(obj)
Expand All @@ -454,7 +465,7 @@ def count_children(obj, type=None):
return sum(1 for x in obj if obj.get(x, getclass=True) is type)


def is_time_series(dset):
def is_time_series(dset: h5.Dataset) -> bool:
"""Return True if dset is a sampled time series (units are not time)"""
return (
not is_marked_pointproc(dset)
Expand All @@ -463,19 +474,17 @@ def is_time_series(dset):
)


def is_marked_pointproc(dset):
def is_marked_pointproc(dset: h5.Dataset) -> bool:
"""Return True if dset is a marked point process (a complex dtype with 'start' field)"""
return dset.dtype.names is not None and "start" in dset.dtype.names


def is_entry(obj):
def is_entry(obj: h5.HLObject) -> bool:
"""Return True if the object is an entry (i.e. an hdf5 group)"""
import h5py as h5

return isinstance(obj, h5.Group)


def count_channels(dset):
def count_channels(dset: h5.Dataset) -> int:
"""Return the number of channels (columns) in dset"""
try:
return dset.shape[1]
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ extend-select = [
"PGH", # pygrep-hooks
"RUF", # Ruff-specific
"UP", # pyupgrade
]
]

[tool.mypy]
ignore_missing_imports = true
10 changes: 5 additions & 5 deletions tests/test_arf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# -*- mode: python -*-

import pytest
import time
from packaging import version

import numpy as nx
import pytest
from h5py.version import version as h5py_version
from numpy.random import randint, randn
from packaging import version

import arf

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_entry(test_file):

@pytest.fixture
def test_dataset(test_entry):
return arf.create_dataset(self.entry, **datasets[2])
return arf.create_dataset(test_entry, **datasets[2])


def test00_create_entries(test_file):
Expand Down Expand Up @@ -139,7 +139,7 @@ def test02_create_datasets(test_entry):
def test04_create_bad_dataset(test_entry):
for dset in bad_datasets:
with pytest.raises(ValueError):
d = arf.create_dataset(test_entry, **dset)
_ = arf.create_dataset(test_entry, **dset)


def test05_set_attributes(test_entry):
Expand All @@ -148,7 +148,7 @@ def test05_set_attributes(test_entry):
assert test_entry.attrs["myint"] == 5000
assert test_entry.attrs["mystr"] == "myvalue"
arf.set_attributes(test_entry, mystr=None)
assert not "mystr" in test_entry.attrs
assert "mystr" not in test_entry.attrs


def test06_null_uuid(test_entry):
Expand Down

0 comments on commit 24af0b6

Please sign in to comment.