Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[regression] Pickling a jitted function with cloudpickle #5043

Closed
KristianHolsheimer opened this issue Nov 30, 2020 · 10 comments
Closed

[regression] Pickling a jitted function with cloudpickle #5043

KristianHolsheimer opened this issue Nov 30, 2020 · 10 comments
Labels
enhancement New feature or request

Comments

@KristianHolsheimer
Copy link

KristianHolsheimer commented Nov 30, 2020

Hi there,

I just stumbled upon a regression between jax==0.2.5 and jax==0.2.6. I make heavy use of cloudpickle (cloudpickle==1.6.0) for serializing objects. I do this either directly (example) or indirectly (example).

I just noticed that I can no longer pickle jitted functions after upgrading from jax==0.2.5 to jax==0.2.6.

Here's a minimal script to reproduce:

$ pip install cloudpickle==1.6.0 jax==0.2.5
$ python3 -c "
import jax
import cloudpickle as pickle
@jax.jit
def f(x):
  return x * 13
pickle.dumps(f)
print('OK')
"
OK

Same but with later version of jax:

$ pip install cloudpickle==1.6.0 jax==0.2.6
$ python3 -c "
import jax
import cloudpickle as pickle
@jax.jit
def f(x):
  return x * 13
pickle.dumps(f)
print('OK')
"

Traceback (most recent call last):
  File "<string>", line 7, in <module>
  File "/home/kris/.local/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/home/kris/.local/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 563, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'jaxlib.xla_extension.jax_jit.CompiledFunction' object

Related issue: #679

FWIW This is a blocker for me.

Some system information:

OS: Ubuntu 20.04.1 LTS
Python: 3.8.5
Other packages: jaxlib==0.1.57+cuda110

@KristianHolsheimer
Copy link
Author

KristianHolsheimer commented Nov 30, 2020

I'm not very familiar with the jax codebase, but it looks like CompiledFunction is defined here.

There's a dunder defined in there: __signature__

Would it be an option to add a __reduce__ in there as well?

(also, I'm out of my depths here, so feel free to ignore these remarks if they don't make any sense)

@KristianHolsheimer
Copy link
Author

KristianHolsheimer commented Nov 30, 2020

I'm not sure which project to ask for help. I went here first because an older version jax==0.2.5 still worked.

But for reference, I also opened an issue here: cloudpipe/cloudpickle#402

@hawkinsp
Copy link
Collaborator

We haven't ever explicitly supported or tested pickling, so I'm not surprised that it broke. If we want to support it, at the least we would need tests to make sure it stays working!

You're correct: to make this work, we'd need to implement a __reduce__ in that C++ class.

@jblespiau as an FYI.

@hawkinsp hawkinsp added the enhancement New feature or request label Nov 30, 2020
@KristianHolsheimer
Copy link
Author

Thanks @hawkinsp

I had a look at it, hoping to implement CompiledFunction.__reduce__ myself, but I couldn't figure it out easily.

Please could you have a look at this?

Or alternatively, please could you outline your dev environment for building and integrating the tensorflow/xla code into jax? And perhaps could you help me and give some hints as to how you might implement this?

Any help to get me unstuck with this would be appreciated.
Thanks

KristianHolsheimer added a commit to microsoft/coax that referenced this issue Dec 9, 2020
@KristianHolsheimer
Copy link
Author

Is there perhaps a possibility to do this by serializing jitted_func.__wrapped__ and then to re-decorate the function with jax.jit when you deserialize it?

I thought this might be easier to do, by using copyreg instead of CompiledFunction.__reduce__.

The problem with this is that the decorated function is of builtin type function. If it was of some custom type I could just register it with copyreg.

@jblespiau
Copy link
Contributor

jblespiau commented Dec 9, 2020

I do not know much about your use-case, but have you looked whether it was possible to do it differently?

For example,it should be possible to pickle f, and then call again jax.jit(f) on the other location.

Depending on what you do, this may actually be what you want:

  • f is specifying the function to execute inPython. jax.jit(f) will compile the function (well, to be precise only the first call to jax.jit(f)(x) will), for the device it will run on. So if you pickle a function on one machine, to execute it on another, you actually may want to recompile (e.g. the first machine has only CPUs, the second GPUs, and you want to use them, or some specific vectorized options).
  • jitted_f is "jit compiled". It's a little strange semantically speaking, to pickle it, as it's expected to be done at run-time.

Given this comment, could you elaborate on why you think your use-case is legitimate?

@KristianHolsheimer
Copy link
Author

Thanks so much for your reply @jblespiau

I see, yes I agree that serializing the internal state of CompiledFunction does not make sense. In particular, the device and backend arguments of jax.jit aren't portable to other devices.

The route of serializing the original function f is the way to go.

I have two use cases for pickling jax.jit(f):

  1. To do checkpoints where I store a couple of objects at once, see e.g. Best way to save a model/agent? microsoft/coax#2 (comment)

  2. To take a jitted function and distribute it on a spark or ray cluster, which use cloudpickle.

By the way, the objects that I'd like to use have the following structure:

class FunctionApproximator:
  def __init__(self):
    def func(params, x):
      return ...
    self.func = jax.jit(func)
    self.params = ...

  def __call__(self, x):
    return self.func(self.params, x)

@KristianHolsheimer
Copy link
Author

Come to think of it, if I had access to fun, static_argnums and donate_argnums I could implement my own __getstate__ and __setstate__ like this:

class FunctionApproximator:
  def __init__(self):
    def func(params, x):
      return ...
    self.func = jax.jit(func)
    self.params = ...

  def __call__(self, x):
    return self.func(self.params, x)

  def __getstate__(self):
    funcs = {
      k: (v.fun, v.static_argnums, v.donate_argnums)
      for k, v in self.__dict__.items() if hasattr(v, '_cpp_jitted_f')}
    other = {k: v for k, v in self.__dict__.items() if not hasattr(v, '_cpp_jitted_f')}
    return funcs, other

  def __setstate__(self, state):
    funcs, other = state
    funcs = {k: jax.jit(f, static_argnums=s, donate_argnums=d) for k, (f, s, d) in funcs.items()}
    self.__dict__.update(funcs)
    self.__dict__.update(other)

I'm guessing that fun is just func.__wrapped__, but I don't know how to access static_argnums and donate_argnums.

@MathieuCiancone
Copy link

Hi @KristianHolsheimer ,

First it's a great occasion to say thank you for your work on coax, which I've been using quite a lot recently. The tutorial video helped a lot with getting started.

I'm experiencing the same problem you had when it comes to saving Q and pi using coax.utils.dump : i get a TypeError: can't pickle jaxlib.xla_extension.jax_jit.CompiledFunction objects.

Have you eventually found a turn around ?

Some config details

jax==0.2.9
jaxlib==0.1.60
cloudpickle==1.6.0
python==3.6.1

@KristianHolsheimer
Copy link
Author

@Sunalwing Thanks for letting me know. I actually fixed it a little while ago, but didn't bump the version.

A new release is now available (run pip install --upgrade coax).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants