In [None]:
from src import notebook, datasets, build_models,caluclate_basis, custom_blocks, train, utils
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import wandb
from tqdm import tqdm   
from torchsummary import summary

In [None]:
def make(config):

    train_loader, val_loader, test_loader = datasets.get_dataloaders(config, logfile=None, summaryfile=None, log=False) 

    model = build_models.build_model_from_args(config)

    if config.criterion == "CrossEntropyLoss":
        criterion = nn.CrossEntropyLoss()
    elif config.criterion == "MSELoss":
        criterion = nn.MSELoss()
    else:
        raise ValueError(f"Criterion {criterion} not supported")
    
    if config.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=config.lr)
    else:
        raise ValueError(f"Optimizer {optimizer} not supported")
    
    return model, train_loader, val_loader, test_loader, criterion, optimizer

In [None]:
hyperparameters = dict(arch=['64', 'M', '128', 'M'], 
                                 batch_norm=False,
                                 bias=True,
                                 classifier_layers=[256],
                                 classifier_bias= True,
                                 classifier_dropout= 0,
                                 avgpool=True,
                                 avgpool_size=[1,1],  # Adjusted avgpool size to (1, 1)
                                 dataset="mnist",
                                 input_height = 28,
                                 input_width = 28,
                                 greyscale=True,
                                 n_classes= 10,
                                 epochs=10, 
                                 batch_size=64,
                                 optimizer="Adam",
                                 weight_normalization=False,
                                 lr=0.001,
                                 criterion="CrossEntropyLoss",
                                 seed=42,
                                 save_checkpoints=False,
                                 save_dir = "",
                                 checkpoint_type = "epoch",
                                 )

In [None]:
run = wandb.init(project="custom-vgg-model", config=hyperparameters)
config = wandb.config

model, train_loader, val_loader, test_loader, criterion, optimizer = make(config)

print(model)

# and use them to train the model
train.train(config, model, train_loader, val_loader, test_loader,criterion, optimizer)
#finish run
run.finish()

In [None]:
# Load ResNet-18 model
def load_resnet18(n_classes):
    model = models.resnet18(pretrained=True)
    # Modify the final fully connected layer to have n_classes output units
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model