In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from transformers import FNetConfig, FNetForSequenceClassification, AdamW
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
# Define Custom Dataset Wrapper

class CustomCIFAR10Dataset(Dataset):
    def __init__(self, cifar_dataset, transform=None, flatten=False):
        self.cifar_dataset = cifar_dataset
        self.transform = transform
        self.flatten = flatten

    def __len__(self):
        return len(self.cifar_dataset)

    def __getitem__(self, idx):
        image, label = self.cifar_dataset[idx]
        if self.transform:
            image = self.transform(image)
        if self.flatten:
            image = image.view(-1)  # Flatten [1, 32, 32] to [1024]
        # Convert to indices and clamp
        image = torch.clamp((image * 255).long(), min=0, max=255)
        return image, label

# Preprocessing

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [3]:
# Load CIFAR-10 Dataset

original_cifar_train = CIFAR10(root='./data', train=True, download=True)
original_cifar_test = CIFAR10(root='./data', train=False, download=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:16<00:00, 10176678.31it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
train_dataset = CustomCIFAR10Dataset(original_cifar_train, transform=transform, flatten=True)
test_dataset = CustomCIFAR10Dataset(original_cifar_test, transform=transform, flatten=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [5]:

# Model Configuration
num_labels = 10  # CIFAR-10 has 10 classes
sequence_length = 1024  # Each image is flattened into a vector of size 1024

config = FNetConfig(
    vocab_size=256,  # Pixel values are now tokenized into 256 levels (0-255)
    hidden_size=512,  # Hidden layer size
    num_hidden_layers=4,  # Number of FNet layers
    intermediate_size=1024,  # Size of intermediate FFN layers
    num_attention_heads=8,  # Not used, but part of the config
    max_position_embeddings=sequence_length,  # Maximum sequence length
    num_labels=num_labels,  # Number of classes for classification
)


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize Model
model = FNetForSequenceClassification(config).to(device)

# Optimizer and Loss
optimizer = AdamW(model.parameters(), lr=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()


Using device: cuda




In [7]:
# Training Loop
epochs = 25
model.train()

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)["logits"]  # Forward pass
        loss = loss_fn(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update parameters

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f"Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")


Epoch 1/25


100%|██████████| 1563/1563 [05:32<00:00,  4.70it/s]


Loss: 2.1653, Accuracy: 0.1748
Epoch 2/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.9533, Accuracy: 0.2650
Epoch 3/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.8837, Accuracy: 0.2947
Epoch 4/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]


Loss: 1.8263, Accuracy: 0.3204
Epoch 5/25


100%|██████████| 1563/1563 [05:32<00:00,  4.70it/s]


Loss: 1.7623, Accuracy: 0.3495
Epoch 6/25


100%|██████████| 1563/1563 [05:32<00:00,  4.70it/s]


Loss: 1.7062, Accuracy: 0.3728
Epoch 7/25


100%|██████████| 1563/1563 [05:32<00:00,  4.70it/s]


Loss: 1.6678, Accuracy: 0.3900
Epoch 8/25


100%|██████████| 1563/1563 [05:32<00:00,  4.71it/s]


Loss: 1.6308, Accuracy: 0.4033
Epoch 9/25


100%|██████████| 1563/1563 [05:32<00:00,  4.71it/s]


Loss: 1.6008, Accuracy: 0.4154
Epoch 10/25


100%|██████████| 1563/1563 [05:32<00:00,  4.71it/s]


Loss: 1.5727, Accuracy: 0.4240
Epoch 11/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]


Loss: 1.5386, Accuracy: 0.4379
Epoch 12/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.5082, Accuracy: 0.4504
Epoch 13/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]


Loss: 1.4796, Accuracy: 0.4614
Epoch 14/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.4486, Accuracy: 0.4746
Epoch 15/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.4264, Accuracy: 0.4828
Epoch 16/25


100%|██████████| 1563/1563 [05:30<00:00,  4.72it/s]


Loss: 1.4002, Accuracy: 0.4910
Epoch 17/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.3778, Accuracy: 0.5004
Epoch 18/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.3452, Accuracy: 0.5133
Epoch 19/25


100%|██████████| 1563/1563 [05:30<00:00,  4.72it/s]


Loss: 1.3231, Accuracy: 0.5208
Epoch 20/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.3006, Accuracy: 0.5281
Epoch 21/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]


Loss: 1.2714, Accuracy: 0.5385
Epoch 22/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]


Loss: 1.2498, Accuracy: 0.5496
Epoch 23/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.2224, Accuracy: 0.5583
Epoch 24/25


100%|██████████| 1563/1563 [05:31<00:00,  4.72it/s]


Loss: 1.2006, Accuracy: 0.5627
Epoch 25/25


100%|██████████| 1563/1563 [05:31<00:00,  4.71it/s]

Loss: 1.1654, Accuracy: 0.5776





In [8]:
# Evaluation Loop
model.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)["logits"]
        loss = loss_fn(outputs, labels)

        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader)
test_acc = correct / total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")


100%|██████████| 313/313 [00:26<00:00, 11.87it/s]

Test Loss: 1.5721, Test Accuracy: 0.4602



