From 4b5b61ff0fa16cf6f5f729fa1f711ddc1e180164 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 15 Apr 2024 09:46:26 +0800 Subject: [PATCH] fix flax RNN interoperation, fix #663 (#665) fix flax RNN interoperation --- brainpy/_src/dnn/interoperation_flax.py | 19 ++++++----- brainpy/_src/dnn/tests/test_flax.py | 44 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 8 deletions(-) create mode 100644 brainpy/_src/dnn/tests/test_flax.py diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py index 9804ac3b..9e051ebf 100644 --- a/brainpy/_src/dnn/interoperation_flax.py +++ b/brainpy/_src/dnn/interoperation_flax.py @@ -1,7 +1,7 @@ import jax import dataclasses -from typing import Dict +from typing import Dict, Tuple from jax.tree_util import tree_flatten, tree_map, tree_unflatten from brainpy import math as bm @@ -77,16 +77,16 @@ class ToFlaxRNNCell(RNNCellBase): model: DynamicalSystem train_params: Dict[str, jax.Array] = dataclasses.field(init=False) - def initialize_carry(self, rng, batch_dims, size=None, init_fn=None): - if len(batch_dims) == 0: + def initialize_carry(self, rng, input_shape: Tuple[int, ...]): + batch_dims = input_shape[:-1] + if len(batch_dims) == 1: batch_dims = 1 - elif len(batch_dims) == 1: - batch_dims = batch_dims[0] + elif len(batch_dims) == 0: + batch_dims = None else: - raise NotImplementedError - + raise ValueError(f'Invalid input shape: {input_shape}') _state_vars = self.model.vars().unique().not_subset(bm.TrainVar) - self.model.reset(batch_size=batch_dims) + self.model.reset(batch_dims) return [_state_vars.dict(), 0, 0.] def setup(self): @@ -131,6 +131,9 @@ def __call__(self, carry, *inputs): # carray and output return [_state_vars.dict(), i + 1, t + share.dt], out + @property + def num_feature_axes(self) -> int: + return 1 else: class ToFlaxRNNCell(object): diff --git a/brainpy/_src/dnn/tests/test_flax.py b/brainpy/_src/dnn/tests/test_flax.py new file mode 100644 index 00000000..b452d782 --- /dev/null +++ b/brainpy/_src/dnn/tests/test_flax.py @@ -0,0 +1,44 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import pytest + +pytest.skip('Skip this test because it is not implemented yet.', + allow_module_level=True) + +import jax +import jax.numpy as jnp +import flax.linen as nn + +import brainpy as bp +import brainpy.math as bm + +bm.set_platform('cpu') +bm.set_mode(bm.training_mode) + +cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1, )) + + +class myRNN(nn.Module): + @nn.compact + def __call__(self, x): # x:(batch, time, features) + x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell + return x + + +def test_init(): + model = myRNN() + model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature