Skip to content

Commit

Permalink
TESTS: add tests to dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 24, 2024
1 parent 3996776 commit 9e88e6c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
6 changes: 0 additions & 6 deletions nitrain/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ def split(self, p, random=True):
return ds_train, ds_test, ds_val
else:
return ds_train, ds_test

def filter(self, expr):
raise NotImplementedError('Not implemented')

def prefetch(self):
raise NotImplementedError('Not implemented')

def __getitem__(self, idx):
reduce = True
Expand Down
17 changes: 16 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_memory(self):

# test repr
r = dataset.__repr__()

def test_multiple_memory(self):
x = [ants.image_read(ants.get_data('r16')) for _ in range(10)]
y = list(range(10))
Expand Down Expand Up @@ -203,6 +203,21 @@ def test_missing_file(self):
inputs=readers.ColumnReader(column='filenames_3d', is_image=True),
outputs=readers.ColumnReader(column='age')
)

def test_split(self):
import ants
import numpy as np
ds = nt.Dataset(
inputs = [ants.from_numpy(np.ones((128,128)))*i for i in range(100)],
outputs = [i for i in range(100)]
)

ds0,ds1,ds2 = ds.split((0.6,0.2,0.2))
ds0,ds1,ds2 = ds.split((0.6,0.2,0.2), random=False)

with self.assertRaises(Exception):
ds0,ds1,ds2 = ds.split((0.6,0.2,0.5))


class TestClass_FolderDataset(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 9e88e6c

Please sign in to comment.