In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt

2025-11-06 15:21:34.509916: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-06 15:21:34.547115: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-06 15:21:35.421092: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

In [3]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=3500)]
        )

In [4]:
def NiN_blk(out_channels, kernel_size, stride, padding):
    return tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(out_channels, kernel_size, stride, padding),
        tf.keras.layers.ReLU(), 
        tf.keras.layers.Conv2D(out_channels, 1),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Conv2D(out_channels, 1),
        tf.keras.layers.ReLU(),
    ])

In [5]:
class NiN(tf.keras.Model):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.model = tf.keras.models.Sequential([
            NiN_blk(96, 11, 4, 'valid'),
            tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
            NiN_blk(256, 5, 2, 'same'),
            tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
            NiN_blk(384, 3, 1, 'same'),
            tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
            tf.keras.layers.Dropout(0.5),
            NiN_blk(num_classes, kernel_size=3, stride=1, padding='same'),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Flatten(),
        ])
    def call(self, x):
        return self.model(x)

In [6]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

def preprocess(image, label):
    image = tf.cast(image, tf.float16) / 255.0
    image = tf.expand_dims(image, axis=-1)
    image = tf.image.resize(image, [224, 224])
    return image, label

batch_size = 128

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(preprocess) \
                   .shuffle(buffer_size=1024) \
                   .batch(batch_size) \
                   .prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.map(preprocess) \
                   .shuffle(buffer_size=1024) \
                   .batch(batch_size) \
                   .prefetch(tf.data.AUTOTUNE)

I0000 00:00:1762438911.051967   19491 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3500 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


In [7]:
NiN = NiN(10)
NiN.model.build((128,224,224,1))
NiN.model.summary()

In [8]:
NiN.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=["accuracy"]
            )

NiN.fit(train_ds, epochs=10, validation_data=test_ds)

Epoch 1/10


2025-11-06 15:23:33.023390: E tensorflow/core/util/util.cc:131] oneDNN supports DT_HALF only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
2025-11-06 15:23:33.106097: I external/local_xla/xla/service/service.cc:163] XLA service 0x7fc4e000aac0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-11-06 15:23:33.106117: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3050 Ti Laptop GPU, Compute Capability 8.6
2025-11-06 15:23:33.135787: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-11-06 15:23:33.387175: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91301


[1m  3/469[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m22s[0m 47ms/step - accuracy: 0.0859 - loss: 2.3026   

I0000 00:00:1762439022.260633   19560 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 66ms/step - accuracy: 0.1822 - loss: 2.3003 - val_accuracy: 0.1915 - val_loss: 2.2958
Epoch 2/10
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 44ms/step - accuracy: 0.2313 - loss: 2.2752 - val_accuracy: 0.1919 - val_loss: 2.2100
Epoch 3/10
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 44ms/step - accuracy: 0.3314 - loss: 1.7788 - val_accuracy: 0.4921 - val_loss: 1.3374
Epoch 4/10
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 45ms/step - accuracy: 0.6199 - loss: 1.0385 - val_accuracy: 0.5808 - val_loss: 1.0541
Epoch 5/10
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 45ms/step - accuracy: 0.7062 - loss: 0.7754 - val_accuracy: 0.7230 - val_loss: 0.7345
Epoch 6/10
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 45ms/step - accuracy: 0.7497 - loss: 0.6759 - val_accuracy: 0.5369 - val_loss: 1.4476
Epoch 7/10
[1m469/469[0m 

<keras.src.callbacks.history.History at 0x7fc6140b6ae0>

In [9]:
NiN.evaluate(test_ds)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.7876 - loss: 0.5525


[0.552453875541687, 0.7875999808311462]