Set up imports. Make sure you're running this on a Mac OS device for MLX support. Ensure all dependencies are installed from `requirements.txt` as well!

In [10]:
import glob

# Core MLX imports
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# MLX data loading library
import mlx.data as dx
from mlx.data.datasets import load_mnist

# Visualization Tools
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# import torchvision for loading datasets and transformations
import torchvision
import torchvision.transforms as transforms

Define the functions to load and transform the dataset

In [18]:
# Perform transformation augmentation on the dataset
def load_transform_MNIST(img_size, batch_size):
    # load the raw data
    mnist_train = load_mnist(train=True)
    mnist_test = load_mnist(train=False)

    # build the stream and transform the data
    train_dataloader = (
        mnist_train
        .shuffle()
        .to_stream()
        .image_resize("image", h=img_size, w=img_size)
        .key_transform("image", lambda x: x.astype("float32") / 255.0)
        .batch(batch_size)
        .prefetch(4, 2)
    )

    test_dataloader = (
        mnist_test
        .to_stream()
        .image_resize("image", h=img_size, w=img_size)
        .key_transform("image", lambda x: x.astype("float32") / 255.0)
        .batch(batch_size)
        .prefetch(4, 2)
    )

    return train_dataloader, test_dataloader

Set up constants and get data

In [19]:
# 28x28 pixel images for MNIST
IMG_SIZE = 28
# 1 channel for grayscale images
IMG_CH = 1
# 128 images per batch
BATCH_SIZE = 128
# 10 output classes for digits 0-9
IN_CLASSES = 10
train_loader, test_loader = load_transform_MNIST(IMG_SIZE, BATCH_SIZE)

# Check for metal availablity
mx.metal.is_available()

True