In [1]:
import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

In [2]:
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


# model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

model = get_model()

2024-04-07 06:27:02.992570: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.


In [3]:
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
# optimizer = keras.optimizers.Adam(learning_rate=1e-3)
optimizer = keras.optimizers.AdamW(learning_rate=1e-3, clipnorm=1.0)

In [4]:
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables


In [5]:
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

In [6]:
@jax.jit
def train_step(state, data):
    # 不使用jit會有錯誤發生而無法訓練，目前原因未知
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

In [7]:
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

Training loss (for 1 batch) at step 0: 111.7812
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 1.6654
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 1.1273
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.6074
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.6776
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2396
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.6327
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2158
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3267
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6049
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3289
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.2887
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3123
Seen so far: 38432 samples
Training loss (for 1 batch) 

2024-04-07 06:27:31.601270: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [8]:
# 透過Stateless訓練的結果只會儲存在state的變數當中，需要把這些算完的state放回到模型中，
# 才能使用keras.model的物件進行操作(例如存檔等等的行為)

print("before assign:")
print(model(np.ones([1,784])))

trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)
    
print("after assign:")
print(model(np.ones([1,784])))

before assign:
[[-0.56451714 -0.41317883 -1.0507475   0.86134243 -0.3582103   0.69896895
   0.5061509  -0.8784539  -0.22967747 -0.57622534]]
after assign:
[[-0.20653455 -0.3140297   0.12025344  0.12553349  0.07457571 -0.01390003
  -0.11240885 -0.17743774  0.4671492   0.02217555]]
