Skip to content

Commit

Permalink
Merge branch 'patch-v1.3.3'
Browse files Browse the repository at this point in the history
Fix small bugs.
1) Fix bugs in RasterizeGraph edge coloring
2) Resample Node no longer results in volume shrinkage by 1 pixel
3) SpecifiedLocation won't get stuck in an infinite loop if None of the points to choose from are valid, instead throws an error.
4) removed references to `np.float128` which may not be available on all platforms (mac). Tests now pass on mac.
5) Removed code for python 2 compatibility
  • Loading branch information
pattonw committed Apr 20, 2024
1 parent ecbb63c commit 71aa879
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 63 deletions.
20 changes: 0 additions & 20 deletions gunpowder/compat.py

This file was deleted.

17 changes: 17 additions & 0 deletions gunpowder/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,20 @@ def u(self):
def v(self):
return self.__v

@property
def attrs(self):
return self.__attrs

@property
def all(self):
return self.__attrs

@classmethod
def from_attrs(cls, attrs: Dict[str, Any]):
u = attrs["u"]
v = attrs["v"]
return cls(u, v, attrs=attrs)

def __iter__(self):
return iter([self.u, self.v])

Expand Down Expand Up @@ -287,6 +297,13 @@ def node(self, id: int):
attrs = self.__graph.nodes[id]
return Node.from_attrs(attrs)

def edge(self, id: tuple[int, int]):
"""
Get specific edge
"""
attrs = self.__graph.edges[id]
return Edge.from_attrs(attrs)

def contains(self, node_id: int):
return node_id in self.__graph.nodes

Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/defect_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def process(self, batch, request):
if augmentation_type == "zero_out":
raw.data[section_selector] = 0

elif augmentation_type == "low_contrast":
elif augmentation_type == "lower_contrast":
section = raw.data[section_selector]

