# Load mnist

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

# Load MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Apply FFT to images

In [None]:
def fft_transform(images):
    fft_images = []
    for img in images:
        # Compute 2D FFT
        f = np.fft.fft2(img)
        fshift = np.fft.fftshift(f)  # center low frequencies
        magnitude = np.abs(fshift)   # magnitude spectrum
        magnitude = np.log(1 + magnitude)  # compress dynamic range
        fft_images.append(magnitude)
    return np.array(fft_images)

x_train_fft = fft_transform(x_train)
x_test_fft = fft_transform(x_test)

# Reshape for CNN (add channel dimension)
x_train_fft = x_train_fft[..., np.newaxis]
x_test_fft = x_test_fft[..., np.newaxis]

print("FFT shape:", x_train_fft.shape)


# Build a CNN on FFT images

In [None]:
from tensorflow.keras import layers, models

model = models.Sequential([
    layers.Conv2D(32, (3,3), activation="relu", input_shape=(28,28,1)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation="relu"),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(128, (3,3), activation="relu"),
    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.5),
    layers.Dense(10, activation="softmax")
])

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

history = model.fit(x_train_fft, y_train, epochs=10, batch_size=128, validation_split=0.1)


# Evaluate

In [None]:
test_loss, test_acc = model.evaluate(x_test_fft, y_test)
print("Test accuracy:", test_acc