Skip to content

Commit

Permalink
new get_studies() API
Browse files Browse the repository at this point in the history
  • Loading branch information
tyarkoni committed Jan 1, 2015
1 parent d597429 commit 5a8883c
Showing 1 changed file with 148 additions and 15 deletions.
163 changes: 148 additions & 15 deletions neurosynth/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

logger = logging.getLogger('neurosynth.dataset')


def download(path='.', url=None, unpack=False):
""" Download the latest data files.
Args:
Expand All @@ -37,18 +38,20 @@ def download(path='.', url=None, unpack=False):
u = urllib2.urlopen(url)
meta = u.info()
file_size = int(meta.getheaders("Content-Length")[0])
print("Downloading the latest Neurosynth files: {0} bytes: {1}".format(url, file_size))
print("Downloading the latest Neurosynth files: {0} bytes: {1}".format(
url, file_size))

bytes_dl = 0
block_size = 8192
while True:
buffer = u.read(block_size)
if not buffer: break
if not buffer:
break
bytes_dl += len(buffer)
f.write(buffer)
p = float(bytes_dl) / file_size
status = r"{0} [{1:.2%}]".format(bytes_dl, p)
status = status + chr(8)*(len(status)+1)
status = status + chr(8) * (len(status) + 1)
sys.stdout.write(status)

f.close()
Expand Down Expand Up @@ -218,6 +221,136 @@ def delete_mappables(self, ids, remap=True):
if remap:
self.image_table = self.create_image_table()

def get_studies(self, features=None, expression=None, mask=None,
peaks=None, frequency_threshold=0.001,
activation_threshold=0.0, func=np.sum, return_type='ids',
r=6
):
""" Get IDs or data for studies that meet specific criteria.
If multiple criteria are passed, the set intersection is returned. For
example, passing expression='emotion' and mask='my_mask.nii.gz' would
return only those studies that are associated with emotion AND report
activation within the voxels indicated in the passed image.
Args:
ids (list): A list of IDs of studies to retrieve.
features (list or str): The name of a feature, or a list of
features, to use for selecting studies.
expression (str): A string expression to pass to the PEG for study
retrieval.
mask: the mask image (see Masker documentation for valid data
types).
peaks (ndarray or list): Either an n x 3 numpy array, or a list of
lists (e.g., [[-10, 22, 14]]) specifying the world (x/y/z)
coordinates of the target location(s).
frequency_threshold (float): For feature-based or expression-based
selection, the threshold for selecting studies--i.e., the
cut-off for a study to be included. Must be a float in range
[0, 1].
activation_threshold (int or float): For mask-based or peak-based
selection, threshold for a study to be included based on
amount of activation displayed. If an integer, represents the
absolute number of voxels that must be active within the mask
(or generated ROIs) in order for a study to be selected. If a
float, it represents the proportion of voxels that must be
active.
func (Callable): The function to use when aggregating over the list
of features. See documentation in FeatureTable.get_ids() for a
full explanation. Only used for feature- or expression-based
selection.
return_type (str): A string specifying what data to return. Valid
options are:
'ids': returns a list of IDs of selected studies.
'images': returns a voxel x mappable matrix of data for all
selected studies.
'weights': returns a dict where the keys are study IDs and the
values are the computed weights. Only valid when performing
feature-based selection.
r (int): For peak-based selection, the radius in millimeters of the
sphere to grow around each peak.
Returns:
When return_type is 'ids' (default), returns a list of IDs of the
selected studies. When return_type is 'data', returns a 2D numpy
array, with voxels in rows and studies in columns. When return_type
is 'weights' (valid only for expression-based selection), returns
a dict, where the keys are study IDs, and the values are the
computed weights.
Examples
--------
Select all studies tagged with the feature 'emotion':
>>> ids = dataset.get_studies(features='emotion')
Select all studies that activate at least 20% of voxels in an amygdala
mask, and retrieve activation data rather than IDs:
>>> data = dataset.get_studies(mask='amygdala_mask.nii.gz',
threshold=0.2, return_type='images')
Select studies that activate at least 5% of all voxels within 12 mm of
three specific foci:
>>> ids = dataset.get_studies(peaks=[[12, -20, 30], [-26, 22, 22], [0, 36, -20]], r=12)
"""
results = []

# Feature-based selection
if features is not None:
# Need to handle weights as a special case, because we can't
# retrieve the weights later using just the IDs.
if return_type == 'weights':
if expression is not None or mask is not None or \
peaks is not None:
raise ValueError("return_type cannot be 'weights' when "
"feature-based search is used in conjunction with "
"other search modes.")
return self.feature_table.get_ids(
features, frequency_threshold, func, get_weights=True)
else:
results.append(self.feature_table.get_ids(
features, frequency_threshold, func))

# Logical expression-based selection
if expression is not None:
_ids = self.feature_table.get_ids_by_expression(
expression, frequency_threshold, func)
results.append(list(_ids))

# Mask-based selection
if mask is not None:
mask = self.masker.mask(mask).astype(bool)
num_vox = np.sum(mask)
prop_mask_active = self.image_table.data.T.dot(mask).astype(float)
if isinstance(activation_threshold, float):
prop_mask_active /= num_vox
indices = np.where(prop_mask_active > activation_threshold)[0]
results.append([self.image_table.ids[ind] for ind in indices])

# Peak-based selection
if peaks is not None:
peaks = np.array(peaks) # Make sure we have a numpy array
peaks = transformations.xyz_to_mat(peaks)
m = self.masker
img = imageutils.map_peaks_to_image(
peaks, r, vox_dims=m.vox_dims, dims=m.dims,
header=m.get_header())
results.append(self.get_studies(mask=img,
activation_threshold=activation_threshold))

# Get intersection of all sets
ids = list(reduce(lambda x, y: set(x) & set(y), results))

if return_type == 'ids':
return ids
elif return_type == 'data':
return self.get_image_data(ids)

@deprecated("get_mappables() is deprecated and will be removed in 0.5. "
"Please use get_studies().")
def get_mappables(self, ids, get_image_data=False):
""" Takes a list of unique ids and returns corresponding Mappables.
Expand All @@ -235,8 +368,8 @@ def get_mappables(self, ids, get_image_data=False):
else:
return [m for m in self.mappables if m.id in ids]

