In [1]:
from src import notebook, datasets, build_models,caluclate_basis, custom_blocks, train, utils, averaging
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import wandb
import torchvision.transforms as transforms
from tqdm import tqdm   
from torchsummary import summary

In [2]:
def make(config):

    train_loader, val_loader, test_loader = datasets.get_dataloaders(config, logfile=None, summaryfile=None, log=False) 

    model = build_models.build_model_from_args(config)

    if config.criterion == "CrossEntropyLoss":
        criterion = nn.CrossEntropyLoss()
    elif config.criterion == "MSELoss":
        criterion = nn.MSELoss()
    else:
        raise ValueError(f"Criterion {criterion} not supported")
    
    if config.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=config.lr)
    else:
        raise ValueError(f"Optimizer {optimizer} not supported")
    
    return model, train_loader, val_loader, test_loader, criterion, optimizer

In [3]:
hyperparameters = dict(arch=['64', 'M', '128', 'M'], 
                                 batch_norm=False,
                                 bias=True,
                                 classifier_layers=[256],
                                 classifier_bias= True,
                                 classifier_dropout=0,
                                 avgpool=False,
                                 avgpool_size=[1,1],  # Adjusted avgpool size to (1, 1)
                                 dataset="mnist",
                                 input_height = 29,
                                 input_width = 29,
                                 greyscale=True,
                                 n_classes= 10,
                                 epochs=10, 
                                 batch_size=64,
                                 optimizer="Adam",
                                 weight_normalization=True,
                                 lr=0.001,
                                 criterion="CrossEntropyLoss",
                                 seed=42,
                                 save_checkpoints=False,
                                 save_dir = "",
                                 checkpoint_type = "epoch",
                                 )

In [4]:
run = wandb.init(project="averaging-with-vgg-model", config=hyperparameters)
config = wandb.config

model, train_loader, val_loader, test_loader, criterion, optimizer = make(config)

print(model)

print(criterion, optimizer)

# and use them to train the model
train.train(config, model, train_loader, val_loader, test_loader,criterion, optimizer)


#rotate the test set
test_loader = averaging.random_rotate_dataset(test_loader)

train.test_model(model, test_loader, device="cuda")

# and evaluate the model

averaging.average_over90degrees_and_evaluate(model, test_loader, device="cuda")

#finish run
run.finish()

