In [1]:
import torch
from modelutils.CIFAR10_LeNet import LeNet
from utils.utils import train, get_device
from torchvision import datasets, transforms
from optimizers.PSO import PSO


### Loading Data

In [2]:
batch_size = 64

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download and load the training and test datasets
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


### Model Setup

In [5]:
device = get_device()

model = LeNet().to(device)
optimizer = PSO(model.parameters(), inertial_weight=0.75, cognitive_coefficient=1.2, social_coefficient=1.1, num_particles=10, min_param_value=0, max_param_value=1)

Using MPS


### Training

In [6]:
train(10, optimizer, model, device, train_loader, test_loader)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1, Train Loss: 14316768832.000, Train Accuracy: 9.61%


 10%|█         | 1/10 [00:57<08:34, 57.22s/it]

Validation Loss: 1865553182.981, Validation Accuracy: 10.00%
Epoch 2, Train Loss: 14532391552.000, Train Accuracy: 10.13%


 20%|██        | 2/10 [01:54<07:37, 57.22s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 3, Train Loss: 14530554284.800, Train Accuracy: 10.13%


 30%|███       | 3/10 [02:50<06:37, 56.72s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 4, Train Loss: 14531721232.640, Train Accuracy: 10.13%


 40%|████      | 4/10 [03:47<05:40, 56.77s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 5, Train Loss: 14534150174.720, Train Accuracy: 10.13%


 50%|█████     | 5/10 [04:44<04:44, 56.82s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 6, Train Loss: 14539696894.720, Train Accuracy: 10.13%


 60%|██████    | 6/10 [05:41<03:48, 57.00s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 7, Train Loss: 14530697908.480, Train Accuracy: 10.13%


 70%|███████   | 7/10 [06:37<02:49, 56.55s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 8, Train Loss: 14534378161.920, Train Accuracy: 10.13%


 80%|████████  | 8/10 [07:32<01:52, 56.23s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 9, Train Loss: 14532790913.280, Train Accuracy: 10.13%


 90%|█████████ | 9/10 [08:27<00:55, 55.76s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Epoch 10, Train Loss: 14531892454.400, Train Accuracy: 10.13%


100%|██████████| 10/10 [09:23<00:00, 56.31s/it]

Validation Loss: 1865553253.911, Validation Accuracy: 10.00%
Finished Training



