Padding GP efficiently #108
Replies: 5 comments 1 reply
-
This is a good question! Depending on your use case, it's nearly always better to re-compile your model for smaller datasets since GPs scale poorly with the size of your data set. But I think there probably are some interesting cases where this would come in handy. In fact, I've been chatting with @aceilers about a GPLVM problem with missing data where such a structure might help performance, but we haven't coded anything up. I think you could use the "augmented data" approach like we use in the Derivative observations tutorial to achieve something like what you want. The basic idea would be that you decorate each data point with a flag saying whether or not it is a "padded" data point (you can't pad with nans, but something like zeros should be fine), and then you write a custom kernel (and a custom mean function, if you're using one!) that returns zero when either of the data points are padded. To be more explicit, here's a simple (untested, etc.) implementation: from tinygp import GaussianProcess, kernels
import jax
import jax.numpy as jnp
class PaddedKernel(kernels.Kernel):
def __init__(self, base_kernel):
self.base_kernel = base_kernel
def evaluate(self, X1, X2):
X1, cond1 = X1
X2, cond2 = X2
val = self.base_kernel.evaluate(X1, X2)
return jnp.where(jnp.logical_and(cond1, cond2), val, 0)
x0 = jnp.linspace(0, 5, 100)
x = jnp.concatenate((x0, jnp.zeros(5)))
cond = jnp.concatenate((jnp.ones(100, dtype=bool), jnp.zeros(5, dtype=bool)))
X = (x, cond)
diag0 = jnp.full_like(x0, 0.1)
diag = jnp.full_like(x, 0.1)
y0 = jnp.sin(x0)
y = jnp.concatenate((y0, jnp.zeros(5)))
k1 = 3.1 * kernels.Matern32(1.5)
k2 = 5.1 * kernels.Matern32(0.5)
k1_pad = PaddedKernel(k1)
k2_pad = PaddedKernel(k2)
print("The actual logprobs are different:")
print(GaussianProcess(k1, x0, diag=diag0).log_probability(y0))
print(GaussianProcess(k1_pad, X, diag=diag).log_probability(y))
print("\nThe _differences_ in logprob for different params are equal:")
print(GaussianProcess(k1, x0, diag=diag0).log_probability(y0) - GaussianProcess(k2, x0, diag=diag0).log_probability(y0))
print(GaussianProcess(k1_pad, X, diag=diag).log_probability(y) - GaussianProcess(k2_pad, X, diag=diag).log_probability(y)) It's important to note that the actual value of the log probability are not the same for these different models, but they only differ by a constant so inference algorithms should work as expected. Hope this helps! |
Beta Was this translation helpful? Give feedback.
-
I've converted this to a "discussion" since it's less of an issue and more of a conceptual discussion. I hope that's ok! |
Beta Was this translation helpful? Give feedback.
-
Thanks - that’s a nicer way of doing it, though as you say the runtime may still scale (badly) with the size of the full padded data. I’m experimenting for my purposes whether I will use
I think the best solution will depend on the specifics of the application. It would be great to have the best of all worlds - avoid recompiles and GP cost scaling with the sizes of the unpadded inputs. Perhaps I need to look again at the linear algebra involved, but I wonder if a special solver module could do it |
Beta Was this translation helpful? Give feedback.
-
@andrewfowlie I wonder what was the conclusion of your investigation that you mentioned above? The reason is that I want to run tinygp code on a vast amount of lightcurves that have different sizes. The compilation for all sizes takes an hour and given that there is no possibility to save the compilation, it seems that only options are 1. padding or 2. long recompiling. As such I would be curious if you found some extra info in the meantime? |
Beta Was this translation helpful? Give feedback.
-
Hi Neven, I didn't do a particularly thorough investigation and don't have anything to show for it other than what I can remember. For my problem
I have a memory that Daniel's example didn't work quite as I first thought. I think the issue was that I used a kernel with a non-zero mean. In Daniel's example he has a zero mean. My non-zero mean meant that the padded values did contribute to the log-likelihood. If the mean is not zero, further little hack is required e.g., set the predictions at the padded values to the mean. I might not have remembered that exactly right. I still have a feeling that something smarter is possible here. |
Beta Was this translation helpful? Give feedback.
-
This is a great little library 👍 I’ve been playing around with jax.jit compiled tinygps.
Sometimes I want to run the tinygp again but change the length of my input coordinates, noise and observed data. That is, change N_data.
This obviously doesn’t work nicely with jax, as it means my model must be recompiled. In other contexts, the solution is ‘padding’ inputs so that they are of fixed size, and making sure padded values don’t impact the computation. See e.g. https://stackoverflow.com/a/68532890/2855071
Is there a sensible or recommended way to pad tinygp inputs?
To be clear, suppose x, y, noise etc are of length N_dim. I want to pad them to be of length > N_dim in such a way that tinygp computations/inferences are unaffected.
I tried padding my inputs with data very far away in input space from the rest of my data, and with very noisy measurements. It seemed to work. But it seems hacky and likely to be unreliable and computationally inefficient.
I tried padding with nans, but got nans back out.
It would be cool if there was an easy, efficient way of doing this. What do you think?
Beta Was this translation helpful? Give feedback.
All reactions