In [1]:
from src import notebook, datasets, build_models,caluclate_basis, custom_blocks, train, utils
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import wandb
from tqdm import tqdm   
from torchsummary import summary

In [2]:
def make(config):

    train_loader, val_loader, test_loader = datasets.get_dataloaders(config, logfile=None, summaryfile=None, log=False) 

    model = build_models.build_model_from_args(config)

    if config.criterion == "CrossEntropyLoss":
        criterion = nn.CrossEntropyLoss()
    elif config.criterion == "MSELoss":
        criterion = nn.MSELoss()
    else:
        raise ValueError(f"Criterion {criterion} not supported")
    
    if config.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=config.lr)
    else:
        raise ValueError(f"Optimizer {optimizer} not supported")
    
    return model, train_loader, val_loader, test_loader, criterion, optimizer

In [3]:
hyperparameters = dict(arch=['64', 'M', '128', 'M'], 
                                 batch_norm=False,
                                 bias=True,
                                 classifier_layers=[256],
                                 classifier_bias= True,
                                 classifier_dropout=0,
                                 avgpool=False,
                                 avgpool_size=[1,1],  # Adjusted avgpool size to (1, 1)
                                 dataset="rotated_mnist",
                                 input_height = 29,
                                 input_width = 29,
                                 greyscale=True,
                                 n_classes= 10,
                                 epochs=10, 
                                 batch_size=64,
                                 optimizer="Adam",
                                 weight_normalization=True,
                                 lr=0.001,
                                 criterion="CrossEntropyLoss",
                                 seed=42,
                                 save_checkpoints=False,
                                 save_dir = "",
                                 checkpoint_type = "epoch",
                                 )

In [4]:
run = wandb.init(project="custom-vgg-model", config=hyperparameters)
config = wandb.config

model, train_loader, val_loader, test_loader, criterion, optimizer = make(config)

print(model)

print(criterion, optimizer)

# and use them to train the model
train.train(config, model, train_loader, val_loader, test_loader,criterion, optimizer)
#finish run
run.finish()

