Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parsing arguments other than x into log-posterior probability functions #93

Closed
bjricketts opened this issue Feb 9, 2023 · 5 comments · Fixed by #116
Closed

Parsing arguments other than x into log-posterior probability functions #93

bjricketts opened this issue Feb 9, 2023 · 5 comments · Fixed by #116
Assignees

Comments

@bjricketts
Copy link

Hi guys, thanks for the really cool code that I'm currently hoping to implement into my own workflow!

Is it possible to parse other arguments other than x into log-posterior probability functions? This is particularly important for when attempting to sample a distribution which is resultant on comparing observed data to a generated model that is dependent on the sampled parameters.

As an example, the docs specify that a potential target distribution might be:

def log_posterior(x):
    return -0.5 * jnp.sum(x ** 2)

but it seems impossible to be able to write a posterior function that compares to some observed data. As a very simple example:

n_dim = 2

data_x = np.arange(1,10)
m = 2
c = 5
observation = m*data_x + c

def log_posterior(x, observation):
    x_model = np.arange(1,10)
    m = x[0]
    c = x[1]
    y = m*x_model+c
    gaussian = -((y-observation)**2)/(2*np.sqrt(observation)**2)
    return jnp.sum(gaussians)

Is this a possible posterior function that just isn't recorded in the docs? I appreciate that it would be hypothetically possible to "generate" or load the observed data within the posterior function but this becomes much less practically feasible when generating the data becomes computationally expensive and it makes sense to do this operation only once.

@kazewong
Copy link
Owner

kazewong commented Feb 9, 2023

@bjricketts Thanks for bringing up this issue. The current recommended way to compare with the observation is to "pre-baked" the data into the function instead of loading it as an argument.

In your code, since you have defined observation in the global scope, I think your log_posterior function will still work without the observation in the argument. If you want to dynamically generate the data, we would recommend constructing the log_posterior function with another function, something similar to this:

def make_posterior_function(data):
    def log_posterior(x):
        return f(x,data)
    return log_posterior

The reason behind this seemingly complicated syntax is mainly to avoid triggering recompilation when the data size change. Even though the current code does not directly support what you want to do, I think having a version where we can use syntax like log_posterior(x,observation) could be useful. Implementing this will require some thoughts on how to safeguard against unnecessary recompilation, and it will take some work in restructuring the sampler.

There are some community examples out there, here is one of mine on gravitational-wave parameter estimation
https://github.com/kazewong/jim/blob/5f5c51622121cc3b175a0cf0ee00be0f8040b23f/example/ParameterEstimation/GW150914.py#L148
At some point in the future, we are going to link these community examples in the doc.

Please let me know if this is sufficient for your use case. For now, we do not have the bandwidth to restructure the code to accommodate the suggested syntax, but we may come back to this issue at some point.

@bjricketts
Copy link
Author

Thanks @kazewong ! I think this is sufficient for my use case (or at least I can work around the issue with a bit of legwork on my end). This was mostly just a suggestion as there is such a function in the emcee package which enabled me to "generalize" some work I am currently doing with hierarchical Bayesian statistics and reducing the amount of work for me to rewrite or add to some very complex hierarchical Bayesian structures.

While I did only suggest a use case where only observations are an external argument, my own work generally parses a large number of custom functions that fit into another overarching function. Being able to parse series of functions into another function for use is where this use case becomes quite powerful. That being said, I totally understand the fact that this would require a major restructure in the code. I really hope you do get the time at some point to implement it though!

@dfm
Copy link
Collaborator

dfm commented Feb 9, 2023

@bjricketts — Another tip is that you could update your example above as follows:

n_dim = 2

data_x = np.arange(1,10)
m = 2
c = 5
observation = m*data_x + c

def log_posterior(x, observation):
    x_model = np.arange(1,10)
    m = x[0]
    c = x[1]
    y = m*x_model+c
    gaussian = -((y-observation)**2)/(2*jnp.sqrt(observation)**2)
    return jnp.sum(gaussians)

from functools import partial
log_prob_func = partial(log_posterior, observation=observation)

(Where all I've done is added the two final lines, and changed a np to a jnp in your function.)Then use the existing implementation (with log_prob_func as the input to the sampler) without any change. This doesn't seem to onerous to me!

@bjricketts
Copy link
Author

That's a nice solution @dfm! That seems like the best course of action for me for now.

That being said, a perhaps hacky but generalized solution to this issue for implementation into flowMC might be to use something like this:

class _FunctionWrapper(object):
    def __init__(self, f, args, kwargs):
        self.f = f
        self.args = args or []
        self.kwargs = kwargs or {}

    def __call__(self, x):
        return self.f(x, *self.args, **self.kwargs)

where the logpdf is instead replaced by a _FunctionWrapper object when initalizing the sampler. For instance, in the case of MALA, it could be changed to be:

class MALA(LocalSamplerBase):

    def __init__(self, logpdf: Callable, jit: bool, params: dict, args = [], kwargs = {}) -> Callable:
        super().__init__(logpdf, jit, params)
        self.params = params
        self.logpdf = _FunctionWrapper(logpdf,args,kwargs)

where args and kwargs are optional arguments that can be parsed through and called as normal. I think this would probably mean that flowMC would need minimal/no rewriting with this solution and users can arbitrarily add as many arguments as they wish to their log posterior functions. Whether this will work with JAX, I'll admit I don't know.

@kazewong
Copy link
Owner

kazewong commented May 4, 2023

@bjricketts Some updates on this issue: in order to make sure we do not trigger recompilation every time we call the likelihood function, we do need to rehaul the API a bit, such that the sampler can accept data as it input and be aware of that during compilation.

I made some progress along this direction in #116 . Currently the sampler works with MALA in the way you describe. I am gonna clean up other stuffs around the current API, such as examples, implementing the functionality for other local samplers such HMC, then it should be in the released version.

@kazewong kazewong linked a pull request May 4, 2023 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants