-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[regression] Pickling a jitted function with cloudpickle #5043
Comments
I'm not very familiar with the jax codebase, but it looks like There's a dunder defined in there: Would it be an option to add a (also, I'm out of my depths here, so feel free to ignore these remarks if they don't make any sense) |
I'm not sure which project to ask for help. I went here first because an older version But for reference, I also opened an issue here: cloudpipe/cloudpickle#402 |
We haven't ever explicitly supported or tested pickling, so I'm not surprised that it broke. If we want to support it, at the least we would need tests to make sure it stays working! You're correct: to make this work, we'd need to implement a @jblespiau as an FYI. |
Thanks @hawkinsp I had a look at it, hoping to implement Please could you have a look at this? Or alternatively, please could you outline your dev environment for building and integrating the Any help to get me unstuck with this would be appreciated. |
Is there perhaps a possibility to do this by serializing I thought this might be easier to do, by using The problem with this is that the decorated function is of builtin type |
I do not know much about your use-case, but have you looked whether it was possible to do it differently? For example,it should be possible to pickle Depending on what you do, this may actually be what you want:
Given this comment, could you elaborate on why you think your use-case is legitimate? |
Thanks so much for your reply @jblespiau I see, yes I agree that serializing the internal state of The route of serializing the original function I have two use cases for pickling
By the way, the objects that I'd like to use have the following structure: class FunctionApproximator:
def __init__(self):
def func(params, x):
return ...
self.func = jax.jit(func)
self.params = ...
def __call__(self, x):
return self.func(self.params, x) |
Come to think of it, if I had access to class FunctionApproximator:
def __init__(self):
def func(params, x):
return ...
self.func = jax.jit(func)
self.params = ...
def __call__(self, x):
return self.func(self.params, x)
def __getstate__(self):
funcs = {
k: (v.fun, v.static_argnums, v.donate_argnums)
for k, v in self.__dict__.items() if hasattr(v, '_cpp_jitted_f')}
other = {k: v for k, v in self.__dict__.items() if not hasattr(v, '_cpp_jitted_f')}
return funcs, other
def __setstate__(self, state):
funcs, other = state
funcs = {k: jax.jit(f, static_argnums=s, donate_argnums=d) for k, (f, s, d) in funcs.items()}
self.__dict__.update(funcs)
self.__dict__.update(other)
I'm guessing that |
Hi @KristianHolsheimer , First it's a great occasion to say thank you for your work on coax, which I've been using quite a lot recently. The tutorial video helped a lot with getting started. I'm experiencing the same problem you had when it comes to saving Have you eventually found a turn around ? Some config detailsjax==0.2.9 |
@Sunalwing Thanks for letting me know. I actually fixed it a little while ago, but didn't bump the version. A new release is now available (run |
Hi there,
I just stumbled upon a regression between
jax==0.2.5
andjax==0.2.6
. I make heavy use of cloudpickle (cloudpickle==1.6.0
) for serializing objects. I do this either directly (example) or indirectly (example).I just noticed that I can no longer pickle jitted functions after upgrading from
jax==0.2.5
tojax==0.2.6
.Here's a minimal script to reproduce:
Same but with later version of jax:
Related issue: #679
FWIW This is a blocker for me.
Some system information:
OS: Ubuntu 20.04.1 LTS
Python: 3.8.5
Other packages:
jaxlib==0.1.57+cuda110
The text was updated successfully, but these errors were encountered: