Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

How to deal with shared submodules? #8

Open
avital opened this issue Sep 14, 2021 · 8 comments
Open

How to deal with shared submodules? #8

avital opened this issue Sep 14, 2021 · 8 comments

Comments

@avital
Copy link

avital commented Sep 14, 2021

Hi @cgarciae -- treex is cool, in particular I really like the idea of using different classes for "variable kinds" and how it plays pretty nicely with typing (though I was a bit confused about why you need both __init__ and module_init but I probably didn't understand exactly how to use it right).

I playing around a bit with it, and was curious if you've thought about how to deal with shared submodules. Because JAX's tree utils (which all JAX transformations use) flatten and unflatten, when you unflatten you effectively get clones of the input pytrees.

I made a short Colab to better describe the issue: https://colab.sandbox.google.com/drive/1R77juKqGWQ3H2_yJ9wh8nNIVzbFJEuyV#scrollTo=Sh2bXFQJpdLd

Do you think this can be resolved in treex? I believe Objax solves this issue by forcing users to use their own Jit module which goes through some utility function that re-pipes the input DAG structure (it only flattens each reference to a shared instance once, whereas Treex flattens each reference to the shared instance). But a library like that is much less interoperable with the rest of the JAX ecosystem, like with Optax...

@cgarciae
Copy link
Owner

cgarciae commented Sep 14, 2021

Hey @avital!

Module sharing

Its an interesting problem, currently you would get two different copies as you described. I tried these two approaches Treex could try to implement to support this but each approach has its weaknesses:

1. Try to reuse instances when unflattening
import jax
import jax.tree_util

import treex as tx


class Channel:
    def __init__(self, value):
        self.value = value


class SharedModule(tx.Module):
    x: tx.State[int]

    def __init__(self) -> None:
        super().__init__()
        self.x = 0
        self.channel = Channel(None)

    def tree_flatten(self):
        # channel is reset when unflattening
        self.channel.value = None
        return super().tree_flatten()

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        channel = aux_data["channel"]

        if channel.value is None:
            obj = super().tree_unflatten(aux_data, children)

            # share instance
            channel.value = obj

        else:
            # reuse instance
            obj = channel.value

        return obj


m = SharedModule()


@jax.jit
def f(m1, m2):
    m1.x = 100
    return m1, m2


m1, m2 = f(m, m)

assert m1.x == m2.x == 100
2. Share a reference (Channel) to submodules and update the reference when flattening / unflattening
from typing import Generic, TypeVar
import flax
import treex as tx
import jax
import jax.tree_util

A = TypeVar("A")


class Channel(Generic[A]):
    def __init__(self, value: A):
        self.value = value


class SharedModule(tx.Module):
    x: tx.State[int]

    def __init__(self, x: int) -> None:
        super().__init__()
        self.x = x
        self.channel = Channel(self)

    def tree_flatten(self):
        self.channel.value = self
        return super().tree_flatten()

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        obj = super().tree_unflatten(aux_data, children)

        obj.channel.value = obj

        return obj

    def get_shared(self):
        return self.channel


class ChildModule(tx.Module):
    def __init__(self, shared_channel: Channel[SharedModule]):
        super().__init__()
        self.shared_channel = shared_channel

    def __call__(self, x):
        return x + self.shared_channel.value.x


class ParentModule(tx.Module):
    shared_module: tx.Module  # its going to be a SharedModule but we need an annotation that Treex recognizes for now
    child1: ChildModule
    child2: ChildModule

    def __init__(self):
        super().__init__()
        self.shared_module = SharedModule(1)
        self.child1 = ChildModule(self.shared_module.get_shared())
        self.child2 = ChildModule(self.shared_module.get_shared())

    def __call__(self, x):
        return self.child1(x) + self.child2(x)


parent = ParentModule()

assert parent(10) == 22  # (10 + 1) + (10 + 1)


@jax.jit
def f(parent):
    parent.shared_module.x = 2

    return parent(10)


assert f(parent) == 24  # (10 + 2) + (10 + 2)

I don't know if this suits your use cases but maybe just keeping a single instance and explicitly passing the shared module top-down as an argument to submodule's __call__ where its needed seems like a safer solution that doesn't require special tooling but its a bit more cumbersome. I'd hate to create a context-dependent operation like Objax did because you loose compatibility with regular JAX code.

Can you post a concrete example? Maybe simple solutions can be found to common patterns.

Regarding Treex

though I was a bit confused about why you need both init and module_init but I probably didn't understand exactly how to use it right

I'll review the readme to make it more clear 😅 module_init and the field Initializer objects are convenient if you require a unique PRNGKey that can passed to you when the user call init with a key, but in your example they are not needed. Code below fixes your example, the only thing wrong was that CounterVariable was just needed as an annotation:

class CounterVariable(TreePart):
    pass

class ModuleWithOneCounter(tx.Module):
    value: CounterVariable[int] # <-- TreeParts are just needed for the annotation, actual fields contain regular values

    def __init__(self):
        super().__init__()
        self.value = 0
    
class TwoSharedCounters(tx.Module):
    def __init__(self):
        super().__init__()
        self.counter1 = ModuleWithOneCounter()
        self.counter2 = self.counter1

    def __call__(self):
          # only increment self.counter1; self.counter2 should also grow
          # because self.counter1 == self.counter2
          self.counter1.value = self.counter1.value + 1

@cgarciae
Copy link
Owner

cgarciae commented Sep 15, 2021

Hey @avital, I think this solves the problem in a general manner! 🥳

It involves some modifications to Module, a context manager, and a decorator called share:

Solution
import functools
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
import jax.numpy as jnp
import jax
import jax.tree_util
import treex as tx


@dataclass
class Context(threading.local):
    shared: Optional[Dict[int, Any]] = None

    def __enter__(self):
        global _CONTEXT
        self._old_context = _CONTEXT
        _CONTEXT = self

    def __exit__(self, *args):
        global _CONTEXT
        _CONTEXT = self._old_context


_CONTEXT = Context()


def _share(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        with Context(shared={}):
            return f(*args, **kwargs)

    return wrapper


def _clear(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        with Context():
            outputs = f(*args, **kwargs)

        if _CONTEXT.shared is not None:
            _CONTEXT.shared.clear()

        return outputs

    return wrapper


def share(jit):
    def decorator(f):
        return _share(jit(_clear(f)))

    return decorator


class Module(tx.Module):
    def tree_flatten(self):
        children, aux_data = super().tree_flatten()

        if _CONTEXT.shared is not None:
            aux_data["__id"] = id(self)

        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):

        if _CONTEXT.shared is not None:
            if aux_data["__id"] in _CONTEXT.shared:
                return _CONTEXT.shared[aux_data["__id"]]
            else:
                obj = super().tree_unflatten(aux_data, children)
                _CONTEXT.shared[aux_data["__id"]] = obj
                return obj
        else:
            return super().tree_unflatten(aux_data, children)

share can wrap any jax function an reuse objects based on previous identity in a clean manner. With it you can do the following:

class SomeModule(Module):
    x: tx.State[int]

    def __init__(self, x: int):
        super().__init__()
        self.x = x


@share(jax.jit)
def f(m1, m2):
    assert m1 is m2  # relative identy is preserved for inputs
    m1.x = 2
    return m1, m2


m = SomeModule(1)

m1, m2 = f(m, m)

assert m1 is m2  # relative identy is preserved for outputs
assert m1.x == m2.x == 2  # state is preserved
assert isinstance(m1.x, jnp.ndarray)  # assert jit happened

Wow, I am very happy with this solution, I actually had tried something like this a year ago but couldn't find it. I believe all Pytree-based libraries can use this solution!

@cgarciae
Copy link
Owner

I think preserve_identities might be a better name for this decorator. Enabling module sharing seems to be a consequence but preserving relative identities seems to be what is actually doing.

@avital
Copy link
Author

avital commented Sep 15, 2021

IIUC with your new proposal, you must use preserve_identities(jax.jit) rather than jax.jit. And the same is true for all other JAX transformations. But what if you call into another library that's built around jax.grad such as jaxopt or a third-party pytree-based checkpointing library? Or if you want to use Chex's variants to test both jitted and non-jitted versions of functions?

In general, I think you lose fully general interop between libraries that assume "vanilla JAX" interfaces and pure functions

@cgarciae
Copy link
Owner

Yeah, this is true. At least its a good first step, I will open a PR with the solution.

I opened google/jax#7919 with a proposal yesterday as a more general JAX native solution. Hopefully the JAX team can some give support for this use case.

@cgarciae
Copy link
Owner

cgarciae commented Sep 17, 2021

@avital made some progress, your example now works 🥳

class ModuleWithOneCounter(tx.Module):
    value: tx.State[int]

    def __init__(self):
        super().__init__()
        self.value = 1

class TwoSharedCounters(tx.Module):
    def __init__(self):
        super().__init__()
        self.counter1 = ModuleWithOneCounter()
        self.counter2 = ModuleWithOneCounter()

    def __call__(self):
        # only increment self.counter1; self.counter2 should also grow
        # because self.counter1 == self.counter2
        self.counter1.value = self.counter1.value + 1

model = TwoSharedCounters().init(42)

def call_model(m):
    m()
    return m

model_no_jit = call_model(model)

model_no_jit.counter1 is model_no_jit.counter2
model_no_jit.counter1.value == model_no_jit.counter2.value

model_jit = jax.jit(call_model)(model)

model_jit.counter1 is model_jit.counter2
model_jit.counter1.value == model_jit.counter2.value

@cgarciae
Copy link
Owner

Update: found edge cases.

@avital
Copy link
Author

avital commented Sep 20, 2021

Oh! Curious what the edge cases are? (And I'm surprised by the code snippet you posted last because it doesn't set self.counter2 = self.counter1)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants