In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds


In [None]:
(train_ds, test_ds), info = tfds.load('cifar10',
                                      split=['train', 'test'],
                                      shuffle_files=True,
                                      as_supervised=True,
                                      with_info=True)


In [None]:
BATCH_SIZE = 64


def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.divide(tf.cast(image, tf.float32), 255.), label


train_ds = train_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.cache()
train_ds = train_ds.shuffle(info.splits['train'].num_examples)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

test_ds = test_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)
test_ds = test_ds.cache()
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)


In [None]:
L2_REG = 0
KERNEL_SIZE = 3
STRIDE_SIZE = 1

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=info.features['image'].shape),  # type: ignore
    tf.keras.layers.Conv2D(filters=64,
                           kernel_size=KERNEL_SIZE,
                           strides=STRIDE_SIZE,
                           padding='same',
                           activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(L2_REG)),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(filters=128,
                           kernel_size=KERNEL_SIZE,
                           strides=STRIDE_SIZE,
                           padding='same',
                           activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(L2_REG)),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64,
                          activation='relu',
                          kernel_regularizer=tf.keras.regularizers.l2(L2_REG)),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(10,
                          kernel_regularizer=tf.keras.regularizers.l2(L2_REG)),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)


In [None]:
import os

logdir = os.path.join(
    'logs',
    'cifar10',
    f'filter_{KERNEL_SIZE}_stride_{STRIDE_SIZE}_l2_{L2_REG}',
)
os.makedirs(logdir, exist_ok=True)

In [None]:
with open(os.path.join(logdir, 'model.json'), 'w', encoding='utf-8') as f:
    print(model.to_json(), file=f)


def log_summary(string):
    print(string)
    with open(os.path.join(logdir, 'summary.txt'), 'a', encoding='utf-8') as f:
        print(string, file=f)


model.summary(print_fn=print)

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=logdir,
    histogram_freq=1,
    write_graph=True,
)

model.fit(
    train_ds,
    epochs=100,
    validation_data=test_ds,
    callbacks=[tensorboard_callback],
)


In [None]:
import matplotlib.pyplot as plt

for data, label in test_ds.unbatch():
    preds = model(data[tf.newaxis, ...])[0]
    pred = tf.argmax(preds).numpy()
    label = label.numpy()

    if pred == label:
        print(f'pred: {pred}, label: {label}')
        plt.imshow(data.numpy(), cmap='gray')
        break


In [None]:
for data, label in test_ds.unbatch():
    preds = model(data[tf.newaxis, ...])[0]
    pred = tf.argmax(preds).numpy()
    label = label.numpy()

    if pred != label:
        print(f'pred: {pred}, label: {label}')
        plt.imshow(data.numpy(), cmap='gray')
        break

In [None]:
sample_data, sample_label = next(iter(train_ds.unbatch().take(1)))
sample_data = sample_data[tf.newaxis, ...]
plt.imshow(sample_data.numpy().squeeze(), cmap='gray')
print(sample_label.numpy())

features = {}
for layer in model.layers:
    sample_data = layer(sample_data)
    if 'conv2d' in layer.name:
        features.update({layer.name: sample_data})

for name, feature in features.items():
    print(name)
    print(feature.shape)

    figure = plt.figure(figsize=(15, 15))
    for i in range(feature.shape[-1]):
        ax = figure.add_subplot(16, 8, i + 1)
        ax.imshow(feature[0, :, :, i].numpy(), cmap='gray')
    plt.show()
