diff --git a/docs/nnx/surgery.ipynb b/docs/nnx/surgery.ipynb index 9ced966fc..5cc0b10a1 100644 --- a/docs/nnx/surgery.ipynb +++ b/docs/nnx/surgery.ipynb @@ -6,15 +6,15 @@ "source": [ "# Model surgery\n", "\n", - "This guide will demostrate how to do model surgery in NNX with a few real-scenario use cases.\n", + "In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n", "\n", - "* __Module manipulation__: Pythonic ways to manipulate submodules given a model.\n", + "* __Python module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n", "\n", - "* __Abstact model__: A key trick to play with NNX modules and states without memory allocation.\n", + "* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n", "\n", - "* __From raw state to model__: How to manipulate parameter state when they are incompatible with existing model code.\n", + "* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", "\n", - "* __Partial initialization__: Initializing only part of the model from scratch." + "* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method." ] }, { @@ -61,11 +61,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Pythonic module manipulations\n", + "## Pythonic module manipulation\n", "\n", - "Model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n", + "Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n", "\n", - "You can make a variety of pythonic operations on its submodules, like swapping in/out, sharing modules/weights, monkeypatching, etc. See a few code examples below." + "You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching:" ] }, { @@ -78,7 +78,7 @@ "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n", "\n", - "# Submodule swapping\n", + "# Sub-module swapping\n", "original1, original2 = model.linear1, model.linear2\n", "model.linear1, model.linear2 = model.linear2, model.linear1\n", "np.testing.assert_allclose(model(x), original1(original2(x)))\n", @@ -89,14 +89,14 @@ "assert not hasattr(nnx.state(model), 'linear2')\n", "np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n", "\n", - "# Variable sharing (weight tying)\n", + "# Variable sharing (weight-tying)\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n", "assert hasattr(nnx.state(model), 'linear2')\n", "assert hasattr(nnx.state(model)['linear2'], 'bias')\n", "assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n", "\n", - "# Monkey patching\n", + "# Monkey-patching\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "def awesome_layer(x): return x\n", "model.linear2 = awesome_layer\n", @@ -107,7 +107,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create model and state without memory allocation\n", + "## Creating an abstract model or state without memory allocation\n", "\n", "For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", "\n", @@ -115,7 +115,7 @@ "* Create a function that returns a valid NNX model; and\n", "* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n", "\n", - "Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model is now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information." + "Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information." ] }, { @@ -177,7 +177,7 @@ "abs_state['linear2']['kernel'].value = model.linear2.kernel\n", "abs_state['linear2']['bias'].value = model.linear2.bias\n", "nnx.update(abs_model, abs_state)\n", - "np.testing.assert_allclose(abs_model(x), model(x)) # they are equivalent now!" + "np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!" ] }, { @@ -186,9 +186,9 @@ "source": [ "## Checkpoint surgery\n", "\n", - "With the abstract state technique in hand, we can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with our given model code, then call `nnx.update` to merge them.\n", + "With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n", "\n", - "This is helpful when you are to change model code significantly (e.g., migrating from Linen to NNX) so that old weights are no longer naturally compatible. Let's run a simple example here." + "This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX) and old weights are no longer naturally compatible. Let's run a simple example here:" ] }, { @@ -197,7 +197,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Save a version of model into a checkpoint\n", + "# Save a version of a model into a checkpoint.\n", "checkpointer = orbax.PyTreeCheckpointer()\n", "old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)" @@ -207,7 +207,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this new model, the submodules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure." + "In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure:" ] }, { @@ -293,7 +293,7 @@ " state = nnx.State.from_flat_path(state)\n", " return nnx.merge(graph_def, state)\n", "\n", - "# Make your local change on the checkpoint\n", + "# Make your local change on the checkpoint.\n", "raw = checkpointer.restore('/tmp/nnx-surgery-state')\n", "pprint(raw)\n", "raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n", @@ -314,7 +314,7 @@ "source": [ "## Partial initialization\n", "\n", - "In some cases (e.g., LoRA), you might want to randomly-initialize only *part of* your model parameters." + "In some cases (such as with Low-Rank Adapation (LoRA)), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization." ] }, { @@ -323,9 +323,9 @@ "source": [ "### Naive partial initialization\n", "\n", - "You can simply initialize the whole model, then swap pre-trained params in. But this approach could allocate additional memory midway, if your modification requires re-creating module params that you will later discard. See this example below.\n", + "You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below.\n", "\n", - "> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting kernel & running from scratch will always yield same output." + "> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output." ] }, { @@ -350,7 +350,7 @@ "simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n", "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", "# On this line, extra kernel and bias is created inside the new LoRALinear!\n", - "# They are wasted b/c we are to use the kernel and bias in `old_state` anyway.\n", + "# They are wasted since you are going to use the kernel and bias in `old_state` anyway.\n", "simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n", "print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n", " ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n", @@ -365,7 +365,7 @@ "source": [ "### Memory-efficient partial initialization\n", "\n", - "Use `nnx.jit`'s efficiently compiled code to make sure only the state params you need are initialized." + "Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" ] }, { @@ -383,16 +383,16 @@ } ], "source": [ - "# Some pretrained model state\n", + "# Some pretrained model state.\n", "old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "\n", - "# Use nnx.jit (which wraps jax.jit) to automatically skip unused arrays - memory efficient!\n", + "# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n", "@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n", "def partial_init(old_state, rngs):\n", " model = TwoLayerMLP(4, rngs=rngs)\n", - " # create new state\n", + " # Create a new state.\n", " model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)\n", - " # add existing create\n", + " # Add the existing state.\n", " nnx.update(model, old_state)\n", " return model\n", "\n", diff --git a/docs/nnx/surgery.md b/docs/nnx/surgery.md index af56548cc..cd7a73a85 100644 --- a/docs/nnx/surgery.md +++ b/docs/nnx/surgery.md @@ -1,14 +1,14 @@ # Model surgery -This guide will demostrate how to do model surgery in NNX with a few real-scenario use cases. +In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases: -* __Module manipulation__: Pythonic ways to manipulate submodules given a model. +* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model. -* __Abstact model__: A key trick to play with NNX modules and states without memory allocation. +* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation. -* __From raw state to model__: How to manipulate parameter state when they are incompatible with existing model code. +* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. -* __Partial initialization__: Initializing only part of the model from scratch. +* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method. ```python @@ -42,11 +42,11 @@ class TwoLayerMLP(nnx.Module): return self.linear2(x) ``` -## Pythonic module manipulations +## Pythonic module manipulation -Model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code. +Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code. -You can make a variety of pythonic operations on its submodules, like swapping in/out, sharing modules/weights, monkeypatching, etc. See a few code examples below. +You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching: ```python @@ -54,7 +54,7 @@ model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) np.testing.assert_allclose(model(x), model.linear2(model.linear1(x))) -# Submodule swapping +# Sub-module swapping original1, original2 = model.linear1, model.linear2 model.linear1, model.linear2 = model.linear2, model.linear1 np.testing.assert_allclose(model(x), original1(original2(x))) @@ -65,14 +65,14 @@ model.linear2 = model.linear1 assert not hasattr(nnx.state(model), 'linear2') np.testing.assert_allclose(model(x), model.linear1(model.linear1(x))) -# Variable sharing (weight tying) +# Variable sharing (weight-tying) model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate assert hasattr(nnx.state(model), 'linear2') assert hasattr(nnx.state(model)['linear2'], 'bias') assert not hasattr(nnx.state(model)['linear2'], 'kernel') -# Monkey patching +# Monkey-patching model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) def awesome_layer(x): return x model.linear2 = awesome_layer @@ -80,7 +80,7 @@ np.testing.assert_allclose(model(x), model.linear1(x)) ``` -## Create model and state without memory allocation +## Creating an abstract model or state without memory allocation For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. @@ -88,7 +88,7 @@ To create an abstract model, * Create a function that returns a valid NNX model; and * Run `nnx.eval_shape` (not `jax.eval_shape`) upon it. -Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model is now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information. +Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information. ```python @@ -131,24 +131,24 @@ abs_state['linear1']['bias'].value = model.linear1.bias abs_state['linear2']['kernel'].value = model.linear2.kernel abs_state['linear2']['bias'].value = model.linear2.bias nnx.update(abs_model, abs_state) -np.testing.assert_allclose(abs_model(x), model(x)) # they are equivalent now! +np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now! ``` ## Checkpoint surgery -With the abstract state technique in hand, we can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with our given model code, then call `nnx.update` to merge them. +With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them. -This is helpful when you are to change model code significantly (e.g., migrating from Linen to NNX) so that old weights are no longer naturally compatible. Let's run a simple example here. +This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here: ```python -# Save a version of model into a checkpoint +# Save a version of a model into a checkpoint. checkpointer = orbax.PyTreeCheckpointer() old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) ``` -In this new model, the submodules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure. +In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure: ```python @@ -194,7 +194,7 @@ def module_from_variables_dict(module_factory, variables, map_key_fn): state = nnx.State.from_flat_path(state) return nnx.merge(graph_def, state) -# Make your local change on the checkpoint +# Make your local change on the checkpoint. raw = checkpointer.restore('/tmp/nnx-surgery-state') pprint(raw) raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2'] @@ -223,23 +223,23 @@ np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones( ## Partial initialization -In some cases (e.g., LoRA), you might want to randomly-initialize only *part of* your model parameters. +In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization. ### Naive partial initialization -You can simply initialize the whole model, then swap pre-trained params in. But this approach could allocate additional memory midway, if your modification requires re-creating module params that you will later discard. See this example below. +You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below. -> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting kernel & running from scratch will always yield same output. +> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output. ```python -# Some pretrained model state +# Some pretrained model state. old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') # On this line, extra kernel and bias is created inside the new LoRALinear! -# They are wasted b/c we are to use the kernel and bias in `old_state` anyway. +# They are wasted since you are going to use the kernel and bias in `old_state` anyway. simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42)) print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}' ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)') @@ -255,20 +255,20 @@ print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ### Memory-efficient partial initialization -Use `nnx.jit`'s efficiently compiled code to make sure only the state params you need are initialized. +Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: ```python # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) -# Use nnx.jit (which wraps jax.jit) to automatically skip unused arrays - memory efficient! +# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient! @functools.partial(nnx.jit, donate_argnums=0, static_argnums=1) def partial_init(old_state, rngs): model = TwoLayerMLP(4, rngs=rngs) - # create new state + # Create a new state. model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs) - # add existing create + # Add the existing state. nnx.update(model, old_state) return model