[34m[1mwandb[0m: Currently logged in as: [33matsou2[0m ([33matsou2-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


CustomModel(
  (features): Sequential(
    (0): ParametrizedConv2d(
      1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ParametrizedConv2d(
      64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): CustomClassifier(
    (classifier_layers): Sequential(
      (0): Sequential(
        (0): Linear(in_features=6272, out_features=256, bias=True)
        (1): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=10, bias=True

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0, Examples seen: 640, Loss: 1.2896
Epoch 0, Examples seen: 1280, Loss: 1.2564
Epoch 0, Examples seen: 1920, Loss: 1.2692
Epoch 0, Examples seen: 2560, Loss: 1.1066
Epoch 0, Examples seen: 3200, Loss: 1.1233
Epoch 0, Examples seen: 3840, Loss: 1.1630
Epoch 0, Examples seen: 4480, Loss: 0.8945
Epoch 0, Examples seen: 5120, Loss: 0.9907
Epoch 0, Examples seen: 5760, Loss: 1.1783
Epoch 0, Examples seen: 6400, Loss: 1.1184
Epoch 0, Examples seen: 7040, Loss: 0.8697
Epoch 0, Examples seen: 7680, Loss: 0.9218
Epoch 0, Examples seen: 8320, Loss: 0.9798
Epoch 0, Examples seen: 8960, Loss: 1.0259
Epoch 0, Examples seen: 9600, Loss: 1.0623
Epoch 0, Examples seen: 10240, Loss: 1.2587
Epoch 0, Examples seen: 10880, Loss: 1.0381
Epoch 0, Examples seen: 11520, Loss: 1.1153
Epoch 0, Examples seen: 12160, Loss: 1.0526
Epoch 0, Examples seen: 12800, Loss: 1.2391
Epoch 0, Examples seen: 13440, Loss: 1.1174
Epoch 0, Examples seen: 14080, Loss: 1.2194
Epoch 0, Examples seen: 14720, Loss: 0.8442
Epoc

 10%|█         | 1/10 [00:53<08:04, 53.83s/it]

Epoch 1/10, Loss: 1.0259, Train Accuracy: 57.70%, Validation Accuracy: 58.67%
Epoch 1, Examples seen: 48640, Loss: 0.9154
Epoch 1, Examples seen: 49280, Loss: 0.9766
Epoch 1, Examples seen: 49920, Loss: 0.8410
Epoch 1, Examples seen: 50560, Loss: 0.9363
Epoch 1, Examples seen: 51200, Loss: 1.0907
Epoch 1, Examples seen: 51840, Loss: 0.7984
Epoch 1, Examples seen: 52480, Loss: 1.0266
Epoch 1, Examples seen: 53120, Loss: 0.9562
Epoch 1, Examples seen: 53760, Loss: 0.8553
Epoch 1, Examples seen: 54400, Loss: 0.9371
Epoch 1, Examples seen: 55040, Loss: 0.9173
Epoch 1, Examples seen: 55680, Loss: 0.7036
Epoch 1, Examples seen: 56320, Loss: 0.9137
Epoch 1, Examples seen: 56960, Loss: 0.7063
Epoch 1, Examples seen: 57600, Loss: 1.1453
Epoch 1, Examples seen: 58240, Loss: 0.7909
Epoch 1, Examples seen: 58880, Loss: 0.8045
Epoch 1, Examples seen: 59520, Loss: 1.0959
Epoch 1, Examples seen: 60160, Loss: 0.7606
Epoch 1, Examples seen: 60800, Loss: 0.9064
Epoch 1, Examples seen: 61440, Loss: 0.983

 20%|██        | 2/10 [01:41<06:41, 50.18s/it]

Epoch 2/10, Loss: 0.9686, Train Accuracy: 58.74%, Validation Accuracy: 59.19%
Epoch 2, Examples seen: 96640, Loss: 1.0208
Epoch 2, Examples seen: 97280, Loss: 0.7984
Epoch 2, Examples seen: 97920, Loss: 0.8212
Epoch 2, Examples seen: 98560, Loss: 0.6137
Epoch 2, Examples seen: 99200, Loss: 1.2654
Epoch 2, Examples seen: 99840, Loss: 0.8653
Epoch 2, Examples seen: 100480, Loss: 1.0603
Epoch 2, Examples seen: 101120, Loss: 0.9024
Epoch 2, Examples seen: 101760, Loss: 0.9028
Epoch 2, Examples seen: 102400, Loss: 1.0087
Epoch 2, Examples seen: 103040, Loss: 1.0444
Epoch 2, Examples seen: 103680, Loss: 0.8288
Epoch 2, Examples seen: 104320, Loss: 0.8676
Epoch 2, Examples seen: 104960, Loss: 0.7982
Epoch 2, Examples seen: 105600, Loss: 0.7197
Epoch 2, Examples seen: 106240, Loss: 0.9543
Epoch 2, Examples seen: 106880, Loss: 0.6863
Epoch 2, Examples seen: 107520, Loss: 1.0090
Epoch 2, Examples seen: 108160, Loss: 0.9811
Epoch 2, Examples seen: 108800, Loss: 1.0887
Epoch 2, Examples seen: 1094

 30%|███       | 3/10 [02:27<05:38, 48.33s/it]

Epoch 3/10, Loss: 0.9560, Train Accuracy: 58.96%, Validation Accuracy: 59.08%
Epoch 3, Examples seen: 144640, Loss: 1.0016
Epoch 3, Examples seen: 145280, Loss: 0.8312
Epoch 3, Examples seen: 145920, Loss: 1.1886
Epoch 3, Examples seen: 146560, Loss: 1.0165
Epoch 3, Examples seen: 147200, Loss: 0.8418
Epoch 3, Examples seen: 147840, Loss: 0.9731
Epoch 3, Examples seen: 148480, Loss: 0.8103
Epoch 3, Examples seen: 149120, Loss: 0.9455
Epoch 3, Examples seen: 149760, Loss: 0.9395
Epoch 3, Examples seen: 150400, Loss: 1.0937
Epoch 3, Examples seen: 151040, Loss: 0.8956
Epoch 3, Examples seen: 151680, Loss: 0.7552
Epoch 3, Examples seen: 152320, Loss: 1.0097
Epoch 3, Examples seen: 152960, Loss: 0.9356
Epoch 3, Examples seen: 153600, Loss: 1.2233
Epoch 3, Examples seen: 154240, Loss: 1.0795
Epoch 3, Examples seen: 154880, Loss: 0.7953
Epoch 3, Examples seen: 155520, Loss: 1.1926
Epoch 3, Examples seen: 156160, Loss: 0.8720
Epoch 3, Examples seen: 156800, Loss: 0.8642
Epoch 3, Examples seen

 40%|████      | 4/10 [03:27<05:16, 52.76s/it]

Epoch 4/10, Loss: 0.9512, Train Accuracy: 59.04%, Validation Accuracy: 59.09%
Epoch 4, Examples seen: 192640, Loss: 0.8282
Epoch 4, Examples seen: 193280, Loss: 0.7205
Epoch 4, Examples seen: 193920, Loss: 0.8647
Epoch 4, Examples seen: 194560, Loss: 0.8154
Epoch 4, Examples seen: 195200, Loss: 1.0802
Epoch 4, Examples seen: 195840, Loss: 0.8275
Epoch 4, Examples seen: 196480, Loss: 0.9059
Epoch 4, Examples seen: 197120, Loss: 0.6819
Epoch 4, Examples seen: 197760, Loss: 0.8291
Epoch 4, Examples seen: 198400, Loss: 1.0075
Epoch 4, Examples seen: 199040, Loss: 0.8275
Epoch 4, Examples seen: 199680, Loss: 0.9049
Epoch 4, Examples seen: 200320, Loss: 0.7934
Epoch 4, Examples seen: 200960, Loss: 0.9055
Epoch 4, Examples seen: 201600, Loss: 0.8997
Epoch 4, Examples seen: 202240, Loss: 0.9357
Epoch 4, Examples seen: 202880, Loss: 0.9727
Epoch 4, Examples seen: 203520, Loss: 0.8489
Epoch 4, Examples seen: 204160, Loss: 0.9801
Epoch 4, Examples seen: 204800, Loss: 0.9053
Epoch 4, Examples seen

 50%|█████     | 5/10 [04:16<04:17, 51.44s/it]

Epoch 5/10, Loss: 0.9475, Train Accuracy: 59.09%, Validation Accuracy: 59.15%
Epoch 5, Examples seen: 240640, Loss: 0.9354
Epoch 5, Examples seen: 241280, Loss: 1.0793
Epoch 5, Examples seen: 241920, Loss: 1.0435
Epoch 5, Examples seen: 242560, Loss: 1.1938
Epoch 5, Examples seen: 243200, Loss: 0.9038
Epoch 5, Examples seen: 243840, Loss: 0.9010
Epoch 5, Examples seen: 244480, Loss: 0.9354
Epoch 5, Examples seen: 245120, Loss: 1.0437
Epoch 5, Examples seen: 245760, Loss: 1.0086
Epoch 5, Examples seen: 246400, Loss: 0.7919
Epoch 5, Examples seen: 247040, Loss: 1.1154
Epoch 5, Examples seen: 247680, Loss: 1.0415
Epoch 5, Examples seen: 248320, Loss: 1.2303
Epoch 5, Examples seen: 248960, Loss: 0.9073
Epoch 5, Examples seen: 249600, Loss: 1.0899
Epoch 5, Examples seen: 250240, Loss: 0.6841
Epoch 5, Examples seen: 250880, Loss: 1.1873
Epoch 5, Examples seen: 251520, Loss: 1.2354
Epoch 5, Examples seen: 252160, Loss: 0.9749
Epoch 5, Examples seen: 252800, Loss: 0.8462
Epoch 5, Examples seen

 60%|██████    | 6/10 [04:59<03:14, 48.65s/it]

Epoch 6/10, Loss: 0.9468, Train Accuracy: 59.12%, Validation Accuracy: 59.06%
Epoch 6, Examples seen: 288640, Loss: 0.8996
Epoch 6, Examples seen: 289280, Loss: 0.8639
Epoch 6, Examples seen: 289920, Loss: 0.9178
Epoch 6, Examples seen: 290560, Loss: 0.9756
Epoch 6, Examples seen: 291200, Loss: 1.1153
Epoch 6, Examples seen: 291840, Loss: 1.0015
Epoch 6, Examples seen: 292480, Loss: 0.9007
Epoch 6, Examples seen: 293120, Loss: 0.9716
Epoch 6, Examples seen: 293760, Loss: 1.0139
Epoch 6, Examples seen: 294400, Loss: 1.1579
Epoch 6, Examples seen: 295040, Loss: 0.9871
Epoch 6, Examples seen: 295680, Loss: 1.1165
Epoch 6, Examples seen: 296320, Loss: 0.7916
Epoch 6, Examples seen: 296960, Loss: 0.7926
Epoch 6, Examples seen: 297600, Loss: 0.7556
Epoch 6, Examples seen: 298240, Loss: 0.8278
Epoch 6, Examples seen: 298880, Loss: 1.1166
Epoch 6, Examples seen: 299520, Loss: 1.0435
Epoch 6, Examples seen: 300160, Loss: 0.8675
Epoch 6, Examples seen: 300800, Loss: 0.9013
Epoch 6, Examples seen

 70%|███████   | 7/10 [06:12<02:49, 56.59s/it]

Epoch 7/10, Loss: 0.9454, Train Accuracy: 59.14%, Validation Accuracy: 59.17%
Epoch 7, Examples seen: 336640, Loss: 0.6856
Epoch 7, Examples seen: 337280, Loss: 1.0075
Epoch 7, Examples seen: 337920, Loss: 0.8851
Epoch 7, Examples seen: 338560, Loss: 0.5758
Epoch 7, Examples seen: 339200, Loss: 0.8635
Epoch 7, Examples seen: 339840, Loss: 1.1513
Epoch 7, Examples seen: 340480, Loss: 0.7556
Epoch 7, Examples seen: 341120, Loss: 1.1153
Epoch 7, Examples seen: 341760, Loss: 0.7925
Epoch 7, Examples seen: 342400, Loss: 1.0008
Epoch 7, Examples seen: 343040, Loss: 1.0434
Epoch 7, Examples seen: 343680, Loss: 0.7557
Epoch 7, Examples seen: 344320, Loss: 1.0793
Epoch 7, Examples seen: 344960, Loss: 0.9716
Epoch 7, Examples seen: 345600, Loss: 0.9354
Epoch 7, Examples seen: 346240, Loss: 0.9718
Epoch 7, Examples seen: 346880, Loss: 1.0441
Epoch 7, Examples seen: 347520, Loss: 1.0481
Epoch 7, Examples seen: 348160, Loss: 0.7918
Epoch 7, Examples seen: 348800, Loss: 0.9714
Epoch 7, Examples seen

 80%|████████  | 8/10 [07:29<02:05, 62.97s/it]

Epoch 8/10, Loss: 0.9433, Train Accuracy: 59.18%, Validation Accuracy: 59.08%
Epoch 8, Examples seen: 384640, Loss: 1.2254
Epoch 8, Examples seen: 385280, Loss: 0.9088
Epoch 8, Examples seen: 385920, Loss: 1.1570
Epoch 8, Examples seen: 386560, Loss: 1.0796
Epoch 8, Examples seen: 387200, Loss: 0.8276
Epoch 8, Examples seen: 387840, Loss: 0.8995
Epoch 8, Examples seen: 388480, Loss: 0.9354
Epoch 8, Examples seen: 389120, Loss: 1.0453
Epoch 8, Examples seen: 389760, Loss: 1.0465
Epoch 8, Examples seen: 390400, Loss: 0.6478
Epoch 8, Examples seen: 391040, Loss: 0.9719
Epoch 8, Examples seen: 391680, Loss: 0.9768
Epoch 8, Examples seen: 392320, Loss: 0.9010
Epoch 8, Examples seen: 392960, Loss: 1.1153
Epoch 8, Examples seen: 393600, Loss: 0.7307
Epoch 8, Examples seen: 394240, Loss: 1.1513
Epoch 8, Examples seen: 394880, Loss: 1.0795
Epoch 8, Examples seen: 395520, Loss: 0.9223
Epoch 8, Examples seen: 396160, Loss: 1.0558
Epoch 8, Examples seen: 396800, Loss: 0.8635
Epoch 8, Examples seen

 90%|█████████ | 9/10 [08:37<01:04, 64.58s/it]

Epoch 9/10, Loss: 0.9442, Train Accuracy: 59.16%, Validation Accuracy: 59.14%
Epoch 9, Examples seen: 432640, Loss: 1.0077
Epoch 9, Examples seen: 433280, Loss: 1.1191
Epoch 9, Examples seen: 433920, Loss: 1.1873
Epoch 9, Examples seen: 434560, Loss: 0.7209
Epoch 9, Examples seen: 435200, Loss: 0.9356
Epoch 9, Examples seen: 435840, Loss: 0.9354
Epoch 9, Examples seen: 436480, Loss: 0.7920
Epoch 9, Examples seen: 437120, Loss: 0.9002
Epoch 9, Examples seen: 437760, Loss: 1.1873
Epoch 9, Examples seen: 438400, Loss: 0.6480
Epoch 9, Examples seen: 439040, Loss: 0.9721
Epoch 9, Examples seen: 439680, Loss: 1.1732
Epoch 9, Examples seen: 440320, Loss: 0.8702
Epoch 9, Examples seen: 440960, Loss: 1.0080
Epoch 9, Examples seen: 441600, Loss: 0.8644
Epoch 9, Examples seen: 442240, Loss: 0.7556
Epoch 9, Examples seen: 442880, Loss: 0.8640
Epoch 9, Examples seen: 443520, Loss: 1.0075
Epoch 9, Examples seen: 444160, Loss: 0.8643
Epoch 9, Examples seen: 444800, Loss: 1.0105
Epoch 9, Examples seen

100%|██████████| 10/10 [10:12<00:00, 61.26s/it]


Epoch 10/10, Loss: 0.9415, Train Accuracy: 59.23%, Validation Accuracy: 59.30%


100%|██████████| 157/157 [00:11<00:00, 13.99it/s]


Test Accuracy: 10.84%


100%|██████████| 157/157 [00:29<00:00,  5.25it/s]

Test Accuracy: 9.76%





0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇████
loss,▅▆▁▇▄▂▆▆▂▅▃▂▄▂▆▆▅▁▅▆▄▃▆▅▇▆▆▇▆▇▅▃█▄▆▆▆▆▂▅
test_accuracy,▁
test_accuracy_after_averaging,▁
train_accuracy,▁▆▇▇▇▇████
train_loss,█▃▂▂▁▁▁▁▁▁
val_accuracy,▁▇▆▆▆▅▇▆▆█

0,1
epoch,9.0
loss,0.82751
test_accuracy,10.84
test_accuracy_after_averaging,9.76
train_accuracy,59.23333
train_loss,0.9415
val_accuracy,59.3


In [5]:

# Load ResNet-18 model
def load_resnet18(n_classes):
    model = models.resnet18(pretrained=True)
    # Modify the final fully connected layer to have n_classes output units
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model