Skip to content

Commit

Permalink
add foldername reader
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 23, 2024
1 parent d79ec2b commit fc42ed2
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 6 deletions.
4 changes: 0 additions & 4 deletions nitrain/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,8 @@ def split(self, p, shuffle=False):
self._base_dir,
self._base_file)

#print(f'Before. {len(ds_test.inputs.values)}')
ds_train.inputs = ds_train.inputs.select(train_indices)
ds_train.outputs = ds_train.outputs.select(train_indices)
#
#
#print(f'After. {len(ds_test.inputs)}')
ds_test.inputs = ds_test.inputs.select(test_indices)
ds_test.outputs = ds_test.outputs.select(test_indices)

Expand Down
1 change: 1 addition & 0 deletions nitrain/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from .column import ColumnReader
from .compose import ComposeReader
from .folder_name import FolderNameReader
from .image import ImageReader
from .memory import MemoryReader
174 changes: 174 additions & 0 deletions nitrain/readers/folder_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import glob
import os
from parse import parse
from fnmatch import fnmatch
import glob
from google.cloud import storage
from google.oauth2 import service_account

import pandas as pd
import numpy as np
import ants

class FolderNameReader:
"""
Returns the name of the folder for files based on the included pattern.
This is useful if your images are stored in folders whose names have meaning.
For instance:
cat/
img1.jpeg
img2.jpeg
...
dog/
img1.jpeg
img2.jpeg
...
Creating a `FolderNameReader('*/*.jpeg')` would return 'cat' or 'dog' for the
appropriate image. This pairs well with an ImageReader:
dataset = nt.Dataset(ImageReader('*/*.jpeg'), FolderNameReader('*/*.jpeg'))
x, y = dataset[0] # x is img1 loaded in memory; y is 'cat'
"""
def __init__(self, pattern, base_dir=None, exclude=None, label=None, level=0, format='string'):
"""
>>> import ants
>>> from nitrain.readers import ImageReader
>>> reader = FolderNameReader('volumes/*.nii')
>>> reader.map_values(base_dir='~/Desktop/kaggle-liver-ct/')
>>> img = reader[1]
"""
self.pattern = os.path.expanduser(pattern)

if base_dir:
base_dir = os.path.expanduser(base_dir)
self.base_dir = base_dir
self.exclude = exclude
self.label = label
self.level = level
self.format = format

def select(self, idx):
new_reader = FolderNameReader(self.pattern, self.base_dir, self.exclude, self.label, self.level)
new_reader.values = self.values
new_reader.values = [new_reader.values[i] for i in idx]
return new_reader

def map_gcs_values(self, bucket, credentials=None, base_dir=None, base_file=None, base_label=None):
if base_dir is None:
base_dir = self.base_dir

pattern = self.pattern
exclude = self.exclude
level = self.level

glob_pattern = pattern.replace('{id}','*')

if base_dir is not None:
if not base_dir.endswith('/'):
base_dir += '/'
glob_pattern = os.path.join(base_dir, glob_pattern)

# GCS
if isinstance(credentials, str):
credentials = service_account.Credentials.from_service_account_file(credentials)
storage_client = storage.Client(credentials=credentials)
bucket_client = storage_client.bucket(bucket)

x = storage_client.list_blobs(bucket, match_glob=glob_pattern)

x = list([blob.name.replace(base_dir, '') for blob in x])

if exclude:
x = [file for file in x if not fnmatch(file, exclude)]

if '{id}' in pattern:
ids = [parse(pattern.replace('*','{other}'), file).named['id'] for file in x]
else:
ids = None

if len(x) == 0:
raise Exception(f'No filepaths found that match {glob_pattern}')

values = [xx.split('/')[level] for xx in x]
unique_values = np.unique(values)

if self.format == 'integer':
self.values = [np.where(unique_values==v)[0][0] for v in values]
elif self.format == 'onehot':
self.values = [list(np.eye(len(unique_values),
dtype='uint32')[np.where(unique_values==v)[0][0]]) for v in values]
elif self.format == 'string':
self.values = values
else:
raise Exception('The format value must be `integer`, `onehot`, or `string`.')

self.unique_values = list(unique_values)
self.ids = ids

