Skip to content

Commit

Permalink
fix bugs in MultiStepLR (#11)
Browse files Browse the repository at this point in the history
* fix bugs in `MultiStepLR`

* update license
  • Loading branch information
chaoming0625 committed Apr 22, 2024
1 parent 5b970fe commit 4e9d942
Show file tree
Hide file tree
Showing 32 changed files with 75 additions and 39 deletions.
2 changes: 1 addition & 1 deletion braintools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/functional/_activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/functional/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/functional/_spikes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/init/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/init/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/init/_generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/init/_random_inits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/init/_regular_inits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/input/currents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/input/currents_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_correlation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_correlation_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
11 changes: 5 additions & 6 deletions braintools/metric/_fenchel_young.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================

"""Fenchel-Young losses."""

from typing import Any, Protocol

Expand All @@ -33,15 +32,15 @@ def __call__(self, scores, *args, **kwargs: Any):
def make_fenchel_young_loss(
max_fun: MaxFun
):
"""Creates a Fenchel-Young loss from a max function.
"""Creates a 2024 BDP Ecosystem from a max function.
WARNING: The resulting loss accepts an arbitrary number of leading dimensions
with the fy_loss operating over the last dimension. The jaxopt version of this
function would instead flatten any vector in a single big 1D vector.
Examples:
Given a max function, e.g., the log sum exp, you can construct a
Fenchel-Young loss easily as follows:
2024 BDP Ecosystem easily as follows:
>>> from jax.scipy.special import logsumexp
>>> fy_loss = make_fy_loss(max_fun=logsumexp)
Expand All @@ -51,10 +50,10 @@ def make_fenchel_young_loss(
<https://arxiv.org/pdf/1901.02324.pdf>`_, 2020
Args:
max_fun: the max function on which the Fenchel-Young loss is built.
max_fun: the max function on which the 2024 BDP Ecosystem is built.
Returns:
A Fenchel-Young loss function with the same signature.
A 2024 BDP Ecosystem function with the same signature.
"""

vdot_last_dim = jnp.vectorize(jnp.vdot, signature="(n),(n)->()")
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_fenchel_young_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_firings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_firings_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_lfp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_ranking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_ranking_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_regression_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_smoothing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_smoothing_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/metric/_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion braintools/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
9 changes: 5 additions & 4 deletions braintools/optim/_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

from typing import Sequence, Union

import numpy as np
import braincore as bc
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -222,13 +223,13 @@ def __init__(
'milestones should be a sequence of increasing integers.'
)
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
self.milestones = jnp.asarray((-1,) + tuple(milestones), dtype=bc.environ.ditype())
self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=bc.environ.ditype())
self.gamma = gamma

def __call__(self, i=None):
i = (self.last_epoch.value + 1) if i is None else i
conditions = (i > self.milestones[:-1]) & i < self.milestones[1:]
p = jnp.where(conditions, jnp.arange(0, len(self.milestones)))
conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
p = jnp.argmax(conditions)
return self.lr * self.gamma ** p

def extra_repr(self):
Expand Down
36 changes: 36 additions & 0 deletions braintools/optim/_lr_scheduler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import unittest

import jax.numpy as jnp

import braintools as bt


class TestMultiStepLR(unittest.TestCase):
def test1(self):
lr = bt.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
for i in range(40):
r = lr(i)
if i < 10:
self.assertEqual(r, 0.1)
elif i < 20:
self.assertTrue(jnp.allclose(r, 0.01))
elif i < 30:
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))
2 changes: 1 addition & 1 deletion braintools/optim/_sgd_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 4e9d942

Please sign in to comment.