In [None]:
import torch
import matplotlib
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.fft

In [None]:
class FourierConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, image_size):
        super(FourierConv2D, self).__init__()
        self.kernel_size = kernel_size
        self.image_size = image_size

        # Initialize the Fourier domain kernels
        real_part = torch.empty(out_channels, in_channels, *image_size)
        imag_part = torch.empty(out_channels, in_channels, *image_size)

        # Apply Glorot (Xavier) initialization separately to real and imaginary parts
        nn.init.xavier_uniform_(real_part)
        nn.init.xavier_uniform_(imag_part)

        # Combine real and imaginary parts to form complex kernels
        self.kernels = nn.Parameter(real_part + 1j * imag_part)

    def forward(self, x):
        # Perform FFT on input
        x_fft = torch.fft.fft2(x, dim=(-2, -1))

        # Pointwise multiplication in Fourier domain
        x_conv_fft = torch.einsum("oihw,biwh->bohw", self.kernels, x_fft)

        # Inverse FFT to return to spatial domain
        x_conv = torch.fft.ifft2(x_conv_fft, dim=(-2, -1)).real
        return x_conv

class FourierPooling2D(nn.Module):
    def __init__(self, pool_size):
        super(FourierPooling2D, self).__init__()
        self.pool_size = pool_size

    def forward(self, x):
        # Truncate Fourier representation to retain only low frequencies
        height, width = x.shape[-2], x.shape[-1]
        y_min, y_max = int((0.5 - self.pool_size / 2) * height), int((0.5 + self.pool_size / 2) * height)
        x_min, x_max = int((0.5 - self.pool_size / 2) * width), int((0.5 + self.pool_size / 2) * width)

        return x[..., y_min:y_max, x_min:x_max]

