-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Describe the issue:
When fitting a correlator using Obs and fits.least_squares directly, I would like to be able to have a singularly defined fit function regardless of the number of excited state terms I choose to include. Hence I might have a function such as the one below (or the vectorised form in the 'code example')
def fcn(p, x):
"""
sum_{alpha} A_alpha (* exp(-E_alpha * t) + exp(-E_\alpha * (NT -t)))
"""
NTerms = int(len(p)/2)
A = p[0:NTerms]
E_P = p[NTerms:]
return (sum(ai * (np.exp(-Ei * x) + np.exp(-Ei * (NT - x)))
for ai, Ei in zip(A, E_P))) / len(A)
This does not work. In fits.py, pyerrors tries to determine the number of fit parameters by iterating over the number of fit parameters until the function works basically. My function defined above will not raise the required IndexError or TypeError and so pyerrors quits out when it tries to check against the number of priors.
My first thought to fixing this (besides just rewriting my function which I don't really want to do) would be an optional argument (a single value or a dict for combined fits) for the number of fit parameters. Alternatively, one could just use the number of priors and/or initial parameters instead.
Code example:
def fcn(p, x):
"""
sum_{alpha} A_alpha (* exp(-E_alpha * t) + exp(-E_\alpha * (Nt -t)))
"""
# Assumes first half of terms are A
# second half are E
NTerms = int(len(p)/2)
A = np.array(p[0:NTerms])[:, np.newaxis] # shape (n, 1)
E_P = np.array(p[NTerms:])[:, np.newaxis] # shape (n, 1)
# This if statement handles the case where x is a single value rather than an array
if isinstance(x, np.float64) or isinstance(x, np.int64) or isinstance(x, float) or isinstance(x, int):
x = np.array([x])[np.newaxis, :] # shape (1, m)
else:
x = np.array(x)[np.newaxis, :] # shape (1, m)
exp_term = np.exp(-E_P * x) + np.exp(-E_P * (NT - x)) # shape (n, m)
weighted_sum = A * exp_term # shape (n, m)
return np.mean(weighted_sum, axis=0) # shape(m)Error message:
Runtime information:
system Linux
python 3.12.3
pyerrors 2.12.0
numpy 1.26.4
scipy 1.13.0
matplotlib 3.8.4
pandas 2.2.2