In [1]:
import tensorflow as tf 
import tensorflow_datasets as tfds
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
import flax.nnx
import optax
from functools import partial

2025-03-14 20:04:41.276239: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-14 20:04:41.291087: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741982681.307821    9323 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741982681.313088    9323 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1741982681.325136    9323 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
class cnn(flax.nnx.Module):
    def __init__(self, 
                 regularizer = 1e-4,
                 *args, 
                 **kwargs):
      
        self.rngs = flax.nnx.Rngs(3)
        self.regularizer = regularizer
        self.loss = 0.
        self.layers = []
        
        self.conv1 = flax.nnx.Conv(1, 32, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 16, 16,32]
        self.conv2 = flax.nnx.Conv(32, 64, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 8, 8, 64]
        self.conv3 = flax.nnx.Conv(64, 128, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 4, 4, 128]
        self.conv4 = flax.nnx.Conv(128, 256, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 2, 2, 256]
        self.conv5 = flax.nnx.Conv(256, 10, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 1, 1, 10]
    
        self.layers = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5]
    
    @flax.nnx.jit
    def __call__(self, x):
        conv1 = flax.nnx.relu(self.conv1(x))
        conv2 = flax.nnx.relu(self.conv2(conv1))
        conv3 = flax.nnx.relu(self.conv3(conv2))
        conv4 = flax.nnx.relu(self.conv4(conv3))
        out = self.conv5(conv4)
        
        return out.reshape(-1, 10)
    
    # addtional loss
    def kernel_bias_L2regularization(self):
        self.loss = 0
        for layer in self.layers:
            # weights regularization
            self.loss += (layer.kernel.value.sum() + layer.bias.value.sum()) ** 2 * self.regularizer
        return self.loss
    
model = cnn()
# flax.nnx.display(model)


In [3]:
# Load MNIST dataset
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

# Define a function to normalize the images
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

# Prepare the training dataset
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat()
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(32, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

# Prepare the test dataset
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat()
ds_test = ds_test.batch(32, drop_remainder=True)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

# initialize training and test dataset iterator
tr_iter = iter(ds_train)
ts_iter = iter(ds_test)


2025-03-14 20:04:51.799162: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
I0000 00:00:1741982691.799252    9323 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1747 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


In [4]:
###! 不知道為什麼，loss function最前面一定要放model，不然會報錯。例如放(images,label,model)，把model放最後一個就會出錯
def loss_fn(model, x, y):
    y_hat_logits = model(jnp.array(x))
    ce = optax.losses.softmax_cross_entropy_with_integer_labels(logits=y_hat_logits, labels=jnp.array(y)).mean()
    reg = model.kernel_bias_L2regularization()
    return ce + reg, y_hat_logits

# @flax.nnx.jit
def update_weights(model, opt, tr_iter, innerSteps=100):
    for innerStep in range(innerSteps):
        pics, labels = next(tr_iter)
        # 如果loss function有回傳loss以外的東西就要把has_aux打開，
        # value_and_grad的回傳規則是 [(loss_fn Arg1, loss_fn Arg2, ...), gradient]
        grad_fn = flax.nnx.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(model, pics, labels)
        opt.update(grads)
    return loss, logits
    


In [5]:
learningRate = 1e-4
trainingStep = 10
opt = flax.nnx.Optimizer(model, optax.adamw(learningRate))

# 如果loss function有回傳loss以外的東西就要把has_aux打開，
# value_and_grad的回傳規則是 [(loss_fn Arg1, loss_fn Arg2, ...), gradient]
grad_fn = flax.nnx.value_and_grad(loss_fn, has_aux=False)

for step in range(trainingStep):
    # optimizing model
    loss, logits = update_weights(model, opt, tr_iter)
    # using test dataset to validate the model with accuracy
    pics, labels = next(ts_iter)
    loss, logits = loss_fn(model, pics, labels)
    accuracy = jnp.mean(jnp.argmax(logits, axis=1) == jnp.array(labels))
    
    print("step {} loss:{} val acc:{}".format(step * 100, loss, accuracy))





2025-03-14 20:05:14.652803: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608


ValueError: Arrays leaves are not supported, at 'loss': Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Array(0.0186766, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f42c03c0c10>, in_tracers=(Traced<ShapedArray(float32[]):JaxprTrace>, Traced<ShapedArray(float32[]):JaxprTrace>), out_tracer_refs=[<weakref at 0x7f42c03d7bf0; to 'JaxprTracer' at 0x7f42c03d7b50>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m b[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m c[35m:f32[][39m = add a b [34m[22m[1min [39m[22m[22m(c,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None, None), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False, False), 'name': 'add', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7f42c03e2860>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types={}), xla_metadata=None))