In [1]:
from src import notebook, datasets, build_models,caluclate_basis, custom_blocks, train, utils
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import wandb
from tqdm import tqdm   
from torchsummary import summary

In [None]:
def make(config):

    train_loader, val_loader, test_loader = datasets.get_dataloaders(config, logfile=None, summaryfile=None, log=False) 

    model = build_models.build_model_from_args(config)

    if config.criterion == "CrossEntropyLoss":
        criterion = nn.CrossEntropyLoss()
    elif config.criterion == "MSELoss":
        criterion = nn.MSELoss()
    else:
        raise ValueError(f"Criterion {criterion} not supported")
    
    if config.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=config.lr)
    else:
        raise ValueError(f"Optimizer {optimizer} not supported")
    
    return model, train_loader, val_loader, test_loader, criterion, optimizer

In [5]:
hyperparameters = dict(arch=['64', 'M', '128', 'M'], 
                                 batch_norm=False,
                                 bias=True,
                                 classifier_layers=[256],
                                 classifier_bias= True,
                                 classifier_dropout=0,
                                 avgpool=False,
                                 avgpool_size=[1,1],  # Adjusted avgpool size to (1, 1)
                                 dataset="mnist",
                                 input_height = 28,
                                 input_width = 28,
                                 greyscale=True,
                                 n_classes= 10,
                                 epochs=10, 
                                 batch_size=64,
                                 optimizer="Adam",
                                 weight_normalization=True,
                                 lr=0.001,
                                 criterion="CrossEntropyLoss",
                                 seed=42,
                                 save_checkpoints=False,
                                 save_dir = "",
                                 checkpoint_type = "epoch",
                                 )

In [8]:
run = wandb.init(project="custom-vgg-model", config=hyperparameters)
config = wandb.config

model, train_loader, val_loader, test_loader, criterion, optimizer = make(config)

print(model)
print(criterion, optimizer)

# and use them to train the model
train.train(config, model, train_loader, val_loader, test_loader,criterion, optimizer)
#finish run
run.finish()

