# Image classification with ConvMixer

<hr />

### Introduction

Vision Transformers (ViT; <a href="https://arxiv.org/abs/1612.00593">Dosovitskiy et al.</a>) extract small patches from the input images, linearly project them, and then apply the Transformer (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al.</a>) blocks. The application of ViTs to image recognition tasks is quickly becoming a promising area of research, because ViTs eliminate the need to have strong inductive biases (such as convolutions) for modeling locality. This presents them as a general computation primititive capable of learning just from the training data with as minimal inductive priors as possible. ViTs yield great downstream performance when trained with proper regularization, data augmentation, and relatively large datasets.

In the <a href="https://openreview.net/pdf?id=TVHS5Y4dNvM">Patches Are All You Need</a> paper (note: at the time of writing, it is a submission to the ICLR 2022 conference), the authors extend the idea of using patches to train an all-convolutional network and demonstrate competitive results. Their architecture namely <b>ConvMixer</b> uses recipes from the recent isotrophic architectures like ViT, MLP-Mixer (<a href="https://arxiv.org/abs/2105.01601">Tolstikhin et al.</a>), such as using the same depth and resolution across different layers in the network, residual connections, and so on.

In this example, we will implement the ConvMixer model and demonstrate its performance on the CIFAR-10 dataset.

<hr />

### Imports

In [1]:
import keras
from keras import layers

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

2025-04-21 04:51:03.900037: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745211063.915648  105436 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745211063.920650  105436 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745211063.931585  105436 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745211063.931600  105436 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745211063.931602  105436 computation_placer.cc:177] computation placer alr

<hr />

### Hyperparameters

To keep run time short, we will train the model for only 10 epochs. To focus on the core ideas of ConvMixer, we will not use other training-specific elements like RandAugment (<a href="https://arxiv.org/abs/1909.13719">Cubuk et al.</a>). If you are interested in learning more about those details, please refer to the <a href="https://openreview.net/pdf?id=TVHS5Y4dNvM">original paper</a>.

In [2]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 10

<hr />

### Load the CIFAR-10 dataset

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
val_split = 0.1

val_indices = int(len(x_train) * val_split)
new_x_train, new_y_train = x_train[val_indices:], y_train[val_indices:]
x_val, y_val = x_train[:val_indices], y_train[:val_indices]

print(f"Training data samples: {len(new_x_train)}")
print(f"Validation data samples: {len(x_val)}")
print(f"Test data samples: {len(x_test)}")

Training data samples: 45000
Validation data samples: 5000
Test data samples: 10000


<hr />

### Prepare <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset">tf.data.Dataset</a> objects

Our data augmentation pipeline is different from what the authors used for the CIFAR-10 dataset, which is fine for the purpose of the example. Note that, it's ok to use <b>TF APIs for data I/O and preprocessing</b> with other backends (jax, torch) as it is feature-complete framework when it comes to data preprocessing.

In [4]:
image_size = 32
auto = tf.data.AUTOTUNE

augmentation_layers = [
    keras.layers.RandomCrop(image_size, image_size),
    keras.layers.RandomFlip("horizontal"),
]


def augment_images(images):
    for layer in augmentation_layers:
        images = layer(images, training=True)
    return images


def make_datasets(images, labels, is_train=False):
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    if is_train:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.batch(batch_size)
    if is_train:
        dataset = dataset.map(
            lambda x, y: (augment_images(x), y), num_parallel_calls=auto
        )
    return dataset.prefetch(auto)


train_dataset = make_datasets(new_x_train, new_y_train, is_train=True)
val_dataset = make_datasets(x_val, y_val)
test_dataset = make_datasets(x_test, y_test)

2025-04-21 04:53:37.736305: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
