In [1]:
import tensorflow as tf 
import tensorflow_datasets as tfds
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
import flax.nnx
import optax
from functools import partial

rngs = flax.nnx.Rngs(0)

2025-03-14 13:36:50.773272: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-14 13:36:50.790761: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741959410.805681   38362 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741959410.809826   38362 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1741959410.822193   38362 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
class cnn(flax.nnx.Module):
    def __init__(self, 
                 regularizer = 1e-4,
                 *args, 
                 **kwargs):
      
        self.rngs = flax.nnx.Rngs(3)
        self.regularizer = regularizer
        self.loss = 0.
        self.layers = []
        
        self.conv1 = flax.nnx.Conv(1, 32, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 16, 16,32]
        self.conv2 = flax.nnx.Conv(32, 64, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 8, 8, 64]
        self.conv3 = flax.nnx.Conv(64, 128, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 4, 4, 128]
        self.conv4 = flax.nnx.Conv(128, 256, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 2, 2, 256]
        self.conv5 = flax.nnx.Conv(256, 10, (3, 3), strides=2, padding='SAME', rngs=self.rngs) # [N, 1, 1, 10]
    
        self.layers = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5]
    
    @flax.nnx.jit
    def __call__(self, x):
        conv1 = flax.nnx.relu(self.conv1(x))
        conv2 = flax.nnx.relu(self.conv2(conv1))
        conv3 = flax.nnx.relu(self.conv3(conv2))
        conv4 = flax.nnx.relu(self.conv4(conv3))
        out = self.conv5(conv4)
        
        return out.reshape(-1, 10)
    
    # addtional loss
    def kernel_bias_L2regularization(self, conv):
        self.loss = 0
        for layer in self.layers:
            # weights regularization
            self.loss += (layer.kernel.value.sum() + layer.bias.value.sum()) ** 2 * self.regularizer
        return self.loss
    
model = cnn()
# flax.nnx.display(model)


In [3]:
# Load MNIST dataset
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

# Define a function to normalize the images
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

# Prepare the training dataset
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat()
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(32, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

# Prepare the test dataset
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).repeat()
ds_test = ds_test.batch(32, drop_remainder=True)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

# initialize training iterator
tr_iter = iter(ds_train)


2025-03-14 13:36:56.181044: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
I0000 00:00:1741959416.181099   38362 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5560 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


In [4]:
###! 不知道為什麼，loss function最前面一定要放model，不然會報錯。例如放(images,label,model)，把model放最後一個就會出錯
def loss_fn(model, x, y):
    y_hat_logits = model(jnp.array(x))
    ce = optax.losses.softmax_cross_entropy_with_integer_labels(logits=y_hat_logits, labels=jnp.array(y)).mean()
    return ce, y_hat_logits

# @flax.nnx.jit
def update_weights(model, opt, tr_iter):
    pics, labels = next(tr_iter)
    # 如果loss function有回傳loss以外的東西就要把has_aux打開，
    # value_and_grad的回傳規則是 [(loss_fn Arg1, loss_fn Arg2, ...), gradient]
    grad_fn = flax.nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, pics, labels)
    opt.update(grads)
    return loss, logits
    


In [5]:
learningRate = 1e-4
trainingStep = 50
opt = flax.nnx.Optimizer(model, optax.adamw(learningRate))

# 如果loss function有回傳loss以外的東西就要把has_aux打開，
# value_and_grad的回傳規則是 [(loss_fn Arg1, loss_fn Arg2, ...), gradient]
grad_fn = flax.nnx.value_and_grad(loss_fn, has_aux=False)


for step in range(trainingStep):
    for innerStep in range(100):
        loss, logits = update_weights(model, opt, tr_iter)
    
    print("step {} loss:{}".format(step * 100, loss))





2025-03-14 13:36:56.825730: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608


step 0 loss:1.0183837413787842
step 100 loss:0.2267848402261734
step 200 loss:0.3425450026988983
step 300 loss:0.38244152069091797
step 400 loss:0.23579832911491394
step 500 loss:0.3750866651535034
step 600 loss:0.19036740064620972
step 700 loss:0.3409792482852936
step 800 loss:0.25898292660713196
step 900 loss:0.10788832604885101
step 1000 loss:0.019472802057862282
step 1100 loss:0.05721774697303772
step 1200 loss:0.28923264145851135
step 1300 loss:0.08502659201622009
step 1400 loss:0.2867080569267273
step 1500 loss:0.16210684180259705
step 1600 loss:0.0771888941526413
step 1700 loss:0.11599244177341461
step 1800 loss:0.10030551254749298
step 1900 loss:0.20300903916358948
step 2000 loss:0.014317288994789124
step 2100 loss:0.3251025974750519
step 2200 loss:0.07950488477945328
step 2300 loss:0.04792633652687073
step 2400 loss:0.10971150547266006
step 2500 loss:0.031338781118392944
step 2600 loss:0.040094152092933655
step 2700 loss:0.17148299515247345
step 2800 loss:0.05137480050325394
s

KeyboardInterrupt: 