In [1]:
import torch
from modelutils.MNIST_LeNet import LeNet
from utils.utils import train, get_device
from torchvision import datasets, transforms
from optimizers.PSO import PSO

### Loading Data

In [2]:
batch_size = 64

# Transform to normalize the data
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to fit LeNet architecture
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the training and test datasets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

### Model Setup

In [3]:
device = get_device()
model = LeNet().to(device)
optimizer = PSO(model.parameters(), inertial_weight=0.9, cognitive_coefficient=1.2, social_coefficient=1.1, num_particles=10, min_param_value=0., max_param_value=1.)

Using MPS


### Training

In [4]:
train(50, optimizer, model, device, train_loader, test_loader)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Train Loss: 7029817639.680, Train Accuracy: 12.30%


  2%|▏         | 1/50 [01:03<52:00, 63.68s/it]

Validation Loss: 734257385.172, Validation Accuracy: 12.22%
Epoch 2, Train Loss: 6832945511.040, Train Accuracy: 12.01%


  4%|▍         | 2/50 [02:03<49:08, 61.43s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 3, Train Loss: 6833440672.000, Train Accuracy: 12.01%


  6%|▌         | 3/50 [03:03<47:31, 60.66s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 4, Train Loss: 6834701969.600, Train Accuracy: 12.01%


  8%|▊         | 4/50 [04:03<46:22, 60.50s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 5, Train Loss: 6833759107.520, Train Accuracy: 12.01%


 10%|█         | 5/50 [05:03<45:07, 60.16s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 6, Train Loss: 6833856986.240, Train Accuracy: 12.01%


 12%|█▏        | 6/50 [06:03<44:09, 60.22s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 7, Train Loss: 6834354625.280, Train Accuracy: 12.01%


 14%|█▍        | 7/50 [07:04<43:28, 60.66s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 8, Train Loss: 6833809083.840, Train Accuracy: 12.01%


 16%|█▌        | 8/50 [08:04<42:18, 60.44s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 9, Train Loss: 6833646164.480, Train Accuracy: 12.01%


 18%|█▊        | 9/50 [09:04<41:04, 60.10s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 10, Train Loss: 6833636856.320, Train Accuracy: 12.01%


 20%|██        | 10/50 [10:04<40:03, 60.08s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 11, Train Loss: 6833808558.080, Train Accuracy: 12.01%


 22%|██▏       | 11/50 [11:04<39:05, 60.14s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 12, Train Loss: 6833350022.080, Train Accuracy: 12.01%


 24%|██▍       | 12/50 [12:02<37:39, 59.47s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 13, Train Loss: 6832588891.520, Train Accuracy: 12.01%


 26%|██▌       | 13/50 [13:00<36:17, 58.86s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 14, Train Loss: 6833270834.560, Train Accuracy: 12.01%


 28%|██▊       | 14/50 [13:57<35:01, 58.38s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 15, Train Loss: 6833719049.600, Train Accuracy: 12.01%


 30%|███       | 15/50 [14:54<33:50, 58.01s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 16, Train Loss: 6832893004.800, Train Accuracy: 12.01%


 32%|███▏      | 16/50 [15:54<33:09, 58.50s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 17, Train Loss: 6833198013.120, Train Accuracy: 12.01%


 34%|███▍      | 17/50 [16:56<32:47, 59.63s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 18, Train Loss: 6833670729.600, Train Accuracy: 12.01%


 36%|███▌      | 18/50 [17:58<32:11, 60.36s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 19, Train Loss: 6833182348.800, Train Accuracy: 12.01%


 38%|███▊      | 19/50 [19:00<31:27, 60.90s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 20, Train Loss: 6834782823.680, Train Accuracy: 12.01%


 40%|████      | 20/50 [20:03<30:45, 61.50s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 21, Train Loss: 6833861167.360, Train Accuracy: 12.01%


 42%|████▏     | 21/50 [21:07<30:02, 62.16s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 22, Train Loss: 6831841560.000, Train Accuracy: 12.01%


 44%|████▍     | 22/50 [22:10<29:09, 62.49s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 23, Train Loss: 6833818783.680, Train Accuracy: 12.01%


 46%|████▌     | 23/50 [23:13<28:08, 62.55s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 24, Train Loss: 6832729560.320, Train Accuracy: 12.01%


 48%|████▊     | 24/50 [24:15<27:01, 62.36s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 25, Train Loss: 6833306586.560, Train Accuracy: 12.01%


 50%|█████     | 25/50 [25:16<25:55, 62.22s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 26, Train Loss: 6833627335.680, Train Accuracy: 12.01%


 52%|█████▏    | 26/50 [26:18<24:52, 62.17s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 27, Train Loss: 6834089412.160, Train Accuracy: 12.01%


 54%|█████▍    | 27/50 [27:21<23:55, 62.41s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 28, Train Loss: 6833444478.080, Train Accuracy: 12.01%


 56%|█████▌    | 28/50 [28:24<22:53, 62.42s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 29, Train Loss: 6833446706.560, Train Accuracy: 12.01%


 58%|█████▊    | 29/50 [29:27<21:54, 62.59s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 30, Train Loss: 6833398570.880, Train Accuracy: 12.01%


 60%|██████    | 30/50 [30:32<21:04, 63.23s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 31, Train Loss: 6833658164.160, Train Accuracy: 12.01%


 62%|██████▏   | 31/50 [31:35<20:00, 63.20s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 32, Train Loss: 6833069169.920, Train Accuracy: 12.01%


 64%|██████▍   | 32/50 [32:39<19:00, 63.39s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 33, Train Loss: 6833552661.120, Train Accuracy: 12.01%


 66%|██████▌   | 33/50 [33:42<17:55, 63.29s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 34, Train Loss: 6832742693.760, Train Accuracy: 12.01%


 68%|██████▊   | 34/50 [34:44<16:48, 63.03s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 35, Train Loss: 6833763987.520, Train Accuracy: 12.01%


 70%|███████   | 35/50 [35:46<15:40, 62.67s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 36, Train Loss: 6833354983.360, Train Accuracy: 12.01%


 72%|███████▏  | 36/50 [36:49<14:38, 62.72s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 37, Train Loss: 6832975678.080, Train Accuracy: 12.01%


 74%|███████▍  | 37/50 [37:50<13:31, 62.41s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 38, Train Loss: 6832593218.880, Train Accuracy: 12.01%


 76%|███████▌  | 38/50 [38:52<12:26, 62.21s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 39, Train Loss: 6833340826.880, Train Accuracy: 12.01%


 78%|███████▊  | 39/50 [39:54<11:22, 62.08s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 40, Train Loss: 6834172892.160, Train Accuracy: 12.01%


 80%|████████  | 40/50 [40:56<10:20, 62.08s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 41, Train Loss: 6833984840.320, Train Accuracy: 12.01%


 82%|████████▏ | 41/50 [41:59<09:20, 62.29s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 42, Train Loss: 6833811025.280, Train Accuracy: 12.01%


 84%|████████▍ | 42/50 [43:01<08:18, 62.37s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 43, Train Loss: 6833680042.880, Train Accuracy: 12.01%


 86%|████████▌ | 43/50 [44:04<07:17, 62.50s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 44, Train Loss: 6833519732.800, Train Accuracy: 12.01%


 88%|████████▊ | 44/50 [45:07<06:15, 62.60s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 45, Train Loss: 6832996206.720, Train Accuracy: 12.01%


 90%|█████████ | 45/50 [46:10<05:13, 62.66s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 46, Train Loss: 6833386403.840, Train Accuracy: 12.01%


 92%|█████████▏| 46/50 [47:13<04:10, 62.69s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 47, Train Loss: 6834357182.720, Train Accuracy: 12.01%


 94%|█████████▍| 47/50 [48:15<03:08, 62.75s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 48, Train Loss: 6833280442.880, Train Accuracy: 12.01%


 96%|█████████▌| 48/50 [49:18<02:05, 62.74s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 49, Train Loss: 6833413325.120, Train Accuracy: 12.01%


 98%|█████████▊| 49/50 [50:20<01:02, 62.51s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Epoch 50, Train Loss: 6832522717.120, Train Accuracy: 12.01%


100%|██████████| 50/50 [51:22<00:00, 61.65s/it]

Validation Loss: 734261069.045, Validation Accuracy: 12.23%
Finished Training



