Skip to content

Commit

Permalink
update loader tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 3, 2024
1 parent 0071cc2 commit ba51b8a
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def test_multi_image_to_image(self):
loader2 = loader.to_keras()
x_batch, y_batch = next(iter(loader2))
self.assertTrue(len(x_batch) == 2)
self.assertTrue(x_batch[0].shape == (4, 256, 256, 1))
self.assertTrue(x_batch[1].shape == (4, 256, 256, 1))
self.assertTrue(y_batch.shape == (4, 256, 256, 1))
self.assertTrue(tuple(x_batch[0].shape) == (4, 256, 256, 1))
self.assertTrue(tuple(x_batch[1].shape) == (4, 256, 256, 1))
self.assertTrue(tuple(y_batch.shape) == (4, 256, 256, 1))

gen = record_generator(loader)
xb,yb = next(iter(gen))
Expand All @@ -143,17 +143,16 @@ def test_image_to_image_with_slice_sampler(self):
dataset = nt.Dataset(x, x)
loader = nt.Loader(dataset,
images_per_batch=4,
sampler=samplers.SliceSampler(batch_size=12))
sampler=samplers.SliceSampler(batch_size=12, axis=0))

x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (12, 218, 182, 1))
self.assertTrue(y_batch.shape == (12, 218, 182, 1))
self.assertEqual(x_batch.shape, (12, 218, 182, 1))
self.assertEqual(y_batch.shape, (12, 218, 182, 1))

loader2 = loader.to_keras()
x_batch, y_batch = next(iter(loader2))
x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (12, 218, 182, 1))
self.assertTrue(y_batch.shape == (12, 218, 182, 1))
self.assertEqual(tuple(x_batch.shape), (12, 218, 182, 1))
self.assertEqual(tuple(y_batch.shape), (12, 218, 182, 1))

def test_multiple_image_slice(self):
base_dir = nt.fetch_data('example-01')
Expand Down

0 comments on commit ba51b8a

Please sign in to comment.