Skip to content

Commit

Permalink
Debugging SetOperation
Browse files Browse the repository at this point in the history
  • Loading branch information
peytondmurray committed Mar 29, 2024
1 parent 9bfb1c7 commit 6d67cb1
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 36 deletions.
131 changes: 101 additions & 30 deletions versioned_hdf5/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Union

import numpy as np
from h5py import Dataset, File, Group, VirtualLayout, VirtualSource, h5s
Expand All @@ -7,14 +7,17 @@
from numpy.testing import assert_array_equal

from .hashtable import Hashtable
from .slicetools import spaceid_to_slice, to_slice_tuple
from .slicetools import overlap, spaceid_to_slice, to_slice_tuple

DEFAULT_CHUNK_SIZE = 2**12
DATA_VERSION = 4
# data_version 2 has broken hashtables, always need to rebuild
# data_version 3 hash collisions for string arrays which, when concatenated, give the same string
CORRUPT_DATA_VERSIONS = frozenset([2, 3])

if TYPE_CHECKING:
from .wrappers import InMemoryDataset


class SplitResult:
"""Object which stores the result of splitting a dataset across the last chunk."""
Expand Down Expand Up @@ -578,6 +581,8 @@ def write_dataset_chunks(
for raw_slice, data_s in data_to_write.items():
c = (raw_slice.raw,) + tuple(slice(0, i) for i in data_s.shape[1:])
raw_data[c] = data_s

# TODO: should this return Tuple (not Slice) elements?
return slices


Expand Down Expand Up @@ -749,11 +754,7 @@ def __repr__(self):
return f"SetOperation:\n Index {self.index}: Data {self.arr}"

def apply(
self,
f: File,
name: str,
version: str,
slices: Dict[Tuple, Tuple]
self, f: File, name: str, _version: str, slices: Dict[Tuple, Tuple]
) -> Dict[Tuple, Tuple]:
"""Write the stored data to the dataset in chunks.
Expand All @@ -764,7 +765,8 @@ def apply(
name : str
Name of the dataset
version : str
Version of the dataset to write to
Version of the dataset to write to; unused (this is provided by the
slices dict)
slices : Dict[Tuple, Tuple]
Mapping between {slices in virtual dataset: slices in raw dataset}
in the virtual datset initially
Expand All @@ -775,38 +777,29 @@ def apply(
Mapping between {slices in virtual dataset: slices in raw dataset}
which were written by this function.
"""
# Ensure each element of the index is a slice
index = to_slice_tuple(ndindex(self.index))

# If the shape of the array doesn't match the shape of the
# index to assign the array to, broadcast it first.
index = ndindex(self.index)
index_shape = tuple(len(dim) for dim in index.args)
if self.arr.shape != index_shape:
arr = np.broadcast_to(self.arr, index_shape)
else:
arr = self.arr

raw_data = f["_version_data"][name]["raw_data"]
chunk_size = tuple(raw_data.attrs["chunks"])[0]

# Create a dictionary which is essentially a list of changes to virtual
# indices that are being made as part of the write operation.
changes = {}
for arr_chunk, virtual_chunk in zip(
partition(arr, chunk_size),
partition(index, chunk_size),
strict=True,
):
changes[virtual_chunk] = arr[arr_chunk.raw]

# Indices written by write_dataset_chunks need to be full chunks.
# Compute the affected chunks and their target values by comparing the
# existing (chunk-sized) dataset slices to the (not-chunk-sized) requested
# changes.
changed_chunks = get_affected_chunks(slices, changes)
changed_chunks = get_setitem_chunks(slices, index, arr, raw_data)

# Operations which set data only write and/or replace entire chunks,
# i.e. they do not modify the keys of this mapping. Therefore it is
# not necessary to pass in the initial slices to write_dataset_chunks.
new_slices = write_dataset_chunks(f, name, changes)
new_slices = write_dataset_chunks(f, name, changed_chunks)
return {**slices, **new_slices}

def show(self, data: np.ndarray) -> np.ndarray:
Expand All @@ -826,7 +819,15 @@ def show(self, data: np.ndarray) -> np.ndarray:
np.ndarray
Value of the dataset post-operation
"""
data[self.index] = self.arr
if isinstance(data, np.ndarray):
data[self.index] = self.arr
else:
# This is a numpy scalar. There's only one element
# so your index better be 0
if self.index != (0,):
raise ValueError

data = self.arr
return data


Expand All @@ -846,7 +847,9 @@ def __init__(self, value: np.ndarray):
def __repr__(self):
return f"WriteOperation:\n {self.value}"

def apply(self, f: File, name: str, version: str) -> Dict[Tuple, Tuple]:
def apply(
self, f: File, name: str, version: str, slices: Dict[Tuple, Tuple]
) -> Dict[Tuple, Tuple]:
"""Append data the stored data to the dataset.
Parameters
Expand Down Expand Up @@ -889,6 +892,17 @@ def show(self, data: np.ndarray) -> np.ndarray:
)


def write_dataset_operations(
f: File,
version_name: str,
name: str,
dataset: "InMemoryDataset",
) -> tuple[Dict[Tuple, Tuple], tuple[int]]:
result = write_operations(f, version_name, name, dataset._operations)
dataset._operations.clear()
return result


def write_operations(
f: File, version_name: str, name: str, operations: List[WriteOperation]
) -> tuple[Dict[Tuple, Tuple], tuple[int]]:
Expand Down Expand Up @@ -921,7 +935,6 @@ def write_operations(
"Use write_dataset() if the dataset does not yet exist"
)

