Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
import typing as tp
from typing import Any

from flax import nnx
from flax import linen
from flax import nnx
from flax.core import FrozenDict
from flax.core import meta
from flax.nnx import graph
from flax.nnx.bridge import variables as bv
from flax.nnx.module import GraphDef, Module
from flax.nnx.object import Object
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.nnx.object import Object
import jax
from jax import tree_util as jtu

Expand Down Expand Up @@ -220,7 +221,7 @@ class ToLinen(linen.Module):
"""
nnx_class: tp.Callable[..., Module]
args: tp.Sequence = ()
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

Expand Down Expand Up @@ -277,4 +278,4 @@ def _update_variables(self, module):
def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
return ToLinen(nnx_class, args=args, kwargs=FrozenDict(kwargs), name=name)