class FCNN(nn.Module):
    def __init__(self, image_size=(32, 32), num_classes=10):
        super(FCNN, self).__init__()
        self.fourier_conv1 = FourierConv2D(in_channels=3, out_channels=16, kernel_size=(3, 3), image_size=image_size)
        self.pool1 = FourierPooling2D(pool_size=0.5)
        self.fourier_conv2 = FourierConv2D(in_channels=16, out_channels=32, kernel_size=(3, 3), image_size=(image_size[0]//2, image_size[1]//2))
        self.pool2 = FourierPooling2D(pool_size=0.5)

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * (image_size[0]//4) * (image_size[1]//4), num_classes)

    def forward(self, x):
        x = self.fourier_conv1(x)
        x = self.pool1(x)
        x = self.fourier_conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


In [None]:
class FourierConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, image_size):
        super(FourierConv2D, self).__init__()
        self.kernel_size = kernel_size
        self.image_size = image_size

        # Initialize 3x3 kernels in Fourier domain
        real_part = torch.empty(out_channels, in_channels, 3, 3)
        imag_part = torch.empty(out_channels, in_channels, 3, 3)

        # Apply Glorot (Xavier) initialization
        nn.init.xavier_uniform_(real_part)
        nn.init.xavier_uniform_(imag_part)

        # Combine real and imaginary parts to form complex kernels
        self.kernels = nn.Parameter(real_part + 1j * imag_part)

    def forward(self, x):
        # Perform FFT on input
        x_fft = torch.fft.fft2(x, dim=(-2, -1))

        # Pad kernels for circular convolution
        padding = [(self.image_size[1] - 3) // 2, (self.image_size[1] - 3) // 2,
                   (self.image_size[0] - 3) // 2, (self.image_size[0] - 3) // 2]
        padded_kernels = torch.nn.functional.pad(self.kernels, padding)

        # Center crop the Fourier-transformed input to match padded kernel size
        # Calculate cropping indices
        h_crop_start = (x_fft.shape[-2] - padded_kernels.shape[-2]) // 2
        h_crop_end = h_crop_start + padded_kernels.shape[-2]
        w_crop_start = (x_fft.shape[-1] - padded_kernels.shape[-1]) // 2
        w_crop_end = w_crop_start + padded_kernels.shape[-1]

        # Crop the Fourier-transformed input
        x_fft_cropped = x_fft[..., h_crop_start:h_crop_end, w_crop_start:w_crop_end]

        # Pointwise multiplication in Fourier domain
        x_conv_fft = torch.einsum("oihw,biwh->bohw", padded_kernels, x_fft_cropped)

        # Inverse FFT to return to spatial domain
        x_conv = torch.fft.ifft2(x_conv_fft, dim=(-2, -1)).real
        return x_conv

class FourierPooling2D(nn.Module):
    def __init__(self, pool_size):
        super(FourierPooling2D, self).__init__()
        self.pool_size = pool_size

    def forward(self, x):
        height, width = x.shape[-2], x.shape[-1]
        y_min, y_max = int((0.5 - self.pool_size / 2) * height), int((0.5 + self.pool_size / 2) * height)
        x_min, x_max = int((0.5 - self.pool_size / 2) * width), int((0.5 + self.pool_size / 2) * width)
        return x[..., y_min:y_max, x_min:x_max]

class FCNN_AlexNet(nn.Module):
    def __init__(self, image_size=(224, 224), num_classes=1000):
        super(FCNN_AlexNet, self).__init__()
        # First Fourier Convolution Layer
        self.fourier_conv1 = FourierConv2D(in_channels=3, out_channels=96, kernel_size=(11, 11), image_size=image_size)
        self.pool1 = FourierPooling2D(pool_size=0.5)

        # Second Fourier Convolution Layer
        self.fourier_conv2 = FourierConv2D(in_channels=96, out_channels=256, kernel_size=(5, 5), image_size=(image_size[0]//2, image_size[1]//2))
        self.pool2 = FourierPooling2D(pool_size=0.5)

        # Third Fourier Convolution Layer
        self.fourier_conv3 = FourierConv2D(in_channels=256, out_channels=384, kernel_size=(3, 3), image_size=(image_size[0]//4, image_size[1]//4))

        # Fourth and Fifth Fourier Convolution Layers (no pooling after these layers in AlexNet)
        self.fourier_conv4 = FourierConv2D(in_channels=384, out_channels=384, kernel_size=(3, 3), image_size=(image_size[0]//4, image_size[1]//4))
        self.fourier_conv5 = FourierConv2D(in_channels=384, out_channels=256, kernel_size=(3, 3), image_size=(image_size[0]//4, image_size[1]//4))
        self.pool3 = FourierPooling2D(pool_size=0.5)

        # Fully Connected Layers
        flattened_size = (image_size[0] // 8) * (image_size[1] // 8) * 256
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(flattened_size, 4096)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(4096, 4096)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = torch.relu(self.fourier_conv1(x))
        x = self.pool1(x)

        x = torch.relu(self.fourier_conv2(x))
        x = self.pool2(x)

        x = torch.relu(self.fourier_conv3(x))

        x = torch.relu(self.fourier_conv4(x))

        x = torch.relu(self.fourier_conv5(x))
        x = self.pool3(x)

        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)

        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)

        x = self.fc3(x)
        return x

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to a smaller size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the FCNN_AlexNet model, assuming it is already implemented as in previous responses
image_size = (128, 128)  # MNIST resized to AlexNet input size
num_classes = 10  # MNIST has 10 classes (digits 0-9)
model = FCNN_AlexNet(image_size=image_size, num_classes=num_classes)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")

# Evaluation on test dataset
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on MNIST test set: {accuracy:.2f}%")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.5MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 491kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.53MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.77MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/5], Loss: 2.3019
Epoch [2/5], Loss: 2.3016
Epoch [3/5], Loss: 2.3014
Epoch [4/5], Loss: 2.3014
Epoch [5/5], Loss: 2.3013
Accuracy on MNIST test set: 11.35%
