Skip to content

Commit

Permalink
Test all CSVPointsSource functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed May 16, 2024
1 parent 4c54298 commit 18409b7
Showing 1 changed file with 59 additions and 3 deletions.
62 changes: 59 additions & 3 deletions tests/cases/csv_points_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import csv

from gunpowder import (
BatchRequest,
Expand All @@ -22,7 +23,7 @@ def seeds():


@pytest.fixture
def test_points(tmpdir):
def test_points_2d(tmpdir):
random.seed(1234)
np.random.seed(1234)

Expand All @@ -41,8 +42,31 @@ def test_points(tmpdir):
yield fake_points_file, fake_points


def test_pipeline3(test_points):
fake_points_file, fake_points = test_points
@pytest.fixture
def test_points_3d(tmpdir):
random.seed(1234)
np.random.seed(1234)

fake_points_file = tmpdir / "shift_test.csv"
fake_points = np.random.randint(0, 100, size=(3, 3)).astype(float)
with open(fake_points_file, "w") as f:
writer = csv.DictWriter(f, fieldnames=["x", "y", "z", "id"])
writer.writeheader()
for i, point in enumerate(fake_points):
pointdict = {"x": point[0], "y": point[1], "z": point[2], "id": i}
writer.writerow(pointdict)

# This fixture will run after seeds since it is set
# with autouse=True. So make sure to reset the seeds properly at the end
# of this fixture
random.seed(12345)
np.random.seed(12345)

yield fake_points_file, fake_points


def test_pipeline_2d(test_points_2d):
fake_points_file, fake_points = test_points_2d

points_key = GraphKey("TEST_POINTS")

Expand All @@ -67,3 +91,35 @@ def test_pipeline3(test_points):
result_locs = [list(point.location) for point in result_points]

assert sorted(result_locs) == sorted(target_locs)


def test_pipeline_3d(test_points_3d):
fake_points_file, fake_points = test_points_3d

points_key = GraphKey("TEST_POINTS")
scale = 2
csv_source = CsvPointsSource(
fake_points_file,
points_key,
spatial_cols=[0, 2, 1],
delimiter=",",
id_col=3,
points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
scale=scale,
)

request = BatchRequest()
shape = Coordinate((100, 100, 100))
request.add(points_key, shape)

pipeline = csv_source
with build(pipeline) as b:
request = b.request_batch(request)

result_points = list(request[points_key].nodes)
for node in result_points:
orig_loc = fake_points[int(node.id)]
reordered_loc = orig_loc.copy()
reordered_loc[1] = orig_loc[2]
reordered_loc[2] = orig_loc[1]
assert list(node.location) == list(reordered_loc * scale)

0 comments on commit 18409b7

Please sign in to comment.