In [1]:
import torch
from modelutils.CIFAR10_LeNet import LeNet
from utils.utils import train, get_device
from torchvision import datasets, transforms
from optimizers.PSO import PSO


### Loading Data

In [2]:
batch_size = 256

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download and load the training and test datasets
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


### Model Setup

In [3]:
device = get_device()
model = LeNet().to(device)
optimizer = PSO(model.parameters(), inertial_weight=0.9, cognitive_coefficient=1.2, social_coefficient=1.2, num_particles=30)

Using MPS


### Training

In [None]:
train(50, optimizer, model, device, train_loader, test_loader)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train Loss: 6042637248.000, Train Accuracy: 11.14%


  2%|▏         | 1/50 [00:44<36:22, 44.55s/it]

Validation Loss: 2762804115.200, Validation Accuracy: 11.10%
Epoch 2, Train Loss: 5335729484.800, Train Accuracy: 10.96%


  4%|▍         | 2/50 [01:28<35:35, 44.49s/it]

Validation Loss: 2984766809.600, Validation Accuracy: 8.67%
Epoch 3, Train Loss: 5700367503.360, Train Accuracy: 8.64%


  6%|▌         | 3/50 [02:13<34:46, 44.39s/it]

Validation Loss: 2928059596.800, Validation Accuracy: 9.22%
Epoch 4, Train Loss: 5614156989.440, Train Accuracy: 8.95%


  8%|▊         | 4/50 [02:57<33:57, 44.29s/it]

Validation Loss: 2928059596.800, Validation Accuracy: 9.22%
Epoch 5, Train Loss: 5615830131.200, Train Accuracy: 8.95%


 10%|█         | 5/50 [03:41<33:14, 44.32s/it]

Validation Loss: 2928059596.800, Validation Accuracy: 9.22%
Epoch 6, Train Loss: 5613584396.800, Train Accuracy: 8.95%


 12%|█▏        | 6/50 [04:25<32:23, 44.17s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 7, Train Loss: 5612796298.240, Train Accuracy: 8.95%


 14%|█▍        | 7/50 [05:09<31:28, 43.92s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 8, Train Loss: 5612897943.040, Train Accuracy: 8.95%


 16%|█▌        | 8/50 [05:53<30:46, 43.96s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 9, Train Loss: 5611801418.240, Train Accuracy: 8.95%


 18%|█▊        | 9/50 [06:37<30:08, 44.10s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 10, Train Loss: 5617501327.360, Train Accuracy: 8.95%


 20%|██        | 10/50 [07:22<29:40, 44.51s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 11, Train Loss: 5611737489.920, Train Accuracy: 8.95%


 22%|██▏       | 11/50 [08:07<28:59, 44.61s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 12, Train Loss: 5612954711.040, Train Accuracy: 8.95%


 24%|██▍       | 12/50 [08:52<28:15, 44.62s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 13, Train Loss: 5614686323.200, Train Accuracy: 8.95%


 26%|██▌       | 13/50 [09:37<27:32, 44.67s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 14, Train Loss: 5610874926.080, Train Accuracy: 8.95%


 28%|██▊       | 14/50 [10:22<26:53, 44.82s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 15, Train Loss: 5610874065.920, Train Accuracy: 8.95%


 30%|███       | 15/50 [11:07<26:12, 44.94s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
Epoch 16, Train Loss: 5611628899.840, Train Accuracy: 8.95%


 32%|███▏      | 16/50 [11:52<25:31, 45.03s/it]

Validation Loss: 2927898995.200, Validation Accuracy: 9.22%
