-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates to nonlinearity module. Removed all old unused functions,
changed to using a class-based system. Users create nonlinearities of the given type, then can call fit() and predict() to learn from data and fit the nonlinearity to new values, respectively.
- Loading branch information
Showing
1 changed file
with
133 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,194 +1,148 @@ | ||
""" | ||
Tools for fitting nonlinear functions to data | ||
""" | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from scipy.optimize import curve_fit | ||
from scipy.interpolate import interp1d | ||
from functools import wraps | ||
from itertools import zip_longest | ||
|
||
__all__ = ['gaussian', 'sigmoid', 'dprime', 'fitgaussian', | ||
'fitsigmoid', 'estdprime', 'estnln'] | ||
|
||
|
||
def gaussian(x, mu, sigma): | ||
""" | ||
A 1D (unnormalized) gaussian function | ||
""" | ||
|
||
return np.exp(-0.5 * ((x-mu) / sigma)**2) | ||
|
||
|
||
def sigmoid(x, threshold, slope, peak, offset): | ||
""" | ||
A sigmoidal nonlinearity | ||
""" | ||
|
||
return offset + peak / (1 + np.exp(-slope*(x - threshold))) | ||
|
||
|
||
def dprime(p0, p1): | ||
""" | ||
compute d' between two distributions given mean / standard deviation | ||
Parameters | ||
---------- | ||
p0 : (float, float) | ||
Mean and standard deviation for the first distribution | ||
p1 : (float, float) | ||
Mean and standard deviation for the second distribution | ||
""" | ||
return (p1[0] - p0[0]) / np.sqrt(p1[1]**2 + p0[1]**2) | ||
|
||
|
||
def fitgaussian(xpts, ypts, p0=None): | ||
""" | ||
Fit a gaussian function to noisy data | ||
Parameters | ||
---------- | ||
xpts : array_like | ||
x-values of the data to fit | ||
ypts : array_like | ||
y-values of the data to fit | ||
Returns | ||
------- | ||
popt : array_like | ||
The best-fit sigmoidal parameters (threshold, slope, peak, and offset) | ||
yhat : array_like | ||
The estimated y-values at the given locations in xpts | ||
pcov [matrix]: | ||
""" | ||
|
||
# estimate initial conditions | ||
if p0 is None: | ||
p0 = (np.mean(xpts), 5*np.mean(np.diff(xpts))) | ||
|
||
# normalize the max to have value 1 | ||
scalefactor = float(np.max(ypts)) | ||
ypts = ypts / scalefactor | ||
|
||
# get parameters | ||
popt, pcov = curve_fit(gaussian, xpts, ypts, p0) | ||
|
||
# evaluate fit | ||
yhat = gaussian(xpts, *popt) * scalefactor | ||
|
||
return popt, yhat, pcov | ||
|
||
|
||
def fitsigmoid(xpts, ypts, **kwargs): | ||
""" | ||
Fit a sigmoidal function to noisy data | ||
Parameters | ||
---------- | ||
xpts : array_like | ||
x-values of the data to fit | ||
ypts : array_like | ||
y-values of the data to fit | ||
kwargs | ||
Optional keyword arguments are passed to `scipy.optimize.curve_fit`, and | ||
can be used to control the fitting procedure more carefully. This may be | ||
needed, e.g., if the nonlinearities are quite noisy. | ||
Returns | ||
------- | ||
popt : array_like | ||
The best-fit sigmoidal parameters (threshold, slope, peak, and offset) | ||
yhat : array_like | ||
The estimated y-values at the given locations in xpts | ||
pcov : array_like | ||
""" | ||
|
||
# estimate initial conditions | ||
p0 = (np.mean(xpts), 1, np.max(ypts), np.min(ypts)) | ||
|
||
# get parameters | ||
popt, pcov = curve_fit(sigmoid, xpts, ypts, p0, **kwargs) | ||
|
||
# evaluate fit | ||
yhat = sigmoid(xpts, *popt) | ||
|
||
return popt, yhat, pcov | ||
|
||
|
||
def estdprime(u, r, numbins=100): | ||
""" | ||
Fit a nonlinearity given a 1D stimulus projection u and spiking response r | ||
""" | ||
|
||
# pick a set of bins, store centered bins | ||
bins = np.linspace(np.min(u), np.max(u), numbins) | ||
bincenters = bins[:-1] + np.mean(np.diff(bins))*0.5 | ||
|
||
# bin the raw stimulus distribution | ||
raw, _ = np.histogram(u, bins) | ||
|
||
# bin the spike-triggered distribution | ||
data = u[r > 0] | ||
spk, _ = np.histogram(data, bins) | ||
|
||
# estimate gaussian parameters | ||
try: | ||
raw_params = fitgaussian(bincenters, raw, (np.mean(u), np.std(u)))[0] | ||
spk_params = fitgaussian(bincenters, spk, (np.mean(data), np.std(data)))[0] | ||
except RuntimeError: | ||
print('Warning: Gaussian curve fit did not converge') | ||
raw_params = (np.mean(u), np.std(u)) | ||
spk_params = (np.mean(data), np.std(data)) | ||
|
||
# estimate d' | ||
return dprime(raw_params, spk_params) | ||
|
||
|
||
def estnln(u, r, numbins=50): | ||
""" | ||
Fit a nonlinearity given a 1D stimulus projection u and spiking response r | ||
""" | ||
__all__ = ['Sigmoid', 'Binterp'] | ||
|
||
# the minimum number of data points / bin to keep for fitting | ||
mincount = 2 | ||
|
||
# bin the raw stimulus distribution | ||
raw, bins = np.histogram(u, numbins) | ||
class Nonlinearity: | ||
def __init__(self): | ||
pass | ||
|
||
# bin the spike-triggered distribution | ||
spk, _ = np.histogram(u[r > 0], bins) | ||
def plot(self, start, stop, n=100): | ||
x = np.linspace(start, stop, n) | ||
plt.plot(x, self.predict(x)) | ||
|
||
# find locations where there are enough data points | ||
locs = np.logical_and((raw > mincount), (spk > mincount)) | ||
def fit(self, x, y): | ||
"""Fits the parameters of the nonlinearity | ||
# normalize the two distributions | ||
raw = raw / float(np.sum(raw)) | ||
spk = spk / float(np.sum(spk)) | ||
Parameters | ||
---------- | ||
x : array_like | ||
input to the nonlinearity | ||
# take the ratio of the two distributions | ||
ratio = spk[locs] / raw[locs] | ||
xvals = bins[locs] | ||
y : array_like | ||
output of the nonlinearity (must have the same shape as x) | ||
""" | ||
raise NotImplementedError | ||
|
||
def predict(self, x): | ||
"""Computes the value of the function at the given input | ||
# fit a sigmoid to the results | ||
popt, yhat, pcov = fitsigmoid(xvals, ratio) | ||
Parameters | ||
---------- | ||
x : array_like | ||
The input to the nonlinearity | ||
Returns | ||
y : array_like | ||
The output of the nonlinearity | ||
""" | ||
raise NotImplementedError | ||
|
||
@wraps(predict) | ||
def __call__(self, x): | ||
return self.predict(x) | ||
|
||
|
||
class Sigmoid(Nonlinearity): | ||
def __init__(self, baseline=0., peak=1., slope=1., threshold=0.): | ||
"""A sigmoidal nonlinearity | ||
Estimates a nonlinearity of the following form: | ||
.. math:: f(x) = \beta + \frac{\alpha}{(1 + \exp(-\gamma * (x - \theta)))} | ||
Usage | ||
----- | ||
>>> f = Sigmoid().fit(x_train, y_train) | ||
>>> yhat = f.predict(x_test) # f(x_test) works as well | ||
Parameters | ||
---------- | ||
baseline : float | ||
y-offset (baseline) | ||
peak : float | ||
maximum response | ||
slope : float | ||
gain of the sigmoid | ||
threshold : float | ||
midpoint of the sigmoid | ||
""" | ||
self.init_params = (baseline, peak, slope, threshold) | ||
|
||
def fit(self, x, y, **kwargs): | ||
self.params, self.pcov = curve_fit(self._sigmoid, x, y, self.init_params, **kwargs) | ||
return self | ||
|
||
@staticmethod | ||
def _sigmoid(x, threshold, slope, peak, baseline): | ||
return baseline + peak / (1 + np.exp(-slope * (x - threshold))) | ||
|
||
def predict(self, x): | ||
try: | ||
return self._sigmoid(x, *self.params) | ||
except NameError: | ||
raise RuntimeError('No estimated parameters, call fit() first') | ||
|
||
|
||
class Binterp(Nonlinearity): | ||
def __init__(self, nbins, method='linear', fill_value='extrapolate'): | ||
"""Interpolated nonlinearity by sorting and binning the data | ||
Given samples (x, y) from the nonlinearity, bin the values using | ||
variable-sized bins with roughly equal counts, and then interpolates | ||
between the mean y-value in each bin using scipy.interpolate.interp1d. | ||
Parameters | ||
---------- | ||
nbins : int | ||
How many bins to use along the input axis | ||
# return values | ||
return popt, xvals, ratio, yhat, pcov | ||
method : str, optional | ||
How to do the interpolation (Default: 'linear'). Possible values: 'linear', | ||
'quadratic', 'cubic', 'nearest', 'slinear', 'zero'. See scipy.interpolate.interp1d | ||
for details. | ||
fill_value : str or value, optional | ||
How to fill in values outside the range of bins (Default: 'extrapolate') | ||
Note: 'extrapolate' only works for the 'linear' or 'nearest' methods, | ||
see scipy.interpolate.interp1d for details | ||
""" | ||
self.nbins = nbins | ||
self.method = method | ||
self.fill_value = fill_value | ||
|
||
@staticmethod | ||
def _grouper(iterable, n, fillvalue=None): | ||
"Collect data into fixed-length chunks or blocks" | ||
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" | ||
args = [iter(iterable)] * n | ||
|
||
# TODO: make this more performant | ||
return np.array(list(zip_longest(*args, fillvalue=fillvalue))) | ||
|
||
def fit(self, x, y): | ||
binsize = int(x.size / self.nbins) | ||
|
||
# sort the x-values and create variable bin edges with equal counts | ||
indices = np.argsort(x) | ||
self.bins = x[indices][::binsize] | ||
y_grouped = self._grouper(y[indices], binsize, fillvalue=np.nan) | ||
self.values = np.nanmean(y_grouped, axis=1) | ||
|
||
# set the predict function using scipy.interpolate.interp1d | ||
self.predict = interp1d(self.bins, self.values, kind=self.method, fill_value=self.fill_value) | ||
return self | ||
|
||
def predict(self, x): | ||
raise RuntimeError('No estimated parameters, call fit() first') |