[34m[1mwandb[0m: Currently logged in as: [33matsou2[0m ([33matsou2-johns-hopkins-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


CustomModel(
  (features): Sequential(
    (0): ParametrizedConv2d(
      1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ParametrizedConv2d(
      64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): _WeightNorm()
        )
      )
    )
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): CustomClassifier(
    (classifier_layers): Sequential(
      (0): Sequential(
        (0): Linear(in_features=6272, out_features=256, bias=True)
        (1): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=10, bias=True

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0, Examples seen: 640, Loss: 2.2819
Epoch 0, Examples seen: 1280, Loss: 2.0325
Epoch 0, Examples seen: 1920, Loss: 1.9982
Epoch 0, Examples seen: 2560, Loss: 1.9805
Epoch 0, Examples seen: 3200, Loss: 1.9376
Epoch 0, Examples seen: 3840, Loss: 1.9006
Epoch 0, Examples seen: 4480, Loss: 1.9878
Epoch 0, Examples seen: 5120, Loss: 2.0487
Epoch 0, Examples seen: 5760, Loss: 1.8045
Epoch 0, Examples seen: 6400, Loss: 1.8485
Epoch 0, Examples seen: 7040, Loss: 1.8977
Epoch 0, Examples seen: 7680, Loss: 1.8851
Epoch 0, Examples seen: 8320, Loss: 1.6762
Epoch 0, Examples seen: 8960, Loss: 1.8654
Epoch 0, Examples seen: 9600, Loss: 1.6141


 10%|█         | 1/10 [00:16<02:28, 16.51s/it]

Epoch 1/10, Loss: 1.9203, Train Accuracy: 28.99%, Validation Accuracy: 32.25%
Epoch 1, Examples seen: 10240, Loss: 1.7282
Epoch 1, Examples seen: 10880, Loss: 1.8254
Epoch 1, Examples seen: 11520, Loss: 1.6890
Epoch 1, Examples seen: 12160, Loss: 1.6825
Epoch 1, Examples seen: 12800, Loss: 1.9013
Epoch 1, Examples seen: 13440, Loss: 1.3417
Epoch 1, Examples seen: 14080, Loss: 1.7777
Epoch 1, Examples seen: 14720, Loss: 1.8216
Epoch 1, Examples seen: 15360, Loss: 1.6904
Epoch 1, Examples seen: 16000, Loss: 1.6905
Epoch 1, Examples seen: 16640, Loss: 1.8016
Epoch 1, Examples seen: 17280, Loss: 1.8691
Epoch 1, Examples seen: 17920, Loss: 1.7247
Epoch 1, Examples seen: 18560, Loss: 1.4720
Epoch 1, Examples seen: 19200, Loss: 1.5967


 20%|██        | 2/10 [00:32<02:10, 16.29s/it]

Epoch 2/10, Loss: 1.7265, Train Accuracy: 32.78%, Validation Accuracy: 33.12%
Epoch 2, Examples seen: 19840, Loss: 1.6587
Epoch 2, Examples seen: 20480, Loss: 1.5990
Epoch 2, Examples seen: 21120, Loss: 1.6376
Epoch 2, Examples seen: 21760, Loss: 1.7381
Epoch 2, Examples seen: 22400, Loss: 1.7148
Epoch 2, Examples seen: 23040, Loss: 1.5847
Epoch 2, Examples seen: 23680, Loss: 1.4626
Epoch 2, Examples seen: 24320, Loss: 1.6482
Epoch 2, Examples seen: 24960, Loss: 1.4849
Epoch 2, Examples seen: 25600, Loss: 1.6813
Epoch 2, Examples seen: 26240, Loss: 1.6211
Epoch 2, Examples seen: 26880, Loss: 1.5609
Epoch 2, Examples seen: 27520, Loss: 1.7388
Epoch 2, Examples seen: 28160, Loss: 1.6993
Epoch 2, Examples seen: 28800, Loss: 1.6745


 30%|███       | 3/10 [00:48<01:52, 16.07s/it]

Epoch 3/10, Loss: 1.6513, Train Accuracy: 34.66%, Validation Accuracy: 34.75%
Epoch 3, Examples seen: 29440, Loss: 1.3524
Epoch 3, Examples seen: 30080, Loss: 1.4927
Epoch 3, Examples seen: 30720, Loss: 1.8324
Epoch 3, Examples seen: 31360, Loss: 1.6671
Epoch 3, Examples seen: 32000, Loss: 1.5008
Epoch 3, Examples seen: 32640, Loss: 1.5915
Epoch 3, Examples seen: 33280, Loss: 1.3782
Epoch 3, Examples seen: 33920, Loss: 1.5282
Epoch 3, Examples seen: 34560, Loss: 1.4509
Epoch 3, Examples seen: 35200, Loss: 1.7533
Epoch 3, Examples seen: 35840, Loss: 1.5451
Epoch 3, Examples seen: 36480, Loss: 1.5805
Epoch 3, Examples seen: 37120, Loss: 1.8190
Epoch 3, Examples seen: 37760, Loss: 1.6256
Epoch 3, Examples seen: 38400, Loss: 1.5221


 40%|████      | 4/10 [01:04<01:36, 16.06s/it]

Epoch 4/10, Loss: 1.5820, Train Accuracy: 36.67%, Validation Accuracy: 35.71%
Epoch 4, Examples seen: 39040, Loss: 1.6879
Epoch 4, Examples seen: 39680, Loss: 1.5643
Epoch 4, Examples seen: 40320, Loss: 1.7038
Epoch 4, Examples seen: 40960, Loss: 1.4949
Epoch 4, Examples seen: 41600, Loss: 1.5970
Epoch 4, Examples seen: 42240, Loss: 1.6115
Epoch 4, Examples seen: 42880, Loss: 1.6682
Epoch 4, Examples seen: 43520, Loss: 1.5450
Epoch 4, Examples seen: 44160, Loss: 1.4549
Epoch 4, Examples seen: 44800, Loss: 1.7519
Epoch 4, Examples seen: 45440, Loss: 1.5007
Epoch 4, Examples seen: 46080, Loss: 1.3307
Epoch 4, Examples seen: 46720, Loss: 1.6762
Epoch 4, Examples seen: 47360, Loss: 1.6269
Epoch 4, Examples seen: 48000, Loss: 1.3676


 50%|█████     | 5/10 [01:20<01:20, 16.06s/it]

Epoch 5/10, Loss: 1.5269, Train Accuracy: 37.74%, Validation Accuracy: 36.29%
Epoch 5, Examples seen: 48640, Loss: 1.3821
Epoch 5, Examples seen: 49280, Loss: 1.5992
Epoch 5, Examples seen: 49920, Loss: 1.5184
Epoch 5, Examples seen: 50560, Loss: 1.5809
Epoch 5, Examples seen: 51200, Loss: 1.5038
Epoch 5, Examples seen: 51840, Loss: 1.4022
Epoch 5, Examples seen: 52480, Loss: 1.4738
Epoch 5, Examples seen: 53120, Loss: 1.6346
Epoch 5, Examples seen: 53760, Loss: 1.5317
Epoch 5, Examples seen: 54400, Loss: 1.5652
Epoch 5, Examples seen: 55040, Loss: 1.6795
Epoch 5, Examples seen: 55680, Loss: 1.5376
Epoch 5, Examples seen: 56320, Loss: 1.4712
Epoch 5, Examples seen: 56960, Loss: 1.3548
Epoch 5, Examples seen: 57600, Loss: 1.7718


 60%|██████    | 6/10 [01:36<01:04, 16.01s/it]

Epoch 6/10, Loss: 1.4942, Train Accuracy: 38.48%, Validation Accuracy: 36.96%
Epoch 6, Examples seen: 58240, Loss: 1.5414
Epoch 6, Examples seen: 58880, Loss: 1.7402
Epoch 6, Examples seen: 59520, Loss: 1.3523
Epoch 6, Examples seen: 60160, Loss: 1.6668
Epoch 6, Examples seen: 60800, Loss: 1.3960
Epoch 6, Examples seen: 61440, Loss: 1.6103
Epoch 6, Examples seen: 62080, Loss: 1.3086
Epoch 6, Examples seen: 62720, Loss: 1.5566
Epoch 6, Examples seen: 63360, Loss: 1.5693
Epoch 6, Examples seen: 64000, Loss: 1.4369
Epoch 6, Examples seen: 64640, Loss: 1.3000
Epoch 6, Examples seen: 65280, Loss: 1.0326
Epoch 6, Examples seen: 65920, Loss: 1.3585
Epoch 6, Examples seen: 66560, Loss: 1.7844
Epoch 6, Examples seen: 67200, Loss: 1.4884


 70%|███████   | 7/10 [01:52<00:47, 15.96s/it]

Epoch 7/10, Loss: 1.4700, Train Accuracy: 38.92%, Validation Accuracy: 37.08%
Epoch 7, Examples seen: 67840, Loss: 1.2636
Epoch 7, Examples seen: 68480, Loss: 1.4985
Epoch 7, Examples seen: 69120, Loss: 1.3377
Epoch 7, Examples seen: 69760, Loss: 1.4831
Epoch 7, Examples seen: 70400, Loss: 1.6057
Epoch 7, Examples seen: 71040, Loss: 1.5677
Epoch 7, Examples seen: 71680, Loss: 1.5523
Epoch 7, Examples seen: 72320, Loss: 1.3867
Epoch 7, Examples seen: 72960, Loss: 1.5686
Epoch 7, Examples seen: 73600, Loss: 1.4630
Epoch 7, Examples seen: 74240, Loss: 1.3949
Epoch 7, Examples seen: 74880, Loss: 1.5622
Epoch 7, Examples seen: 75520, Loss: 1.3861
Epoch 7, Examples seen: 76160, Loss: 1.6689
Epoch 7, Examples seen: 76800, Loss: 1.3509


 80%|████████  | 8/10 [02:08<00:31, 15.96s/it]

Epoch 8/10, Loss: 1.4564, Train Accuracy: 39.12%, Validation Accuracy: 37.00%
Epoch 8, Examples seen: 77440, Loss: 1.4969
Epoch 8, Examples seen: 78080, Loss: 1.6493
Epoch 8, Examples seen: 78720, Loss: 1.3793
Epoch 8, Examples seen: 79360, Loss: 1.3786
Epoch 8, Examples seen: 80000, Loss: 1.5033
Epoch 8, Examples seen: 80640, Loss: 1.3470
Epoch 8, Examples seen: 81280, Loss: 1.3436
Epoch 8, Examples seen: 81920, Loss: 1.3915
Epoch 8, Examples seen: 82560, Loss: 1.5324
Epoch 8, Examples seen: 83200, Loss: 1.4674
Epoch 8, Examples seen: 83840, Loss: 1.5296
Epoch 8, Examples seen: 84480, Loss: 1.5711
Epoch 8, Examples seen: 85120, Loss: 1.3577
Epoch 8, Examples seen: 85760, Loss: 1.6065
Epoch 8, Examples seen: 86400, Loss: 1.1534


 90%|█████████ | 9/10 [02:24<00:15, 15.97s/it]

Epoch 9/10, Loss: 1.4460, Train Accuracy: 39.22%, Validation Accuracy: 37.25%
Epoch 9, Examples seen: 87040, Loss: 1.3525
Epoch 9, Examples seen: 87680, Loss: 1.4424
Epoch 9, Examples seen: 88320, Loss: 1.2068
Epoch 9, Examples seen: 88960, Loss: 1.4802
Epoch 9, Examples seen: 89600, Loss: 1.2880
Epoch 9, Examples seen: 90240, Loss: 1.2409
Epoch 9, Examples seen: 90880, Loss: 1.2902
Epoch 9, Examples seen: 91520, Loss: 1.4021
Epoch 9, Examples seen: 92160, Loss: 1.6415
Epoch 9, Examples seen: 92800, Loss: 1.3445
Epoch 9, Examples seen: 93440, Loss: 1.2209
Epoch 9, Examples seen: 94080, Loss: 1.1114
Epoch 9, Examples seen: 94720, Loss: 1.5219
Epoch 9, Examples seen: 95360, Loss: 1.4311
Epoch 9, Examples seen: 96000, Loss: 1.5288


100%|██████████| 10/10 [02:40<00:00, 16.01s/it]


Epoch 10/10, Loss: 1.4275, Train Accuracy: 39.65%, Validation Accuracy: 37.38%


100%|██████████| 782/782 [00:15<00:00, 51.02it/s] 

Test Accuracy: 37.56%





0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇████
loss,█▆▆▅▅▅▅▅▅▄▄▃▅▄▃▃▃▂▃▂▃▃▅▃▂▂▂▃▃▂▃▄▃▄▄▃▂▃▃▁
test_accuracy,▁
train_accuracy,▁▃▅▆▇▇████
train_loss,█▅▄▃▂▂▂▁▁▁
val_accuracy,▁▂▄▆▇▇█▇██

0,1
epoch,9.0
loss,1.52879
test_accuracy,37.562
train_accuracy,39.64583
train_loss,1.42747
val_accuracy,37.375


In [5]:
run.finish()

In [6]:
# Load ResNet-18 model
def load_resnet18(n_classes):
    model = models.resnet18(pretrained=True)
    # Modify the final fully connected layer to have n_classes output units
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model