In [1]:
import pretty_errors
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
from matplotlib.pylab import f
import tensorflow as tf
import numpy as np
import cvnn.layers as complex_layers
import matplotlib.pyplot as plt
from cvnn.activations import modrelu, zrelu, crelu, cart_softmax
from cvnn.losses import ComplexAverageCrossEntropy


def get_model() -> tf.keras.Model:
    model = tf.keras.models.Sequential()
    layers = [
        complex_layers.ComplexInput(input_shape=(28, 28, 1), dtype=tf.complex64),
        complex_layers.ComplexFlatten(),
        complex_layers.ComplexDense(128, activation=modrelu, use_bias=True, dtype=tf.complex64),
        complex_layers.ComplexDense(256, activation=modrelu, use_bias=True, dtype=tf.complex64),
        complex_layers.ComplexDense(128, activation=modrelu, use_bias=True, dtype=tf.complex64),
        complex_layers.ComplexDense(10, activation=cart_softmax, use_bias=True, dtype=tf.complex64)
    ]
    for layer in layers:
        model.add(layer)
    return model

def load_complex_dataset(x_train, y_train, x_test, y_test):
    """Loads the MNIST dataset and applies the 2D Discrete Fourier Transform (DFT) to each image.
    Args:
        x_train (numpy.ndarray): The training images, shape (num_samples, 28, 28).
        y_train (numpy.ndarray): The labels for the training images.
        x_test (numpy.ndarray): The test images, shape (num_samples, 28, 28).
        y_test (numpy.ndarray): The labels for the test images.
    returns: A tuple containing the transformed training and test datasets.
    """
    
    x_train_complex = []
    x_test_complex = []
    for train_sample in x_train:
        # Apply the 2D Discrete Fourier Transform
        train_complex_image = np.fft.fft2(train_sample)

        # The output of the DFT is often shifted to have the zero-frequency component (DC component) in the center for visualization purposes.
        train_shifted_complex_image = np.fft.fftshift(train_complex_image)
        casted = tf.cast(train_shifted_complex_image, dtype=tf.complex64)
        x_train_complex.append(casted)
    for test_sample in x_test:
        # Apply the 2D Discrete Fourier Transform
        test_complex_image = np.fft.fft2(test_sample)

        # The output of the DFT is often shifted to have the zero-frequency component (DC component) in the center for visualization purposes.
        test_shifted_complex_image = np.fft.fftshift(test_complex_image)
        casted = tf.cast(test_shifted_complex_image, dtype=tf.complex64)
        x_test_complex.append(casted)
    return (np.array(x_train_complex), y_train), (np.array(x_test_complex), y_test)

2025-06-11 13:43:01.420268: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-11 13:43:01.833152: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-11 13:43:01.833504: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-11 13:43:01.892462: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-11 13:43:02.007546: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-06-11 13:43:02.010357: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [2]:

(real_images_train, labels_train), (real_images_test, labels_test) = tf.keras.datasets.mnist.load_data() # real data
(complex_images_train, _), (complex_images_test, _) = load_complex_dataset(real_images_train, labels_train, real_images_test, labels_test) # complex data (2d DFT)

print(f"Complex number Ex: {complex_images_train[0][0][0]}")

# Convert labels to one-hot encoding
labels_train = tf.keras.utils.to_categorical(labels_train, 10)
labels_test = tf.keras.utils.to_categorical(labels_test, 10)

# flatten images 
print(f'\nTrain data shape: {complex_images_train.shape}, Train labels shape: {labels_train.shape}')
print(f'Test data shape: {complex_images_test.shape}, Test labels shape: {labels_test.shape}\n')


Complex number Ex: (323+0j)

Train data shape: (60000, 28, 28), Train labels shape: (60000, 10)
Test data shape: (10000, 28, 28), Test labels shape: (10000, 10)



In [3]:

# ------------ sample code ------------
epochs = 100

# Assume you already have complex data... example numpy arrays of dtype np.complex64
model = get_model()   # Get your model

# Compile as any TensorFlow model
model.compile(optimizer='adam', metrics=['accuracy'],
            loss=ComplexAverageCrossEntropy())
model.summary()


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 complex_flatten (ComplexFl  (None, 784)               0         
 atten)                                                          
                                                                 
 complex_dense (ComplexDens  (None, 128)               200960    
 e)                                                              
                                                                 
 complex_dense_1 (ComplexDe  (None, 256)               66048     
 nse)                                                            
                                                                 
 complex_dense_2 (ComplexDe  (None, 128)               65792     
 nse)                                                            
                                                                 
 complex_dense_3 (ComplexDe  (None, 10)                2

In [4]:

# Train and evaluate
history = model.fit(complex_images_train, labels_train, epochs=epochs, validation_data=(complex_images_test, labels_test))
test_loss, test_acc = model.evaluate(complex_images_test,  labels_test, verbose=2)
print(f'\nTest accuracy: {test_acc:.4f}')
print(f'Test loss: {test_loss:.4f}')
print(f'History: {history.history}')

2025-06-11 13:37:28.334471: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 376320000 exceeds 10% of free system memory.


Epoch 1/100
Epoch 2/100
Epoch 3/100

KeyboardInterrupt: 