mean = section.mean()
Expand Down
12 changes: 6 additions & 6 deletions gunpowder/nodes/dvid_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ def __get_spec(self, array_key):
spec.dtype = data_dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in [
np.float32,
np.float64,
np.float128,
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s. "
"Based on the dtype %s, it has been set to %s. "
Expand Down
12 changes: 6 additions & 6 deletions gunpowder/nodes/hdf5like_source_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def __read_spec(self, array_key, data_file, ds_name):
spec.dtype = dataset.dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in [
np.float32,
np.float64,
np.float128,
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s "
"(dataset %s). Based on the dtype %s, it has been "
Expand Down
12 changes: 6 additions & 6 deletions gunpowder/nodes/klb_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def __read_spec(self, headers):
spec.dtype = dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in [
np.float32,
np.float64,
np.float128,
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s. "
"Based on the dtype %s, it has been set to %s. "
Expand Down
39 changes: 29 additions & 10 deletions gunpowder/nodes/rasterize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __rasterize(
settings.mode == "ball"
and settings.inner_radius_fraction is None
and len(list(graph.edges)) == 0
and settings.color_attr is None
)

if use_fast_rasterization:
Expand Down Expand Up @@ -347,7 +348,7 @@ def __rasterize(

else:
if settings.color_attr is not None:
c = graph.nodes[node].get(settings.color_attr)
c = node.attrs.get(settings.color_attr)
if c is None:
logger.debug(f"Skipping node: {node}")
continue
Expand All @@ -363,7 +364,7 @@ def __rasterize(
if settings.edges:
for e in graph.edges:
if settings.color_attr is not None:
c = graph.edges[e].get(settings.color_attr)
c = e.attrs.get(settings.color_attr)
if c is None:
continue
elif np.isclose(c, 1) and not np.isclose(settings.fg_value, 1):
Expand All @@ -372,26 +373,44 @@ def __rasterize(
f"attribute {settings.color_attr} "
f"but color 1 will be replaced with fg_value: {settings.fg_value}"
)
else:
c = 1

u = graph.node(e.u)
v = graph.node(e.v)
u_coord = Coordinate(u.location / voxel_size)
v_coord = Coordinate(v.location / voxel_size)
line = draw.line_nd(u_coord, v_coord, endpoint=True)
rasterized_graph[line] = 1
rasterized_graph[line] = c

# grow graph
if not use_fast_rasterization:
if settings.mode == "ball":
enlarge_binary_map(
rasterized_graph,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)
if settings.color_attr is not None:
for color in np.unique(rasterized_graph):
if color == 0:
continue
assert color in [2,3], np.unique(rasterized_graph)
mask = rasterized_graph == color
enlarge_binary_map(
mask,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)
rasterized_graph[mask] = color
else:
enlarge_binary_map(
rasterized_graph,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)

else:

sigmas = settings.radius / voxel_size

gaussian_filter(
Expand Down
3 changes: 0 additions & 3 deletions gunpowder/nodes/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def setup(self):
+ source_voxel_size[-self.ndim :]
)

spec.roi = spec.roi.grow(
-self.pad, -self.pad
) # Pad w/ 1 voxel per side for interpolation to avoid edge effects
spec.roi = spec.roi.snap_to_grid(
np.lcm(source_voxel_size, self.target_voxel_size), mode="shrink"
)
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/simple_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def process(self, batch, request):
array.data = array.data[channel_slices + mirror]

transpose = [t + num_channels for t in self.transpose]
array.data = array.data = array.data.transpose(
array.data = array.data.transpose(
list(range(num_channels)) + transpose
)

Expand Down
31 changes: 30 additions & 1 deletion gunpowder/nodes/specified_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,29 @@ class SpecifiedLocation(BatchFilter):
Default is None, which places the point in the center.
Chooses uniformly from [loc - jitter, loc + jitter] in each
direction.
attempt_factor (``int``):
If choosing randomly then given `n` points, sample
`attempt_factor * n` points at most before giving up and
throwing an error.
"""

def __init__(self, locations, choose_randomly=False, extra_data=None, jitter=None):
def __init__(
self,
locations,
choose_randomly=False,
extra_data=None,
jitter=None,
attempt_factor: int = 5,
):
self.coordinates = locations
self.choose_randomly = choose_randomly
self.jitter = jitter
self.loc_i = -1
self.upstream_spec = None
self.specified_shift = None
self.attempt_factor = attempt_factor

if extra_data is not None:
assert len(extra_data) == len(locations), (
Expand All @@ -79,13 +93,28 @@ def prepare(self, request):
request_center = total_roi.shape / 2 + total_roi.offset

self.specified_shift = self._get_next_shift(request_center, lcm_voxel_size)
loop_counter = 0
while not self.__check_shift(request):
logger.warning(
"Location %s (shift %s) skipped"
% (self.coordinates[self.loc_i], self.specified_shift)
)
self.specified_shift = self._get_next_shift(request_center, lcm_voxel_size)

loop_counter += 1
if loop_counter >= len(self.coordinates) * (
1 + (self.attempt_factor - 1) * int(self.choose_randomly)
):
if self.choose_randomly:
raise Exception(
f"Took {5*len(self.coordinates)} samples of {len(self.coordinates)} points "
"and did not find a suitible location"
)
else:
raise Exception(
"Looped through every possible location and None are valid"
)

# Set shift for all requests
for specs_type in [request.array_specs, request.graph_specs]:
for key, spec in specs_type.items():
Expand Down
12 changes: 6 additions & 6 deletions gunpowder/nodes/zarr_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def __read_spec(self, array_key, data_file, ds_name):
spec.dtype = dataset.dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in [
np.float32,
np.float64,
np.float128,
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s "
"(dataset %s). Based on the dtype %s, it has been "
Expand Down
3 changes: 1 addition & 2 deletions gunpowder/tensorflow/nodes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def __init__(
self.summary_saver = None
self.log_dir = log_dir
self.log_every = log_every
# Check if optimizer is a str in python 2/3 compatible way.
if isinstance(optimizer, ("".__class__, "".__class__)):
if isinstance(optimizer, str):
self.optimizer_loss_names = (optimizer, loss)
else:
self.optimizer_func = optimizer
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/version_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__major__ = 1
__minor__ = 3
__patch__ = 2
__patch__ = 3
__tag__ = ""
__version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".")

Expand Down
30 changes: 30 additions & 0 deletions tests/cases/rasterize_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,36 @@
import numpy as np


def test_rasterize_graph_colors():
graph = Graph(
[
Node(id=1, location=np.array((0.5, 0.5)), attrs={"color": 2}),
Node(id=2, location=np.array((0.5, 4.5)), attrs={"color": 2}),
Node(id=3, location=np.array((4.5, 0.5)), attrs={"color": 3}),
Node(id=4, location=np.array((4.5, 4.5)), attrs={"color": 3}),
],
[Edge(1, 2, attrs={"color": 2}), Edge(3, 4, attrs={"color": 3})],
GraphSpec(roi=Roi((0, 0), (5, 5))),
)

graph_key = GraphKey("G")
array_key = ArrayKey("A")
graph_source = GraphSource(graph_key, graph)
pipeline = graph_source + RasterizeGraph(
graph_key,
array_key,
ArraySpec(roi=Roi((0, 0), (5, 5)), voxel_size=Coordinate(1, 1), dtype=np.uint8),
settings=RasterizationSettings(1, color_attr="color"),
)
with build(pipeline):
request = BatchRequest()
request[array_key] = ArraySpec(Roi((0, 0), (5, 5)))
rasterized = pipeline.request_batch(request)[array_key].data
assert rasterized[0, 0] == 2
assert rasterized[0, :].sum() == 10
assert rasterized[4, 0] == 3
assert rasterized[4, :].sum() == 15

def test_3d():
graph_key = GraphKey("TEST_GRAPH")
array_key = ArrayKey("TEST_ARRAY")
Expand Down

0 comments on commit 71aa879

Please sign in to comment.