Skip to content

Commit

Permalink
Merge pull request #166 from kyleabeauchamp/hesswarning
Browse files Browse the repository at this point in the history
Fix warnings
  • Loading branch information
kyleabeauchamp committed Jan 23, 2015
2 parents f401e44 + e649c0b commit b378f06
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pymbar/mbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __init__(self, u_kn, N_k, maximum_iterations=10000, relative_tolerance=1.0e-
print(self.f_k)

self.f_k = mbar_solvers.solve_mbar_with_subsampling(self.u_kn, self.N_k, self.f_k, solver_protocol, subsampling_protocol, subsampling, x_kindices=self.x_kindices)
self.Log_W_nk = np.log(mbar_solvers.mbar_W_nk(self.u_kn, self.N_k, self.f_k))
self.Log_W_nk = mbar_solvers.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k)

# Print final dimensionless free energies.
if self.verbose:
Expand Down
35 changes: 30 additions & 5 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ def mbar_hessian(u_kn, N_k, f_k):
return -1.0 * H


def mbar_log_W_nk(u_kn, N_k, f_k):
"""Calculate the log weight matrix.
Parameters
----------
u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float'
The reduced potential energies, i.e. -log unnormalized probabilities
N_k : np.ndarray, shape=(n_states), dtype='int'
The number of samples in each state
f_k : np.ndarray, shape=(n_states), dtype='float'
The reduced free energies of each state
Returns
-------
logW_nk : np.ndarray, dtype='float', shape=(n_samples, n_states)
The normalized log weights.
Notes
-----
Equation (9) in JCP MBAR paper.
"""
u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k)

log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1)
logW = f_k - u_kn.T - log_denominator_n[:, np.newaxis]
return logW

def mbar_W_nk(u_kn, N_k, f_k):
"""Calculate the weight matrix.
Expand All @@ -200,11 +227,7 @@ def mbar_W_nk(u_kn, N_k, f_k):
-----
Equation (9) in JCP MBAR paper.
"""
u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k)

log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1)
W = np.exp(f_k - u_kn.T - log_denominator_n[:, np.newaxis])
return W
return np.exp(mbar_log_W_nk(u_kn, N_k, f_k))


def precondition_u_kn(u_kn, N_k, f_k):
Expand Down Expand Up @@ -291,6 +314,8 @@ def solve_mbar_once(u_kn_nonzero, N_k_nonzero, f_k_nonzero, method="hybr", tol=1
hess = lambda x: mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][:, 1:] # Hessian of objective function

if method in ["L-BFGS-B", "dogleg", "CG", "BFGS", "Newton-CG", "TNC", "trust-ncg", "SLSQP"]:
if method in ["L-BFGS-B", "CG"]:
hess = None # To suppress warning from passing a hessian function.
results = scipy.optimize.minimize(grad_and_obj, f_k_nonzero[1:], jac=True, hess=hess, method=method, tol=tol, options=options)
else:
results = scipy.optimize.root(grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options)
Expand Down

0 comments on commit b378f06

Please sign in to comment.