if self.label is None:
if base_label is not None:
self.label = base_label
else:
self.label = 'pattern'

def map_values(self, base_dir=None, base_label=None, **kwargs):
if base_dir is None:
base_dir = self.base_dir

pattern = self.pattern
exclude = self.exclude
level = self.level

glob_pattern = pattern.replace('{id}','*')

if base_dir is not None:
if not base_dir.endswith('/'):
base_dir += '/'
glob_pattern = os.path.join(base_dir, glob_pattern)

x = sorted(glob.glob(glob_pattern, recursive=True))

if base_dir is not None:
x = [os.path.relpath(xx, base_dir) for xx in x]

if exclude:
x = [file for file in x if not fnmatch(file, exclude)]

if '{id}' in pattern:
ids = [parse(pattern.replace('*','{other}'), file).named['id'] for file in x]
else:
ids = None

if len(x) == 0:
raise Exception(f'No filepaths found that match {glob_pattern}')

values = [xx.split('/')[level] for xx in x]
unique_values = np.unique(values)

if self.format == 'integer':
self.values = [np.where(unique_values==v)[0][0] for v in values]
elif self.format == 'onehot':
self.values = [list(np.eye(len(unique_values),
dtype='uint32')[np.where(unique_values==v)[0][0]]) for v in values]
elif self.format == 'string':
self.values = values
else:
raise Exception('The format value must be `integer`, `onehot`, or `string`.')

self.unique_values = list(unique_values)
self.ids = ids

if self.label is None:
if base_label is not None:
self.label = base_label
else:
self.label = 'folder_name'

def __getitem__(self, idx):
return {self.label: self.values[idx]}

def __len__(self):
return len(self.values)
68 changes: 68 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,74 @@ def test_pattern_compose(self):
self.assertEqual(x2[0].mean(), 9)
self.assertEqual(y2.mean(), 109)

class TestReader_FolderNameReader(unittest.TestCase):

def setUp(self):
pass
def tearDown(self):
pass

def test_folder_name(self):
import nitrain as nt
from nitrain import readers
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=readers.ImageReader('*/img3d.nii.gz'),
outputs=readers.FolderNameReader('*/img3d_100.nii.gz'),
base_dir=base_dir)

data_train, data_test = dataset.split(0.8)

self.assertEqual(len(data_train), 8)
self.assertEqual(len(data_test), 2)

x,y=data_train[3]
self.assertEqual(x.mean(), 4)
self.assertEqual(y, 'sub_3')

def test_folder_name_formats(self):
import nitrain as nt
from nitrain import readers
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=readers.ImageReader('*/img3d.nii.gz'),
outputs=readers.FolderNameReader('*/img3d_100.nii.gz',
format='integer'),
base_dir=base_dir)

x,y=dataset[3]
self.assertEqual(x.mean(), 4)
self.assertEqual(y, 3)

dataset = nt.Dataset(inputs=readers.ImageReader('*/img3d.nii.gz'),
outputs=readers.FolderNameReader('*/img3d_100.nii.gz',
format='onehot'),
base_dir=base_dir)

x,y=dataset[3]
self.assertEqual(x.mean(), 4)
self.assertEqual(len(y), 10)
self.assertEqual(y[3], 1)
self.assertEqual(sum(y), 1)

def test_folder_name_compose(self):
import nitrain as nt
from nitrain import readers
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=readers.ImageReader('*/img3d.nii.gz'),
outputs=[readers.FolderNameReader('*/img3d.nii.gz'),
readers.FolderNameReader('*/img3d.nii.gz')],
base_dir=base_dir)

data_train, data_test = dataset.split(0.8)

x,y=dataset[3]
self.assertEqual(x.mean(), 4)
self.assertEqual(y[0], 'sub_3')
self.assertEqual(y[1], 'sub_3')


class TestOther_Bugs(unittest.TestCase):
def setUp(self):
pass
Expand Down
2 changes: 0 additions & 2 deletions tests/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ def test_nested_lists(self):





class TestFunction_infer_reader_dicts(unittest.TestCase):
def setUp(self):
pass
Expand Down

0 comments on commit fc42ed2

Please sign in to comment.