Skip to content

Commit

Permalink
Fix or ignore some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557273333
  • Loading branch information
The jestimator Authors committed Aug 15, 2023
1 parent a5671ac commit 1cc482e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jestimator/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def extract_axes(variables: FrozenDict[str, Any]):
return params, params_axes_, vars_, vars_axes_


class InferState(struct.PyTreeNode):
class InferState(struct.PyTreeNode): # pytype: disable=invalid-function-definition # dataclass_transform
"""State for inference, with support for partitioning."""
step: ArrayLike
apply_fn: Callable = struct.field(pytree_node=False) # pylint: disable=g-bare-generic
Expand Down Expand Up @@ -134,7 +134,7 @@ def as_logical_axes(self) -> 'InferState':
step=None, params=self._params_axes, _vars=self._vars_axes)


class TrainState(struct.PyTreeNode):
class TrainState(struct.PyTreeNode): # pytype: disable=invalid-function-definition # dataclass_transform
"""Train state compatible with T5X partitioning and checkpointing.
Attributes:
Expand Down

0 comments on commit 1cc482e

Please sign in to comment.