In [1]:
import torch
from modelutils.MNIST_LeNet import LeNet
from utils.utils import train, get_device
from torchvision import datasets, transforms
from optimizers.PSO import PSO

### Loading Data

In [2]:
batch_size = 256

# Transform to normalize the data
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to fit LeNet architecture
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the training and test datasets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

### Model Setup

In [3]:
device = get_device()
model = LeNet().to(device)
optimizer = PSO(model.parameters(), inertial_weight=0.9, cognitive_coefficient=1.2, social_coefficient=1.2, num_particles=30)

Using MPS


### Training

In [None]:
train(100, optimizer, model, device, train_loader, test_loader)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Train Loss: 6565983457.280, Train Accuracy: 12.17%


  1%|          | 1/100 [00:51<1:24:22, 51.14s/it]

Validation Loss: 2634232473.600, Validation Accuracy: 12.84%
Epoch 2, Train Loss: 6109156152.320, Train Accuracy: 11.47%


  2%|▏         | 2/100 [01:40<1:21:32, 49.93s/it]

Validation Loss: 2485049414.400, Validation Accuracy: 9.66%
Epoch 3, Train Loss: 5787399521.280, Train Accuracy: 10.77%


  3%|▎         | 3/100 [02:29<1:20:17, 49.67s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 4, Train Loss: 5800277222.400, Train Accuracy: 10.83%


  4%|▍         | 4/100 [03:19<1:19:23, 49.62s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 5, Train Loss: 5799664372.480, Train Accuracy: 10.83%


  5%|▌         | 5/100 [04:08<1:18:25, 49.53s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 6, Train Loss: 5801049367.040, Train Accuracy: 10.83%


  6%|▌         | 6/100 [04:57<1:17:07, 49.23s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 7, Train Loss: 5798215219.200, Train Accuracy: 10.83%


  7%|▋         | 7/100 [05:46<1:16:32, 49.38s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 8, Train Loss: 5798954365.440, Train Accuracy: 10.83%


  8%|▊         | 8/100 [06:36<1:15:38, 49.33s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 9, Train Loss: 5799915924.480, Train Accuracy: 10.83%


  9%|▉         | 9/100 [07:24<1:14:20, 49.01s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 10, Train Loss: 5799123932.160, Train Accuracy: 10.83%


 10%|█         | 10/100 [08:13<1:13:37, 49.08s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 11, Train Loss: 5801062849.280, Train Accuracy: 10.83%


 11%|█         | 11/100 [09:03<1:13:05, 49.28s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 12, Train Loss: 5798241930.240, Train Accuracy: 10.83%


 12%|█▏        | 12/100 [09:52<1:12:18, 49.30s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 13, Train Loss: 5797436285.440, Train Accuracy: 10.83%


 13%|█▎        | 13/100 [10:43<1:11:57, 49.63s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 14, Train Loss: 5800491281.920, Train Accuracy: 10.83%


 14%|█▍        | 14/100 [11:34<1:11:51, 50.14s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 15, Train Loss: 5796277456.640, Train Accuracy: 10.83%


 15%|█▌        | 15/100 [12:24<1:11:14, 50.28s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 16, Train Loss: 5798924321.280, Train Accuracy: 10.83%


 16%|█▌        | 16/100 [13:15<1:10:31, 50.38s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 17, Train Loss: 5799295531.520, Train Accuracy: 10.83%


 17%|█▋        | 17/100 [14:05<1:09:31, 50.26s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 18, Train Loss: 5798921264.640, Train Accuracy: 10.83%


 18%|█▊        | 18/100 [14:56<1:08:45, 50.32s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 19, Train Loss: 5798649643.520, Train Accuracy: 10.83%


 19%|█▉        | 19/100 [15:46<1:08:10, 50.50s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 20, Train Loss: 5796846074.880, Train Accuracy: 10.83%


 20%|██        | 20/100 [16:37<1:07:29, 50.61s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 21, Train Loss: 5800751142.400, Train Accuracy: 10.83%


 21%|██        | 21/100 [17:28<1:06:41, 50.65s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 22, Train Loss: 5801760258.560, Train Accuracy: 10.83%


 22%|██▏       | 22/100 [18:18<1:05:41, 50.54s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 23, Train Loss: 5799322296.320, Train Accuracy: 10.83%


 23%|██▎       | 23/100 [19:09<1:04:56, 50.61s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 24, Train Loss: 5800784120.320, Train Accuracy: 10.83%


 24%|██▍       | 24/100 [19:59<1:03:54, 50.45s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 25, Train Loss: 5798367313.920, Train Accuracy: 10.83%


 25%|██▌       | 25/100 [20:50<1:03:07, 50.50s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 26, Train Loss: 5803686640.640, Train Accuracy: 10.83%


 26%|██▌       | 26/100 [21:40<1:02:12, 50.44s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 27, Train Loss: 5798815754.240, Train Accuracy: 10.83%


 27%|██▋       | 27/100 [22:30<1:01:06, 50.23s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 28, Train Loss: 5797820119.040, Train Accuracy: 10.83%


 28%|██▊       | 28/100 [23:20<1:00:20, 50.29s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 29, Train Loss: 5798145876.480, Train Accuracy: 10.83%


 29%|██▉       | 29/100 [24:11<59:36, 50.37s/it]  

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 30, Train Loss: 5800014561.280, Train Accuracy: 10.83%


 30%|███       | 30/100 [25:02<58:56, 50.52s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 31, Train Loss: 5799880391.680, Train Accuracy: 10.83%


 31%|███       | 31/100 [25:53<58:12, 50.62s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
Epoch 32, Train Loss: 5799372531.200, Train Accuracy: 10.83%


 32%|███▏      | 32/100 [26:44<57:29, 50.73s/it]

Validation Loss: 2483149577.600, Validation Accuracy: 10.89%
