In [1]:
# Import Necessary Libraries

import torch

from torch import nn

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms

from torchvision.datasets import CIFAR10

from transformers import NystromformerForSequenceClassification, NystromformerConfig

import torch.optim as optim

from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm

import torch.nn.functional as F

In [2]:
# Device Configuration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# Define Custom Dataset Wrapper

class CustomCIFAR10Dataset(Dataset):

    def __init__(self, cifar_dataset, transform=None, flatten=False):

        self.cifar_dataset = cifar_dataset

        self.transform = transform

        self.flatten = flatten



    def __len__(self):

        return len(self.cifar_dataset)



    def __getitem__(self, idx):

        image, label = self.cifar_dataset[idx]

        if self.transform:

            image = self.transform(image)

        if self.flatten:

            image = image.view(-1)  # Flatten [1, 32, 32] to [1024]

        return image, label



# Preprocessing

transform = transforms.Compose([

    transforms.Grayscale(num_output_channels=1),

    transforms.ToTensor(),

    transforms.Normalize((0.5,), (0.5,))

])


In [4]:
# Load CIFAR-10 Dataset

original_cifar_train = CIFAR10(root='./data', train=True, download=True)

original_cifar_test = CIFAR10(root='./data', train=False, download=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29185224.53it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
train_dataset = CustomCIFAR10Dataset(original_cifar_train, transform=transform, flatten=True)

test_dataset = CustomCIFAR10Dataset(original_cifar_test, transform=transform, flatten=True)



train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [6]:
# Define Model Configuration

config = NystromformerConfig(

    vocab_size=1024,

    hidden_size=128,

    num_hidden_layers=4,

    num_attention_heads=4,

    num_labels=10,

    max_position_embeddings=1024,

    hidden_dropout_prob=0.1

)



# Initialize Model

model = NystromformerForSequenceClassification(config)

model.to(device)



# Define Optimizer and Scheduler

optimizer = optim.AdamW(model.parameters(), lr=5e-4)

scheduler = CosineAnnealingLR(optimizer, T_max=10)

criterion = nn.CrossEntropyLoss()


In [7]:
# Training Function

scaler = torch.cuda.amp.GradScaler()



def train_epoch(model, data_loader, optimizer, scheduler, criterion, device):

    model.train()

    total_loss = 0

    correct = 0

    total = 0



    for images, labels in tqdm(data_loader):

        images, labels = images.to(device), labels.to(device)



        optimizer.zero_grad()

        # Use autocast only if CUDA is available

        if torch.cuda.is_available():

            with torch.cuda.amp.autocast():

                images = F.normalize(images.view(images.size(0), -1), dim=1, p=2)

                images = torch.clamp((images * (config.vocab_size - 1)).round(), 0, config.vocab_size - 1).long()

                attention_mask = torch.ones(images.size(), dtype=torch.long).to(device)

                outputs = model(input_ids=images, attention_mask=attention_mask)

                loss = criterion(outputs.logits, labels)

        else:

            images = F.normalize(images.view(images.size(0), -1), dim=1, p=2)

            images = torch.clamp((images * (config.vocab_size - 1)).round(), 0, config.vocab_size - 1).long()

            attention_mask = torch.ones(images.size(), dtype=torch.long).to(device)

            outputs = model(input_ids=images, attention_mask=attention_mask)

            loss = criterion(outputs.logits, labels)



        if torch.cuda.is_available():

            scaler.scale(loss).backward()

            scaler.step(optimizer)

            scaler.update()

        else:

            loss.backward()

            optimizer.step()



        total_loss += loss.item()

        preds = outputs.logits.argmax(dim=1)

        correct += (preds == labels).sum().item()

        total += labels.size(0)



    scheduler.step()

    return total_loss / len(data_loader), correct / total



def evaluate(model, data_loader, criterion, device):

    model.eval()

    total_loss = 0

    correct = 0

    total = 0



    with torch.no_grad():

        for images, labels in tqdm(data_loader):

            images, labels = images.to(device), labels.to(device)



            if torch.cuda.is_available():

                with torch.cuda.amp.autocast():

                    images = F.normalize(images.view(images.size(0), -1), dim=1, p=2)

                    images = torch.clamp((images * (config.vocab_size - 1)).round(), 0, config.vocab_size - 1).long()

                    attention_mask = torch.ones(images.size(), dtype=torch.long).to(device)

                    outputs = model(input_ids=images, attention_mask=attention_mask)

                    loss = criterion(outputs.logits, labels)

            else:

                images = F.normalize(images.view(images.size(0), -1), dim=1, p=2)

                images = torch.clamp((images * (config.vocab_size - 1)).round(), 0, config.vocab_size - 1).long()

                attention_mask = torch.ones(images.size(), dtype=torch.long).to(device)

                outputs = model(input_ids=images, attention_mask=attention_mask)

                loss = criterion(outputs.logits, labels)



            total_loss += loss.item()

            preds = outputs.logits.argmax(dim=1)

            correct += (preds == labels).sum().item()

            total += labels.size(0)



    return total_loss / len(data_loader), correct / total




  scaler = torch.cuda.amp.GradScaler()


In [8]:
# Train the Model

num_epochs = 20



for epoch in range(num_epochs):

    print(f"Epoch {epoch + 1}/{num_epochs}")



    train_loss, train_acc = train_epoch(

        model, train_loader, optimizer, scheduler, criterion, device

    )

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")



    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")



Epoch 1/20


  with torch.cuda.amp.autocast():
100%|██████████| 1563/1563 [13:01<00:00,  2.00it/s]


Train Loss: 2.2088, Train Accuracy: 0.1808


  with torch.cuda.amp.autocast():
100%|██████████| 313/313 [01:00<00:00,  5.21it/s]


Validation Loss: 2.1602, Validation Accuracy: 0.1971
Epoch 2/20


100%|██████████| 1563/1563 [13:00<00:00,  2.00it/s]


Train Loss: 2.1290, Train Accuracy: 0.2179


100%|██████████| 313/313 [01:00<00:00,  5.21it/s]


Validation Loss: 2.1000, Validation Accuracy: 0.2309
Epoch 3/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2161, Train Accuracy: 0.1570


100%|██████████| 313/313 [01:00<00:00,  5.17it/s]


Validation Loss: 2.3038, Validation Accuracy: 0.1000
Epoch 4/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.2859, Train Accuracy: 0.1157


100%|██████████| 313/313 [01:00<00:00,  5.16it/s]


Validation Loss: 2.1347, Validation Accuracy: 0.2111
Epoch 5/20


100%|██████████| 1563/1563 [13:03<00:00,  1.99it/s]


Train Loss: 2.1392, Train Accuracy: 0.2092


100%|██████████| 313/313 [01:00<00:00,  5.16it/s]


Validation Loss: 2.1549, Validation Accuracy: 0.2052
Epoch 6/20


100%|██████████| 1563/1563 [13:04<00:00,  1.99it/s]


Train Loss: 2.1398, Train Accuracy: 0.2129


100%|██████████| 313/313 [01:00<00:00,  5.17it/s]


Validation Loss: 2.1516, Validation Accuracy: 0.2022
Epoch 7/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.1932, Train Accuracy: 0.1850


100%|██████████| 313/313 [01:00<00:00,  5.17it/s]


Validation Loss: 2.1967, Validation Accuracy: 0.1831
Epoch 8/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.2284, Train Accuracy: 0.1575


100%|██████████| 313/313 [01:00<00:00,  5.19it/s]


Validation Loss: 2.2400, Validation Accuracy: 0.1462
Epoch 9/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.2588, Train Accuracy: 0.1352


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2639, Validation Accuracy: 0.1340
Epoch 10/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2605, Train Accuracy: 0.1360


100%|██████████| 313/313 [01:00<00:00,  5.19it/s]


Validation Loss: 2.2586, Validation Accuracy: 0.1336
Epoch 11/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2603, Train Accuracy: 0.1339


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2586, Validation Accuracy: 0.1336
Epoch 12/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2616, Train Accuracy: 0.1332


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2563, Validation Accuracy: 0.1358
Epoch 13/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.2250, Train Accuracy: 0.1560


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2191, Validation Accuracy: 0.1595
Epoch 14/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2280, Train Accuracy: 0.1554


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2036, Validation Accuracy: 0.1727
Epoch 15/20


100%|██████████| 1563/1563 [13:03<00:00,  2.00it/s]


Train Loss: 2.1902, Train Accuracy: 0.1880


100%|██████████| 313/313 [01:00<00:00,  5.17it/s]


Validation Loss: 2.1884, Validation Accuracy: 0.1829
Epoch 16/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2598, Train Accuracy: 0.1394


100%|██████████| 313/313 [01:00<00:00,  5.19it/s]


Validation Loss: 2.2505, Validation Accuracy: 0.1474
Epoch 17/20


100%|██████████| 1563/1563 [13:02<00:00,  2.00it/s]


Train Loss: 2.2603, Train Accuracy: 0.1403


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.2623, Validation Accuracy: 0.1378
Epoch 18/20


100%|██████████| 1563/1563 [13:04<00:00,  1.99it/s]


Train Loss: 2.2473, Train Accuracy: 0.1473


100%|██████████| 313/313 [01:00<00:00,  5.15it/s]


Validation Loss: 2.2050, Validation Accuracy: 0.1883
Epoch 19/20


100%|██████████| 1563/1563 [13:04<00:00,  1.99it/s]


Train Loss: 2.1450, Train Accuracy: 0.2106


100%|██████████| 313/313 [01:00<00:00,  5.18it/s]


Validation Loss: 2.1294, Validation Accuracy: 0.2205
Epoch 20/20


100%|██████████| 1563/1563 [13:05<00:00,  1.99it/s]


Train Loss: 2.1188, Train Accuracy: 0.2202


100%|██████████| 313/313 [01:00<00:00,  5.15it/s]

Validation Loss: 2.1291, Validation Accuracy: 0.2203



