Skip to content

Commit

Permalink
Refactor dtypes and add float8_* dtypes (keras-team#19401)
Browse files Browse the repository at this point in the history
* Refactor dtypes in codebase and add float8_* dtypes

* Update comments

Fix for JAX export on GPU. (keras-team#19404)

Fix formatting in export_lib. (keras-team#19405)

`ops/numpy.py`: Support `key` as `list` in `GetItem` (keras-team#19310)

When loading a model that contains `GetItem` nodes with multidimensional
indices/slices as `key`, the `key` argument is loaded from JSON as a `list`,
not a `tuple` (because JSON does not have the distinction).

So, treat the `key list` as equivalent to the `key tuple`.
Copying is important: otherwise, the later `pop()` will remove the bound
slice elements from the op itself.

`saving/serialization_lib_test.py`:

* Add `test_numpy_get_item_layer()`:
	test for consistent serialization/deserialization of a model which
	contains `ops.numpy.GetItem`;

feat(losses): add Dice loss implementation (keras-team#19409)

* feat(losses): add Dice loss implementation

* removed smooth parameter and type casting

* adjusted casting and dot operator

Update casting

Bump the github-actions group with 1 update (keras-team#19412)

Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action).

Updates `github/codeql-action` from 3.24.6 to 3.24.9
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@8a470fd...1b1aada)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

Fix issue with shared layer deserialization

Remove dead code in saving lib (keras-team#19415)

Remove unused beta param for silu, use torch op directly (keras-team#19417)

The beta param was only accepted on the tensorflow/torch backends
and not in the `keras.ops` API, nor was it tested. I think best
just to ditch, since no one could be relying on it.

Fix print_fn for custom function (keras-team#19419)

Add fp8 to `EinsumDense`

Add test script
  • Loading branch information
james77777778 committed Apr 3, 2024
1 parent 9eb9629 commit 898db1d
Show file tree
Hide file tree
Showing 25 changed files with 520 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ jobs:

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@8a470fddafa5cbb6266ee11b37ef4d8aae19c571 # v3.24.6
uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9
with:
sarif_file: results.sarif
73 changes: 73 additions & 0 deletions check_fp8_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse

import numpy as np

import keras
from keras import layers
from keras import models


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--fp8", action="store_true")
parser.add_argument("--einsum", action="store_true")
return parser.parse_args()


class Classifier(models.Model):
def __init__(self, use_fp8=False):
super().__init__()
inputs = layers.Input(shape=[28, 28, 1])
x = layers.Flatten()(inputs)
x = layers.Dense(
64, activation="relu", use_bias=False, use_fp8=use_fp8
)(x)
x = layers.Dense(
64, activation="relu", use_bias=False, use_fp8=use_fp8
)(x)
outputs = layers.Dense(
10, activation="softmax", use_bias=False, use_fp8=use_fp8
)(x)
super().__init__(inputs, outputs)


class Classifier2(models.Model):
def __init__(self, use_fp8=False):
super().__init__()
inputs = layers.Input(shape=[28, 28, 1])
x = layers.Flatten()(inputs)
x = layers.EinsumDense(
"ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8
)(x)
x = layers.EinsumDense(
"ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8
)(x)
outputs = layers.EinsumDense(
"ab,bc->ac",
output_shape=[10],
activation="softmax",
use_fp8=use_fp8,
)(x)
super().__init__(inputs, outputs)


args = get_args()
if args.einsum:
model = Classifier2(use_fp8=args.fp8)
else:
model = Classifier(use_fp8=args.fp8)
num_classes = 10
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model.compile(
loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"],
)
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.1)
67 changes: 51 additions & 16 deletions keras/backend/common/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import functools

from keras import backend
from keras.api_export import keras_export
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend import config
from keras.backend.common.variables import standardize_dtype

"""
We adapted the type promotion lattice from JAX. Ref:
https://github.com/google/jax/blob/main/jax/_src/dtypes.py
"""

BOOL_TYPES = ["bool"]
INT_TYPES = [
BOOL_TYPES = ("bool",)
INT_TYPES = (
"uint8",
"uint16",
"uint32",
Expand All @@ -20,9 +14,44 @@
"int16",
"int32",
"int64",
]
FLOAT_TYPES = ["bfloat16", "float16", "float32", "float64"]
WEAK_TYPES = ["int", "float"]
)
FLOAT_TYPES = ("bfloat16", "float16", "float32", "float64")
WEAK_TYPES = ("int", "float")
# We need to separate float8 from float because there are no implicit
# conversions from float8 dtypes to other dtypes.
# Ref: https://github.com/google/jax/issues/16705
FLOAT8_TYPES = ("float8_e4m3fn", "float8_e5m2")

# All supported dtypes in Keras
ALLOWED_DTYPES = (
"float16",
"float32",
"float64",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"bfloat16",
"bool",
"string",
"float8_e4m3fn",
"float8_e5m2",
)
PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int64" if config.backend() == "tensorflow" else "int32",
float: "float32",
str: "string",
# special case for string value
"int": "int64" if config.backend() == "tensorflow" else "int32",
}

# We adapted the type promotion lattice from JAX. Ref:
# https://github.com/google/jax/blob/main/jax/_src/dtypes.py


def _type_promotion_lattice():
Expand Down Expand Up @@ -168,7 +197,7 @@ def _respect_weak_type(dtype, weak_type):
@functools.lru_cache(maxsize=None)
def _resolve_weak_type(dtype, precision="32"):
"""Resolve weak type by the precision of `backend.floatx()`."""
extended_allowed_dtypes = ALLOWED_DTYPES.union(WEAK_TYPES)
extended_allowed_dtypes = set(ALLOWED_DTYPES).union(WEAK_TYPES)
if dtype not in extended_allowed_dtypes:
raise ValueError(
"Invalid value for argument `dtype`. Expected one of "
Expand Down Expand Up @@ -234,7 +263,7 @@ def _lattice_result_type(*args):
out_weak_type = any(out_dtype is t for t in WEAK_TYPES)

out_weak_type = (out_dtype != "bool") and out_weak_type
precision = backend.floatx()[-2:]
precision = config.floatx()[-2:]
if out_weak_type:
out_dtype = _resolve_weak_type(out_dtype, precision=precision)
return out_dtype
Expand Down Expand Up @@ -270,7 +299,13 @@ def result_type(*dtypes):
if len(dtypes) == 0:
# If no dtypes provided, default to floatx, this matches
# `ops.convert_to_tensor([])`
return backend.floatx()
return config.floatx()
for dtype in dtypes:
if dtype in FLOAT8_TYPES:
raise ValueError(
"There is no implicit conversions from float8 dtypes to others."
f" You must cast it internally. Received: {dtypes}"
)
return _lattice_result_type(
*(backend.floatx() if arg is None else arg for arg in dtypes),
*(config.floatx() if arg is None else arg for arg in dtypes),
)
19 changes: 16 additions & 3 deletions keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from keras import backend
from keras import ops
from keras.backend.common import dtypes
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.testing import test_case
from keras.testing.test_utils import named_product

Expand All @@ -18,14 +17,18 @@ class DtypesTest(test_case.TestCase, parameterized.TestCase):

# TODO: torch doesn't support uint64.
ALL_DTYPES = []
for x in ALLOWED_DTYPES:
for x in dtypes.ALLOWED_DTYPES:
if x not in ["string", "uint64"]:
x = str(to_torch_dtype(x)).split(".")[-1]
if x not in ALL_DTYPES: # skip duplicates created by remapping
ALL_DTYPES.append(x)
ALL_DTYPES += [None]
else:
ALL_DTYPES = [x for x in ALLOWED_DTYPES if x != "string"] + [None]
ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [
None
]
# Remove float8 dtypes for the following tests
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]

def setUp(self):
from jax.experimental import enable_x64
Expand Down Expand Up @@ -217,3 +220,13 @@ def test_least_upper_bound_with_no_common_upper_bound(self):
ValueError, "no available implicit dtype promotion path"
):
dtypes._least_upper_bound("test_dtype1", "test_dtype2")

def test_invalid_float8_dtype(self):
with self.assertRaisesRegex(
ValueError, "There is no implicit conversions from float8 dtypes"
):
dtypes.result_type("float8_e4m3fn", "bfloat16")
with self.assertRaisesRegex(
ValueError, "There is no implicit conversions from float8 dtypes"
):
dtypes.result_type("float8_e5m2", "bfloat16")
32 changes: 3 additions & 29 deletions keras/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from keras.api_export import keras_export
from keras.backend import config
from keras.backend.common import dtypes
from keras.backend.common import global_state
from keras.backend.common.name_scope import current_path
from keras.backend.common.stateless_scope import get_stateless_scope
Expand Down Expand Up @@ -397,40 +398,13 @@ def initialize_all_variables():
global_state.set_global_attribute("uninitialized_variables", [])


ALLOWED_DTYPES = {
"float16",
"float32",
"float64",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"bfloat16",
"bool",
"string",
}

PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int64" if config.backend() == "tensorflow" else "int32",
float: "float32",
str: "string",
# special case for string value
"int": "int64" if config.backend() == "tensorflow" else "int32",
}


@keras_export(
["keras.utils.standardize_dtype", "keras.backend.standardize_dtype"]
)
def standardize_dtype(dtype):
if dtype is None:
return config.floatx()
dtype = PYTHON_DTYPES_MAP.get(dtype, dtype)
dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype)
if hasattr(dtype, "name"):
dtype = dtype.name
elif hasattr(dtype, "__str__") and (
Expand All @@ -440,7 +414,7 @@ def standardize_dtype(dtype):
elif hasattr(dtype, "__name__"):
dtype = dtype.__name__

if dtype not in ALLOWED_DTYPES:
if dtype not in dtypes.ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype: {dtype}")
return dtype

Expand Down
4 changes: 2 additions & 2 deletions keras/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from keras import backend
from keras import initializers
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.common import dtypes
from keras.backend.common.variables import AutocastScope
from keras.backend.common.variables import KerasVariable
from keras.backend.common.variables import shape_equal
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_autocasting(self):
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")

@parameterized.parameters(
*((dtype for dtype in ALLOWED_DTYPES if dtype != "string"))
*((dtype for dtype in dtypes.ALLOWED_DTYPES if dtype != "string"))
)
def test_standardize_dtype(self, dtype):
"""Tests standardize_dtype for all ALLOWED_DTYPES except string."""
Expand Down
4 changes: 2 additions & 2 deletions keras/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def softsign(x):
return tf.nn.softsign(x)


def silu(x, beta=1.0):
return tf.nn.silu(x, beta=beta)
def silu(x):
return tf.nn.silu(x)


def log_sigmoid(x):
Expand Down
2 changes: 2 additions & 0 deletions keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"int64": torch.int64,
"bfloat16": torch.bfloat16,
"bool": torch.bool,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
}


Expand Down
4 changes: 2 additions & 2 deletions keras/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def softsign(x):
return tnn.softsign(x)


def silu(x, beta=1.0):
def silu(x):
x = convert_to_tensor(x)
return x * sigmoid(beta * x)
return tnn.silu(x)


def log_sigmoid(x):
Expand Down
5 changes: 4 additions & 1 deletion keras/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ def stateless_fn(variables, *args, **kwargs):

def stateful_fn(*args, **kwargs):
return jax2tf_stateless_fn(
self._tf_trackable.variables, *args, **kwargs
# Change the trackable `ListWrapper` to a plain `list`
list(self._tf_trackable.variables),
*args,
**kwargs,
)

# Note: we truncate the number of parameters to what is
Expand Down
Loading

0 comments on commit 898db1d

Please sign in to comment.