Skip to content

Commit

Permalink
Adds clip_negative argument to normalize_spatial
Browse files Browse the repository at this point in the history
Used for fixing ellipse fits to 2D spatial RFs with strong surrounds.
Closes #66
  • Loading branch information
Niru Maheswaranathan committed Dec 22, 2015
1 parent 6ff6ee8 commit ea945eb
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions pyret/filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def rolling_window(array, window, time_axis=0):
return arr


def normalize_spatial(spatial_filter, scale_factor=1.0):
def normalize_spatial(spatial_filter, scale_factor=1.0, clip_negative=False):
"""
Normalizes a spatial frame by doing the following:
1. mean subtraction using a robust estimate of the mean (ignoring outliers)
Expand All @@ -451,7 +451,10 @@ def normalize_spatial(spatial_filter, scale_factor=1.0):
scale_factor : float, optional
The given filter is resampled at a sampling rate of this ratio times
the original sampling rate (default: 1.0)
the original sampling rate (Default: 1.0)
clip_negative : boolean, optional
Whether or not to clip negative values to 0. (Default: True)
"""

Expand All @@ -470,8 +473,14 @@ def normalize_spatial(spatial_filter, scale_factor=1.0):
# normalize by the standard deviation of the pixel values
rf_centered /= rf_centered.std()

# return this normalized filter, resampled by the given amount
return resample(rf_centered, scale_factor)
# resample by the given amount
rf_resampled = resample(rf_centered, scale_factor)

# clip negative values
if clip_negative:
rf_resampled = np.maximum(rf_resampled, 0)

return rf_resampled


def get_contours(spatial_filter, threshold=10.0):
Expand Down Expand Up @@ -560,8 +569,8 @@ def get_ellipse(tx, ty, spatial_filter, pvalue=0.6827):
"""

# preprocess
zdata = normalize_spatial(spatial_filter).ravel()
zdata /= np.max(zdata)
zdata = normalize_spatial(spatial_filter, clip_negative=True).ravel()
zdata /= zdata.max()

# get initial parameters
xm, ym = np.meshgrid(tx, ty)
Expand Down

0 comments on commit ea945eb

Please sign in to comment.