breakpoint()
slices = get_previous_version_slices(f, version_name, name)
for operation in operations:
slices.update(operation.apply(f, name, version_name, slices))
Expand Down Expand Up @@ -1323,10 +1336,68 @@ def partition(
yield from ChunkSize((chunk_size,)).as_subchunks(index, shape)


def get_affected_chunks(
def get_setitem_chunks(
slices: Dict[Tuple, Tuple],
changes: Dict[Tuple, np.ndarray],
) -> Dict[Tuple, np.ndarray]:
index: Tuple,
arr: np.ndarray,
raw_data: Dataset,
) -> Dict[Tuple, Union[Tuple, np.ndarray]]:
"""Get the new chunks of the virtual dataset after arr is written to index.
Parameters
----------
slices : Dict[Tuple, Tuple]
Mapping between existing {slices in virtual dataset: slices in raw dataset}
index : Tuple
(Contiguous) virtual indices for which arr is to be set
arr : np.ndarray
Values to set at `index` indices in the virtual dataset
raw_data : Dataset
Raw data
"""
new_chunks = {}

# Keep track of the current index along axis 0 of the data we are inserting
next_arr_index = 0
for vslice, rslice in slices.items():
# Get the part of the index that overlaps with the virtual slice
overlapping = overlap(vslice.args[0], index.args[0])
if overlapping:
# Find the corresponding overlap in the raw slice
raw_overlap_start = rslice.args[0].start + (
overlapping.start - vslice.args[0].start
)

# Depending on the insertion index and the length of arr,
# the new data to insert may fall in the middle of a
# chunk. Compute the new chunk using existing raw_data where
# appropriate, and data from `arr` for the part of the
# chunk where `index` overlaps with virtual slices.
raw_index_initial = Tuple(
Slice(rslice.args[0].start, raw_overlap_start), *rslice.args[1:]
)
raw_index_final = Tuple(
Slice(raw_overlap_start + len(overlapping), rslice.args[0].stop),
*rslice.args[1:],
)
overlapping_data_index = Tuple(
Slice(next_arr_index, next_arr_index + len(overlapping)),
*rslice.args[1:],
)
changed_chunk = np.concatenate(
(
raw_data[raw_index_initial.raw],
arr[overlapping_data_index.raw],
raw_data[raw_index_final.raw],
),
)

assert changed_chunk.shape == tuple(len(dim) for dim in vslice.args)

new_chunks[vslice] = changed_chunk
next_arr_index += len(overlapping)
else:
# This chunk isn't changed by the index
new_chunks[vslice] = rslice

# Find the affected chunks in the slices of the previous virtual dataset
for index in changes:
return new_chunks
41 changes: 41 additions & 0 deletions versioned_hdf5/slicetools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import lru_cache
from typing import List, Optional

from ndindex import Integer, Slice, Tuple

Expand Down Expand Up @@ -67,3 +68,43 @@ def to_slice_tuple(index: Tuple) -> Tuple:
raise TypeError(f"Cannot convert type of {dim} to a Slice.")

return Tuple(*result)


def overlap(index1: Slice, index2: Slice) -> Optional[List[Slice]]:
"""Return the overlap (if any) between two Slices.
Parameters
----------
index1 : Slice
First slice to compare
index2 : Slice
Second slice to compare
Returns
-------
Optional[Slice]
None if there is no overlap, otherwise a Slice containing the overlapping
indices
"""
# ----------------
# ------------------
# or
# ----------------------
# -------------
if index1.start <= index2.start <= index1.stop:
return Slice(index2.start, min(index2.stop, index1.stop))

# ----------------------
# ---------------------
# or
# -----------------
# -----------
if index1.start <= index2.stop <= index1.stop:
return Slice(max(index1.start, index2.start), index2.stop)

# ---------
# ----------
# or
# ----------
# ---------
return None
2 changes: 2 additions & 0 deletions versioned_hdf5/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def test_changes_dataset(vfile):
key_ds[0] = 0

key = vfile["version2"][f"{name}/key"]
f = key[1]

assert key.shape == (2 * DEFAULT_CHUNK_SIZE,)
assert_equal(key[0], 0)
assert_equal(key[1 : 2 * DEFAULT_CHUNK_SIZE], 1.0)
Expand Down
7 changes: 3 additions & 4 deletions versioned_hdf5/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
create_virtual_dataset,
write_dataset,
write_dataset_chunks,
write_operations,
write_dataset_operations,
)
from .wrappers import (
DatasetWrapper,
Expand Down Expand Up @@ -148,9 +148,7 @@ def commit_version(
for k, v in data.attrs.items():
data_copy.attrs[k] = v
else:
slices, shape = write_operations(
f, version_name, name, data._operations
)
slices, shape = write_dataset_operations(f, version_name, name, data)

elif isinstance(data, dict):
if chunks[name] is not None:
Expand Down Expand Up @@ -190,6 +188,7 @@ def commit_version(
)
else:
shape = data.shape

create_virtual_dataset(
f, version_name, name, shape, slices, attrs=attrs, fillvalue=fillvalue
)
Expand Down
5 changes: 3 additions & 2 deletions versioned_hdf5/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,11 @@ def __getitem__(self, args, new_dtype=None):
# === END CODE FROM h5py.Dataset.__getitem__ ===

idx = ndindex(args).expand(self.shape)
arr = super().__getitem__(idx.raw)

# If we can read from the underlying dataset, do so
# If there are operations to carry out, iterate through
# them in memory
if self._operations:
arr = super().__getitem__(idx.raw)[:]
for operation in self._operations:
arr = operation.show(arr)

Expand Down

0 comments on commit 6d67cb1

Please sign in to comment.