-
Notifications
You must be signed in to change notification settings - Fork 17
How to deal with shared submodules? #8
Comments
Hey @avital! Module sharingIts an interesting problem, currently you would get two different copies as you described. I tried these two approaches Treex could try to implement to support this but each approach has its weaknesses: 1. Try to reuse instances when unflatteningimport jax
import jax.tree_util
import treex as tx
class Channel:
def __init__(self, value):
self.value = value
class SharedModule(tx.Module):
x: tx.State[int]
def __init__(self) -> None:
super().__init__()
self.x = 0
self.channel = Channel(None)
def tree_flatten(self):
# channel is reset when unflattening
self.channel.value = None
return super().tree_flatten()
@classmethod
def tree_unflatten(cls, aux_data, children):
channel = aux_data["channel"]
if channel.value is None:
obj = super().tree_unflatten(aux_data, children)
# share instance
channel.value = obj
else:
# reuse instance
obj = channel.value
return obj
m = SharedModule()
@jax.jit
def f(m1, m2):
m1.x = 100
return m1, m2
m1, m2 = f(m, m)
assert m1.x == m2.x == 100 2. Share a reference (Channel) to submodules and update the reference when flattening / unflatteningfrom typing import Generic, TypeVar
import flax
import treex as tx
import jax
import jax.tree_util
A = TypeVar("A")
class Channel(Generic[A]):
def __init__(self, value: A):
self.value = value
class SharedModule(tx.Module):
x: tx.State[int]
def __init__(self, x: int) -> None:
super().__init__()
self.x = x
self.channel = Channel(self)
def tree_flatten(self):
self.channel.value = self
return super().tree_flatten()
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = super().tree_unflatten(aux_data, children)
obj.channel.value = obj
return obj
def get_shared(self):
return self.channel
class ChildModule(tx.Module):
def __init__(self, shared_channel: Channel[SharedModule]):
super().__init__()
self.shared_channel = shared_channel
def __call__(self, x):
return x + self.shared_channel.value.x
class ParentModule(tx.Module):
shared_module: tx.Module # its going to be a SharedModule but we need an annotation that Treex recognizes for now
child1: ChildModule
child2: ChildModule
def __init__(self):
super().__init__()
self.shared_module = SharedModule(1)
self.child1 = ChildModule(self.shared_module.get_shared())
self.child2 = ChildModule(self.shared_module.get_shared())
def __call__(self, x):
return self.child1(x) + self.child2(x)
parent = ParentModule()
assert parent(10) == 22 # (10 + 1) + (10 + 1)
@jax.jit
def f(parent):
parent.shared_module.x = 2
return parent(10)
assert f(parent) == 24 # (10 + 2) + (10 + 2) I don't know if this suits your use cases but maybe just keeping a single instance and explicitly passing the shared module top-down as an argument to submodule's Can you post a concrete example? Maybe simple solutions can be found to common patterns. Regarding Treex
I'll review the readme to make it more clear 😅 class CounterVariable(TreePart):
pass
class ModuleWithOneCounter(tx.Module):
value: CounterVariable[int] # <-- TreeParts are just needed for the annotation, actual fields contain regular values
def __init__(self):
super().__init__()
self.value = 0
class TwoSharedCounters(tx.Module):
def __init__(self):
super().__init__()
self.counter1 = ModuleWithOneCounter()
self.counter2 = self.counter1
def __call__(self):
# only increment self.counter1; self.counter2 should also grow
# because self.counter1 == self.counter2
self.counter1.value = self.counter1.value + 1 |
Hey @avital, I think this solves the problem in a general manner! 🥳 It involves some modifications to Solutionimport functools
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
import jax.numpy as jnp
import jax
import jax.tree_util
import treex as tx
@dataclass
class Context(threading.local):
shared: Optional[Dict[int, Any]] = None
def __enter__(self):
global _CONTEXT
self._old_context = _CONTEXT
_CONTEXT = self
def __exit__(self, *args):
global _CONTEXT
_CONTEXT = self._old_context
_CONTEXT = Context()
def _share(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
with Context(shared={}):
return f(*args, **kwargs)
return wrapper
def _clear(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
with Context():
outputs = f(*args, **kwargs)
if _CONTEXT.shared is not None:
_CONTEXT.shared.clear()
return outputs
return wrapper
def share(jit):
def decorator(f):
return _share(jit(_clear(f)))
return decorator
class Module(tx.Module):
def tree_flatten(self):
children, aux_data = super().tree_flatten()
if _CONTEXT.shared is not None:
aux_data["__id"] = id(self)
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
if _CONTEXT.shared is not None:
if aux_data["__id"] in _CONTEXT.shared:
return _CONTEXT.shared[aux_data["__id"]]
else:
obj = super().tree_unflatten(aux_data, children)
_CONTEXT.shared[aux_data["__id"]] = obj
return obj
else:
return super().tree_unflatten(aux_data, children)
class SomeModule(Module):
x: tx.State[int]
def __init__(self, x: int):
super().__init__()
self.x = x
@share(jax.jit)
def f(m1, m2):
assert m1 is m2 # relative identy is preserved for inputs
m1.x = 2
return m1, m2
m = SomeModule(1)
m1, m2 = f(m, m)
assert m1 is m2 # relative identy is preserved for outputs
assert m1.x == m2.x == 2 # state is preserved
assert isinstance(m1.x, jnp.ndarray) # assert jit happened Wow, I am very happy with this solution, I actually had tried something like this a year ago but couldn't find it. I believe all Pytree-based libraries can use this solution! |
I think |
IIUC with your new proposal, you must use In general, I think you lose fully general interop between libraries that assume "vanilla JAX" interfaces and pure functions |
Yeah, this is true. At least its a good first step, I will open a PR with the solution. I opened google/jax#7919 with a proposal yesterday as a more general JAX native solution. Hopefully the JAX team can some give support for this use case. |
@avital made some progress, your example now works 🥳 class ModuleWithOneCounter(tx.Module):
value: tx.State[int]
def __init__(self):
super().__init__()
self.value = 1
class TwoSharedCounters(tx.Module):
def __init__(self):
super().__init__()
self.counter1 = ModuleWithOneCounter()
self.counter2 = ModuleWithOneCounter()
def __call__(self):
# only increment self.counter1; self.counter2 should also grow
# because self.counter1 == self.counter2
self.counter1.value = self.counter1.value + 1
model = TwoSharedCounters().init(42)
def call_model(m):
m()
return m
model_no_jit = call_model(model)
model_no_jit.counter1 is model_no_jit.counter2
model_no_jit.counter1.value == model_no_jit.counter2.value
model_jit = jax.jit(call_model)(model)
model_jit.counter1 is model_jit.counter2
model_jit.counter1.value == model_jit.counter2.value |
Update: found edge cases. |
Oh! Curious what the edge cases are? (And I'm surprised by the code snippet you posted last because it doesn't set |
Hi @cgarciae -- treex is cool, in particular I really like the idea of using different classes for "variable kinds" and how it plays pretty nicely with typing (though I was a bit confused about why you need both
__init__
andmodule_init
but I probably didn't understand exactly how to use it right).I playing around a bit with it, and was curious if you've thought about how to deal with shared submodules. Because JAX's tree utils (which all JAX transformations use) flatten and unflatten, when you unflatten you effectively get clones of the input pytrees.
I made a short Colab to better describe the issue: https://colab.sandbox.google.com/drive/1R77juKqGWQ3H2_yJ9wh8nNIVzbFJEuyV#scrollTo=Sh2bXFQJpdLd
Do you think this can be resolved in treex? I believe Objax solves this issue by forcing users to use their own Jit module which goes through some utility function that re-pipes the input DAG structure (it only flattens each reference to a shared instance once, whereas Treex flattens each reference to the shared instance). But a library like that is much less interoperable with the rest of the JAX ecosystem, like with Optax...
The text was updated successfully, but these errors were encountered: