Skip to content

Commit

Permalink
Fix pylint errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 364545783
  • Loading branch information
hbq1 authored and OptaxDev committed Mar 23, 2021
1 parent 0eda253 commit 8d405b9
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion optax/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def init_fn(params):
k: jnp.asarray(_convert_floats(v, dtype))
for k, v in numeric_hps.items()}
hparams.update(schedule_fn(count, dtype))
return InjectHyperparamsState(
return InjectHyperparamsState( # pylint:disable=too-many-function-args
count, hparams, inner_factory(**other_hps, **hparams).init(params))

def update_fn(updates, state, params=None):
Expand All @@ -459,7 +459,10 @@ def update_fn(updates, state, params=None):
hparams.update(schedule_fn(count_inc, dtype))
updates, inner_state = inner_factory(**other_hps, **hparams).update(
updates, state.inner_state, params)

# pylint:disable=too-many-function-args
return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
# pylint:enable=too-many-function-args

return transform.GradientTransformation(init_fn, update_fn)

Expand Down

0 comments on commit 8d405b9

Please sign in to comment.