@deprecated("get_ids_by_features() is deprecated and will be removed in " \
"0.5. Please use get_studies(features=...).")
@deprecated("get_ids_by_features() is deprecated and will be removed in "
"0.5. Please use get_studies(features=...).")
def get_ids_by_features(self, features, threshold=0.001, func=np.sum,
get_image_data=False, get_weights=False):
""" A wrapper for FeatureTable.get_ids().
Expand All @@ -255,16 +388,16 @@ def get_ids_by_features(self, features, threshold=0.001, func=np.sum,
features, threshold, func, get_weights)
return self.get_image_data(ids) if get_image_data else ids

@deprecated("get_ids_by_expression() is deprecated and will be removed in " \
"0.5. Please use get_studies(expression=...).")
@deprecated("get_ids_by_expression() is deprecated and will be removed in "
"0.5. Please use get_studies(expression=...).")
def get_ids_by_expression(self, expression, threshold=0.001, func=np.sum,
get_image_data=False):
ids = self.feature_table.get_ids_by_expression(
expression, threshold, func)
return self.get_image_data(ids) if get_image_data else ids

@deprecated("get_ids_by_mask() is deprecated and will be removed in " \
"0.5. Please use get_studies(mask=...).")
@deprecated("get_ids_by_mask() is deprecated and will be removed in "
"0.5. Please use get_studies(mask=...).")
def get_ids_by_mask(self, mask, threshold=0.0, get_image_data=False):
""" Return all mappable objects that activate within the bounds
defined by the mask image.
Expand All @@ -289,8 +422,8 @@ def get_ids_by_mask(self, mask, threshold=0.0, get_image_data=False):
else:
return [self.image_table.ids[ind] for ind in indices]

@deprecated("get_ids_by_peaks() is deprecated and will be removed in " \
"0.5. Please use get_studies(peaks=...).")
@deprecated("get_ids_by_peaks() is deprecated and will be removed in "
"0.5. Please use get_studies(peaks=...).")
def get_ids_by_peaks(self, peaks, r=10, threshold=0.0,
get_image_data=False):
""" A wrapper for get_ids_by_mask. Takes a set of xyz coordinates and
Expand Down Expand Up @@ -541,8 +674,8 @@ def __init__(self, dataset, **kwargs):
if kwargs:
self.add_features(features, **kwargs)

def add_features(self, features, merge='outer', duplicates='ignore',
min_studies=0, threshold=0.0001):
def add_features(self, features, merge='outer', duplicates='ignore',
min_studies=0, threshold=0.0001):
""" Add new features to FeatureTable.
Args:
features (str, DataFrame): A filename to load data from, or a
Expand Down Expand Up @@ -653,7 +786,7 @@ def get_ordered_names(self, features):
np.in1d(self.data.columns.values, np.array(features)))[0]
return list(self.data.columns[idxs].values)

def get_ids(self, features, threshold=None, func=np.sum, get_weights=False):
def get_ids(self, features, threshold=0.0, func=np.sum, get_weights=False):
""" Returns a list of all Mappables in the table that meet the desired
feature-based criteria.
Expand Down Expand Up @@ -719,7 +852,7 @@ def get_ids_by_expression(self, expression, threshold=0.001, func=np.sum):
parser = lp.Parser(
lexer, self.dataset, threshold=threshold, func=func)
parser.build()
return parser.parse(expression).keys()
return parser.parse(expression).keys().values

def get_features_by_ids(self, ids=None, threshold=0.0001, func=np.mean,
get_weights=False):
Expand Down

0 comments on commit 5a8883c

Please sign in to comment.