Skip to content

Commit

Permalink
ENH: add tests for multi-input datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Mar 28, 2024
1 parent e8220d4 commit 7d1c772
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
30 changes: 10 additions & 20 deletions examples/Liver-tumor-segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,19 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ANTsImage (RPI)\n",
"\t Pixel Type : float (float32)\n",
"\t Components : 1\n",
"\t Dimensions : (512, 512, 75)\n",
"\t Spacing : (0.7031, 0.7031, 5.0)\n",
"\t Origin : (-172.9, 179.2969, -368.0)\n",
"\t Direction : [ 1. 0. 0. 0. -1. 0. 0. 0. 1.]\n",
"\n",
"ANTsImage (RPI)\n",
"\t Pixel Type : float (float32)\n",
"\t Components : 1\n",
"\t Dimensions : (512, 512, 75)\n",
"\t Spacing : (0.7031, 0.7031, 5.0)\n",
"\t Origin : (-172.9, 179.2969, -368.0)\n",
"\t Direction : [ 1. 0. 0. 0. -1. 0. 0. 0. 1.]\n",
"\n"
"ename": "AttributeError",
"evalue": "'ComposeConfig' object has no attribute 'values'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnitrain\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FolderDataset\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mFolderDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m~/Desktop/kaggle-liver-ct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvolumes/volume-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvolumes/volume-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msegmentations/segmentation-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m x, y \u001b[38;5;241m=\u001b[39m dataset[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(x)\n",
"File \u001b[0;32m~/Desktop/nitrain/nitrain/datasets/folder_dataset.py:41\u001b[0m, in \u001b[0;36mFolderDataset.__init__\u001b[0;34m(self, base_dir, x, y, x_transforms, y_transforms)\u001b[0m\n\u001b[1;32m 38\u001b[0m x_config \u001b[38;5;241m=\u001b[39m _infer_config(x, base_dir)\n\u001b[1;32m 39\u001b[0m y_config \u001b[38;5;241m=\u001b[39m _infer_config(y, base_dir)\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[43mx_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(y_config\u001b[38;5;241m.\u001b[39mvalues):\n\u001b[1;32m 42\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFound that len(x) [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(x_config\u001b[38;5;241m.\u001b[39mvalues)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m] != len(y) [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(y_config\u001b[38;5;241m.\u001b[39mvalues)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]. Attempting to match ids.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 43\u001b[0m x_config, y_config \u001b[38;5;241m=\u001b[39m _align_configs(x_config, y_config)\n",
"\u001b[0;31mAttributeError\u001b[0m: 'ComposeConfig' object has no attribute 'values'"
]
}
],
Expand Down
10 changes: 9 additions & 1 deletion nitrain/datasets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
class ComposeConfig:
def __init__(self, configs):
self.configs = configs
values = [config.values for config in self.configs]
self.values = list(zip(*values))

# TODO: align ids for composed configs
if self.configs[0].ids is not None:
self.ids = self.configs[0].ids
else:
self.ids = None

def __getitem__(self, idx):
return [config[idx] for config in self.configs]
Expand Down Expand Up @@ -122,7 +130,7 @@ def _infer_config(x, base_dir=None):
>>> array = np.random.normal(40,10,(10,50,50,50))
>>> x = _infer_config(array)
>>> x = _infer_config([ants.image_read(ants.get_data('r16')) for _ in range(10)])
>>> x = _infer_config([{'pattern': '**/*.nii.gz'}, {'pattern': '{id}/anat/*.nii.gz'}], base_dir)
>>> x = _infer_config([{'pattern': '{id}/anat/*.nii.gz'}, {'pattern': '{id}/anat/*.nii.gz'}], base_dir)
>>> x = _infer_config({'pattern': '{id}/anat/*.nii.gz'}, base_dir)
>>> x = _infer_config({'pattern': '*/anat/*.nii.gz'}, base_dir)
>>> x = _infer_config({'pattern': '**/*T1w*'}, base_dir)
Expand Down
4 changes: 3 additions & 1 deletion nitrain/datasets/folder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def __init__(self,
-------
>>> from nitrain.datasets import FolderDataset
>>> dataset = FolderDataset('~/Desktop/openneuro/ds004711',
x={'pattern': '{id}/anat/*T1w.nii.gz',
x=[{'pattern': '{id}/anat/*T1w.nii.gz',
'exclude': '**run-02*'},
{'pattern': '{id}/anat/*T1w.nii.gz',
'exclude': '**run-02*'}],
y={'file':'participants.tsv', 'column':'age',
'id': 'participant_id'})
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def test_2d(self):
x, y = dataset[:2]
self.assertTrue(len(x) == 2)

def test_double_image_input(self):
dataset = datasets.FolderDataset(
base_dir=self.tmp_dir,
x=[{'pattern': '*/img2d.nii.gz'},{'pattern': '*/img2d.nii.gz'}],
y={'file': 'participants.csv', 'column': 'age'}
)
self.assertTrue(len(dataset.x) == 5)
self.assertTrue(len(dataset.x[0]) == 2)
self.assertTrue(len(dataset.y) == 5)

x, y = dataset[:2]
self.assertTrue(len(x) == 2)
self.assertTrue(len(x[0]) == 2)

def test_2d_image_to_image(self):
dataset = datasets.FolderDataset(
Expand Down

0 comments on commit 7d1c772

Please sign in to comment.