CustomModel(
  (features): Sequential(
    (0): ParametrizedConv2d(
      1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ParametrizedConv2d(
      64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): CustomClassifier(
    (classifier_layers): Sequential(
      (0): Sequential(
        (0): Linear(in_features=6272, out_features=256, bias=True)
        (1): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=10, bias=True

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0, Examples seen: 640, Loss: 1.8454
Epoch 0, Examples seen: 1280, Loss: 1.1765
Epoch 0, Examples seen: 1920, Loss: 0.8440
Epoch 0, Examples seen: 2560, Loss: 0.9053
Epoch 0, Examples seen: 3200, Loss: 0.8668
Epoch 0, Examples seen: 3840, Loss: 0.6454
Epoch 0, Examples seen: 4480, Loss: 1.0601
Epoch 0, Examples seen: 5120, Loss: 0.9273
Epoch 0, Examples seen: 5760, Loss: 0.9158
Epoch 0, Examples seen: 6400, Loss: 0.9566
Epoch 0, Examples seen: 7040, Loss: 0.6538
Epoch 0, Examples seen: 7680, Loss: 0.6869
Epoch 0, Examples seen: 8320, Loss: 0.9507
Epoch 0, Examples seen: 8960, Loss: 0.8323
Epoch 0, Examples seen: 9600, Loss: 0.7963
Epoch 0, Examples seen: 10240, Loss: 1.0615
Epoch 0, Examples seen: 10880, Loss: 0.8431
Epoch 0, Examples seen: 11520, Loss: 0.8355
Epoch 0, Examples seen: 12160, Loss: 1.0000
Epoch 0, Examples seen: 12800, Loss: 0.8068
Epoch 0, Examples seen: 13440, Loss: 0.8505
Epoch 0, Examples seen: 14080, Loss: 0.6801
Epoch 0, Examples seen: 14720, Loss: 0.6303
Epoc

 10%|█         | 1/10 [01:25<12:49, 85.54s/it]

Epoch 1/10, Loss: 0.6494, Train Accuracy: 74.38%, Validation Accuracy: 79.84%
Epoch 1, Examples seen: 48640, Loss: 0.3379
Epoch 1, Examples seen: 49280, Loss: 0.5203
Epoch 1, Examples seen: 49920, Loss: 0.5032
Epoch 1, Examples seen: 50560, Loss: 0.5389
Epoch 1, Examples seen: 51200, Loss: 0.3934
Epoch 1, Examples seen: 51840, Loss: 0.4826
Epoch 1, Examples seen: 52480, Loss: 0.5194
Epoch 1, Examples seen: 53120, Loss: 0.4699
Epoch 1, Examples seen: 53760, Loss: 0.3827
Epoch 1, Examples seen: 54400, Loss: 0.8256
Epoch 1, Examples seen: 55040, Loss: 0.4773
Epoch 1, Examples seen: 55680, Loss: 0.3485
Epoch 1, Examples seen: 56320, Loss: 0.5173
Epoch 1, Examples seen: 56960, Loss: 0.3703
Epoch 1, Examples seen: 57600, Loss: 0.5073
Epoch 1, Examples seen: 58240, Loss: 0.4353
Epoch 1, Examples seen: 58880, Loss: 0.3856
Epoch 1, Examples seen: 59520, Loss: 0.5032
Epoch 1, Examples seen: 60160, Loss: 0.6110
Epoch 1, Examples seen: 60800, Loss: 0.4938
Epoch 1, Examples seen: 61440, Loss: 0.346

 20%|██        | 2/10 [02:06<07:53, 59.20s/it]

Epoch 2/10, Loss: 0.4739, Train Accuracy: 80.13%, Validation Accuracy: 80.08%
Epoch 2, Examples seen: 96640, Loss: 0.6942
Epoch 2, Examples seen: 97280, Loss: 0.4534
Epoch 2, Examples seen: 97920, Loss: 0.3989
Epoch 2, Examples seen: 98560, Loss: 0.4004
Epoch 2, Examples seen: 99200, Loss: 0.4768
Epoch 2, Examples seen: 99840, Loss: 0.4431
Epoch 2, Examples seen: 100480, Loss: 0.5303
Epoch 2, Examples seen: 101120, Loss: 0.5848
Epoch 2, Examples seen: 101760, Loss: 0.3017
Epoch 2, Examples seen: 102400, Loss: 0.6178
Epoch 2, Examples seen: 103040, Loss: 0.3481
Epoch 2, Examples seen: 103680, Loss: 0.4745
Epoch 2, Examples seen: 104320, Loss: 0.5842
Epoch 2, Examples seen: 104960, Loss: 0.4714
Epoch 2, Examples seen: 105600, Loss: 0.5986
Epoch 2, Examples seen: 106240, Loss: 0.5055
Epoch 2, Examples seen: 106880, Loss: 0.5433
Epoch 2, Examples seen: 107520, Loss: 0.5130
Epoch 2, Examples seen: 108160, Loss: 0.5456
Epoch 2, Examples seen: 108800, Loss: 0.4326
Epoch 2, Examples seen: 1094

 30%|███       | 3/10 [02:50<06:06, 52.36s/it]

Epoch 3/10, Loss: 0.4607, Train Accuracy: 80.48%, Validation Accuracy: 80.07%
Epoch 3, Examples seen: 144640, Loss: 0.4940
Epoch 3, Examples seen: 145280, Loss: 0.5658
Epoch 3, Examples seen: 145920, Loss: 0.4114
Epoch 3, Examples seen: 146560, Loss: 0.3626
Epoch 3, Examples seen: 147200, Loss: 0.3659
Epoch 3, Examples seen: 147840, Loss: 0.2978
Epoch 3, Examples seen: 148480, Loss: 0.2896
Epoch 3, Examples seen: 149120, Loss: 0.5722
Epoch 3, Examples seen: 149760, Loss: 0.5215
Epoch 3, Examples seen: 150400, Loss: 0.2901
Epoch 3, Examples seen: 151040, Loss: 0.3256
Epoch 3, Examples seen: 151680, Loss: 0.2888
Epoch 3, Examples seen: 152320, Loss: 0.2532
Epoch 3, Examples seen: 152960, Loss: 0.5786
Epoch 3, Examples seen: 153600, Loss: 0.5128
Epoch 3, Examples seen: 154240, Loss: 0.5618
Epoch 3, Examples seen: 154880, Loss: 0.7223
Epoch 3, Examples seen: 155520, Loss: 0.5064
Epoch 3, Examples seen: 156160, Loss: 0.2652
Epoch 3, Examples seen: 156800, Loss: 0.3687
Epoch 3, Examples seen

 40%|████      | 4/10 [03:28<04:39, 46.57s/it]

Epoch 4/10, Loss: 0.4544, Train Accuracy: 80.59%, Validation Accuracy: 80.28%
Epoch 4, Examples seen: 192640, Loss: 0.4798
Epoch 4, Examples seen: 193280, Loss: 0.4320
Epoch 4, Examples seen: 193920, Loss: 0.2936
Epoch 4, Examples seen: 194560, Loss: 0.3243
Epoch 4, Examples seen: 195200, Loss: 0.2642
Epoch 4, Examples seen: 195840, Loss: 0.2887
Epoch 4, Examples seen: 196480, Loss: 0.5475
Epoch 4, Examples seen: 197120, Loss: 0.8112
Epoch 4, Examples seen: 197760, Loss: 0.4328
Epoch 4, Examples seen: 198400, Loss: 0.4720
Epoch 4, Examples seen: 199040, Loss: 0.3332
Epoch 4, Examples seen: 199680, Loss: 0.2559
Epoch 4, Examples seen: 200320, Loss: 0.5079
Epoch 4, Examples seen: 200960, Loss: 0.5815
Epoch 4, Examples seen: 201600, Loss: 0.3478
Epoch 4, Examples seen: 202240, Loss: 0.4007
Epoch 4, Examples seen: 202880, Loss: 0.5306
Epoch 4, Examples seen: 203520, Loss: 0.5054
Epoch 4, Examples seen: 204160, Loss: 0.3239
Epoch 4, Examples seen: 204800, Loss: 0.3644
Epoch 4, Examples seen

 50%|█████     | 5/10 [04:05<03:36, 43.28s/it]

Epoch 5/10, Loss: 0.4494, Train Accuracy: 80.75%, Validation Accuracy: 80.28%
Epoch 5, Examples seen: 240640, Loss: 0.6494
Epoch 5, Examples seen: 241280, Loss: 0.3271
Epoch 5, Examples seen: 241920, Loss: 0.5940
Epoch 5, Examples seen: 242560, Loss: 0.4563
Epoch 5, Examples seen: 243200, Loss: 0.3838
Epoch 5, Examples seen: 243840, Loss: 0.4329
Epoch 5, Examples seen: 244480, Loss: 0.2521
Epoch 5, Examples seen: 245120, Loss: 0.3706
Epoch 5, Examples seen: 245760, Loss: 0.5425
Epoch 5, Examples seen: 246400, Loss: 0.3819
Epoch 5, Examples seen: 247040, Loss: 0.5428
Epoch 5, Examples seen: 247680, Loss: 0.5405
Epoch 5, Examples seen: 248320, Loss: 0.5764
Epoch 5, Examples seen: 248960, Loss: 0.3610
Epoch 5, Examples seen: 249600, Loss: 0.4327
Epoch 5, Examples seen: 250240, Loss: 0.3623
Epoch 5, Examples seen: 250880, Loss: 0.3603
Epoch 5, Examples seen: 251520, Loss: 0.2881
Epoch 5, Examples seen: 252160, Loss: 0.4347
Epoch 5, Examples seen: 252800, Loss: 0.4685
Epoch 5, Examples seen

 60%|██████    | 6/10 [04:43<02:45, 41.41s/it]

Epoch 6/10, Loss: 0.4472, Train Accuracy: 80.79%, Validation Accuracy: 80.27%
Epoch 6, Examples seen: 288640, Loss: 0.4326
Epoch 6, Examples seen: 289280, Loss: 0.4643
Epoch 6, Examples seen: 289920, Loss: 0.5105
Epoch 6, Examples seen: 290560, Loss: 0.4688
Epoch 6, Examples seen: 291200, Loss: 0.3602
Epoch 6, Examples seen: 291840, Loss: 0.3790
Epoch 6, Examples seen: 292480, Loss: 0.3966
Epoch 6, Examples seen: 293120, Loss: 0.2262
Epoch 6, Examples seen: 293760, Loss: 0.4361
Epoch 6, Examples seen: 294400, Loss: 0.4346
Epoch 6, Examples seen: 295040, Loss: 0.6477
Epoch 6, Examples seen: 295680, Loss: 0.4342
Epoch 6, Examples seen: 296320, Loss: 0.5757
Epoch 6, Examples seen: 296960, Loss: 0.4680
Epoch 6, Examples seen: 297600, Loss: 0.5039
Epoch 6, Examples seen: 298240, Loss: 0.6950
Epoch 6, Examples seen: 298880, Loss: 0.3600
Epoch 6, Examples seen: 299520, Loss: 0.3627
Epoch 6, Examples seen: 300160, Loss: 0.3633
Epoch 6, Examples seen: 300800, Loss: 0.4708
Epoch 6, Examples seen

 70%|███████   | 7/10 [05:27<02:06, 42.18s/it]

Epoch 7/10, Loss: 0.4447, Train Accuracy: 80.88%, Validation Accuracy: 80.24%
Epoch 7, Examples seen: 336640, Loss: 0.5053
Epoch 7, Examples seen: 337280, Loss: 0.3241
Epoch 7, Examples seen: 337920, Loss: 0.7231
Epoch 7, Examples seen: 338560, Loss: 0.4326
Epoch 7, Examples seen: 339200, Loss: 0.4321
Epoch 7, Examples seen: 339840, Loss: 0.3959
Epoch 7, Examples seen: 340480, Loss: 0.4337
Epoch 7, Examples seen: 341120, Loss: 0.4614
Epoch 7, Examples seen: 341760, Loss: 0.3840
Epoch 7, Examples seen: 342400, Loss: 0.3631
Epoch 7, Examples seen: 343040, Loss: 0.5415
Epoch 7, Examples seen: 343680, Loss: 0.2884
Epoch 7, Examples seen: 344320, Loss: 0.3962
Epoch 7, Examples seen: 344960, Loss: 0.3255
Epoch 7, Examples seen: 345600, Loss: 0.6314
Epoch 7, Examples seen: 346240, Loss: 0.2896
Epoch 7, Examples seen: 346880, Loss: 0.5038
Epoch 7, Examples seen: 347520, Loss: 0.3241
Epoch 7, Examples seen: 348160, Loss: 0.3971
Epoch 7, Examples seen: 348800, Loss: 0.6478
Epoch 7, Examples seen

 80%|████████  | 8/10 [06:08<01:23, 41.80s/it]

Epoch 8/10, Loss: 0.4430, Train Accuracy: 80.91%, Validation Accuracy: 80.03%
Epoch 8, Examples seen: 384640, Loss: 0.5595
Epoch 8, Examples seen: 385280, Loss: 0.4446
Epoch 8, Examples seen: 385920, Loss: 0.5037
Epoch 8, Examples seen: 386560, Loss: 0.3261
Epoch 8, Examples seen: 387200, Loss: 0.2968
Epoch 8, Examples seen: 387840, Loss: 0.4067
Epoch 8, Examples seen: 388480, Loss: 0.3816
Epoch 8, Examples seen: 389120, Loss: 0.3465
Epoch 8, Examples seen: 389760, Loss: 0.5046
Epoch 8, Examples seen: 390400, Loss: 0.6117
Epoch 8, Examples seen: 391040, Loss: 0.3958
Epoch 8, Examples seen: 391680, Loss: 0.4318
Epoch 8, Examples seen: 392320, Loss: 0.4677
Epoch 8, Examples seen: 392960, Loss: 0.3969
Epoch 8, Examples seen: 393600, Loss: 0.1815
Epoch 8, Examples seen: 394240, Loss: 0.4352
Epoch 8, Examples seen: 394880, Loss: 0.4319
Epoch 8, Examples seen: 395520, Loss: 0.4326
Epoch 8, Examples seen: 396160, Loss: 0.5757
Epoch 8, Examples seen: 396800, Loss: 0.3971
Epoch 8, Examples seen

 90%|█████████ | 9/10 [06:47<00:41, 41.05s/it]

Epoch 9/10, Loss: 0.4412, Train Accuracy: 80.96%, Validation Accuracy: 80.23%
Epoch 9, Examples seen: 432640, Loss: 0.7378
Epoch 9, Examples seen: 433280, Loss: 0.5405
Epoch 9, Examples seen: 433920, Loss: 0.2892
Epoch 9, Examples seen: 434560, Loss: 0.4677
Epoch 9, Examples seen: 435200, Loss: 0.4360
Epoch 9, Examples seen: 435840, Loss: 0.5039
Epoch 9, Examples seen: 436480, Loss: 0.3958
Epoch 9, Examples seen: 437120, Loss: 0.5039
Epoch 9, Examples seen: 437760, Loss: 0.2176
Epoch 9, Examples seen: 438400, Loss: 0.3805
Epoch 9, Examples seen: 439040, Loss: 0.3598
Epoch 9, Examples seen: 439680, Loss: 0.2910
Epoch 9, Examples seen: 440320, Loss: 0.3628
Epoch 9, Examples seen: 440960, Loss: 0.4679
Epoch 9, Examples seen: 441600, Loss: 0.2777
Epoch 9, Examples seen: 442240, Loss: 0.3257
Epoch 9, Examples seen: 442880, Loss: 0.6117
Epoch 9, Examples seen: 443520, Loss: 0.4269
Epoch 9, Examples seen: 444160, Loss: 0.5045
Epoch 9, Examples seen: 444800, Loss: 0.4677
Epoch 9, Examples seen

100%|██████████| 10/10 [07:27<00:00, 44.73s/it]


Epoch 10/10, Loss: 0.4413, Train Accuracy: 80.95%, Validation Accuracy: 80.32%


100%|██████████| 157/157 [00:04<00:00, 35.04it/s] 

Test Accuracy: 80.84%





0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇█████
loss,▇██▂▇▅▅▄▄▄▃▂▅▂▃▄▅▅▅▄▃▂▄▅▆▂▂▂▃▄▃▄▄▄▂▁▃▃▅▃
test_accuracy,▁
train_accuracy,▁▇▇███████
train_loss,█▂▂▁▁▁▁▁▁▁
val_accuracy,▁▄▄██▇▇▄▇█

0,1
epoch,9.0
loss,0.33137
test_accuracy,80.84
train_accuracy,80.95208
train_loss,0.4413
val_accuracy,80.31667


In [7]:
run.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅████████
loss,█▇▇▆▆▅█▆▄▃▃▃▄▂▅▅▄▇▆▄▆▂▁▁▄▅▅▄▆▅▄▁▃▅▃▃▄▇▅▃
train_accuracy,▁█
train_loss,█▁
val_accuracy,▁▁

0,1
epoch,2.0
loss,1.9371
train_accuracy,9.89167
train_loss,2.10063
val_accuracy,9.79167


In [None]:
# Load ResNet-18 model
def load_resnet18(n_classes):
    model = models.resnet18(pretrained=True)
    # Modify the final fully connected layer to have n_classes output units
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model