From 7d1c772cc4f505db62b043af6d4a56bcb3d7457f Mon Sep 17 00:00:00 2001 From: ncullen93 Date: Thu, 28 Mar 2024 12:43:04 +0100 Subject: [PATCH] ENH: add tests for multi-input datasets --- examples/Liver-tumor-segmentation.ipynb | 30 +++++++++---------------- nitrain/datasets/configs.py | 10 ++++++++- nitrain/datasets/folder_dataset.py | 4 +++- tests/test_datasets.py | 13 +++++++++++ 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/examples/Liver-tumor-segmentation.ipynb b/examples/Liver-tumor-segmentation.ipynb index 7050ea3..26ae890 100644 --- a/examples/Liver-tumor-segmentation.ipynb +++ b/examples/Liver-tumor-segmentation.ipynb @@ -53,29 +53,19 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "ANTsImage (RPI)\n", - "\t Pixel Type : float (float32)\n", - "\t Components : 1\n", - "\t Dimensions : (512, 512, 75)\n", - "\t Spacing : (0.7031, 0.7031, 5.0)\n", - "\t Origin : (-172.9, 179.2969, -368.0)\n", - "\t Direction : [ 1. 0. 0. 0. -1. 0. 0. 0. 1.]\n", - "\n", - "ANTsImage (RPI)\n", - "\t Pixel Type : float (float32)\n", - "\t Components : 1\n", - "\t Dimensions : (512, 512, 75)\n", - "\t Spacing : (0.7031, 0.7031, 5.0)\n", - "\t Origin : (-172.9, 179.2969, -368.0)\n", - "\t Direction : [ 1. 0. 0. 0. -1. 0. 0. 0. 1.]\n", - "\n" + "ename": "AttributeError", + "evalue": "'ComposeConfig' object has no attribute 'values'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnitrain\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FolderDataset\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mFolderDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m~/Desktop/kaggle-liver-ct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvolumes/volume-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvolumes/volume-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msegmentations/segmentation-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m x, y \u001b[38;5;241m=\u001b[39m dataset[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(x)\n", + "File \u001b[0;32m~/Desktop/nitrain/nitrain/datasets/folder_dataset.py:41\u001b[0m, in \u001b[0;36mFolderDataset.__init__\u001b[0;34m(self, base_dir, x, y, x_transforms, y_transforms)\u001b[0m\n\u001b[1;32m 38\u001b[0m x_config \u001b[38;5;241m=\u001b[39m _infer_config(x, base_dir)\n\u001b[1;32m 39\u001b[0m y_config \u001b[38;5;241m=\u001b[39m _infer_config(y, base_dir)\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[43mx_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(y_config\u001b[38;5;241m.\u001b[39mvalues):\n\u001b[1;32m 42\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFound that len(x) [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(x_config\u001b[38;5;241m.\u001b[39mvalues)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m] != len(y) [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(y_config\u001b[38;5;241m.\u001b[39mvalues)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]. Attempting to match ids.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 43\u001b[0m x_config, y_config \u001b[38;5;241m=\u001b[39m _align_configs(x_config, y_config)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'ComposeConfig' object has no attribute 'values'" ] } ], diff --git a/nitrain/datasets/configs.py b/nitrain/datasets/configs.py index b83382f..bbec8b8 100644 --- a/nitrain/datasets/configs.py +++ b/nitrain/datasets/configs.py @@ -11,6 +11,14 @@ class ComposeConfig: def __init__(self, configs): self.configs = configs + values = [config.values for config in self.configs] + self.values = list(zip(*values)) + + # TODO: align ids for composed configs + if self.configs[0].ids is not None: + self.ids = self.configs[0].ids + else: + self.ids = None def __getitem__(self, idx): return [config[idx] for config in self.configs] @@ -122,7 +130,7 @@ def _infer_config(x, base_dir=None): >>> array = np.random.normal(40,10,(10,50,50,50)) >>> x = _infer_config(array) >>> x = _infer_config([ants.image_read(ants.get_data('r16')) for _ in range(10)]) - >>> x = _infer_config([{'pattern': '**/*.nii.gz'}, {'pattern': '{id}/anat/*.nii.gz'}], base_dir) + >>> x = _infer_config([{'pattern': '{id}/anat/*.nii.gz'}, {'pattern': '{id}/anat/*.nii.gz'}], base_dir) >>> x = _infer_config({'pattern': '{id}/anat/*.nii.gz'}, base_dir) >>> x = _infer_config({'pattern': '*/anat/*.nii.gz'}, base_dir) >>> x = _infer_config({'pattern': '**/*T1w*'}, base_dir) diff --git a/nitrain/datasets/folder_dataset.py b/nitrain/datasets/folder_dataset.py index 16aac85..99b5468 100644 --- a/nitrain/datasets/folder_dataset.py +++ b/nitrain/datasets/folder_dataset.py @@ -27,8 +27,10 @@ def __init__(self, ------- >>> from nitrain.datasets import FolderDataset >>> dataset = FolderDataset('~/Desktop/openneuro/ds004711', - x={'pattern': '{id}/anat/*T1w.nii.gz', + x=[{'pattern': '{id}/anat/*T1w.nii.gz', 'exclude': '**run-02*'}, + {'pattern': '{id}/anat/*T1w.nii.gz', + 'exclude': '**run-02*'}], y={'file':'participants.tsv', 'column':'age', 'id': 'participant_id'}) """ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f9f270f..105c87d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -72,6 +72,19 @@ def test_2d(self): x, y = dataset[:2] self.assertTrue(len(x) == 2) + def test_double_image_input(self): + dataset = datasets.FolderDataset( + base_dir=self.tmp_dir, + x=[{'pattern': '*/img2d.nii.gz'},{'pattern': '*/img2d.nii.gz'}], + y={'file': 'participants.csv', 'column': 'age'} + ) + self.assertTrue(len(dataset.x) == 5) + self.assertTrue(len(dataset.x[0]) == 2) + self.assertTrue(len(dataset.y) == 5) + + x, y = dataset[:2] + self.assertTrue(len(x) == 2) + self.assertTrue(len(x[0]) == 2) def test_2d_image_to_image(self): dataset = datasets.FolderDataset(