Skip to content

Commit

Permalink
fix flax RNN interoperation, fix #663 (#665)
Browse files Browse the repository at this point in the history
fix flax RNN interoperation
  • Loading branch information
chaoming0625 committed Apr 15, 2024
1 parent 4bd1898 commit 4b5b61f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
19 changes: 11 additions & 8 deletions brainpy/_src/dnn/interoperation_flax.py
@@ -1,7 +1,7 @@

import jax
import dataclasses
from typing import Dict
from typing import Dict, Tuple
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from brainpy import math as bm
Expand Down Expand Up @@ -77,16 +77,16 @@ class ToFlaxRNNCell(RNNCellBase):
model: DynamicalSystem
train_params: Dict[str, jax.Array] = dataclasses.field(init=False)

def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
if len(batch_dims) == 0:
def initialize_carry(self, rng, input_shape: Tuple[int, ...]):
batch_dims = input_shape[:-1]
if len(batch_dims) == 1:
batch_dims = 1
elif len(batch_dims) == 1:
batch_dims = batch_dims[0]
elif len(batch_dims) == 0:
batch_dims = None
else:
raise NotImplementedError

raise ValueError(f'Invalid input shape: {input_shape}')
_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
self.model.reset(batch_size=batch_dims)
self.model.reset(batch_dims)
return [_state_vars.dict(), 0, 0.]

def setup(self):
Expand Down Expand Up @@ -131,6 +131,9 @@ def __call__(self, carry, *inputs):
# carray and output
return [_state_vars.dict(), i + 1, t + share.dt], out

@property
def num_feature_axes(self) -> int:
return 1

else:
class ToFlaxRNNCell(object):
Expand Down
44 changes: 44 additions & 0 deletions brainpy/_src/dnn/tests/test_flax.py
@@ -0,0 +1,44 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import pytest

pytest.skip('Skip this test because it is not implemented yet.',
allow_module_level=True)

import jax
import jax.numpy as jnp
import flax.linen as nn

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bm.set_mode(bm.training_mode)

cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1, ))


class myRNN(nn.Module):
@nn.compact
def __call__(self, x): # x:(batch, time, features)
x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell
return x


def test_init():
model = myRNN()
model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature

0 comments on commit 4b5b61f

Please sign in to comment.