In [None]:
from torch import nn, optim
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import confusion_matrix

from model import FeedforwardNeuralNet
from data import MNISTDataset, CIFAR10Dataset
from trainer import Trainer
from utils import Experiment, plot_train_losses, plot_eval_losses, plot_accuracies

epochs = 1
lr = 0.001
train_batch_size = 512
eval_batch_size = 1000
model = "resnet"

experiments = [
    Experiment("SGD", optim.SGD),
    Experiment("Adagrad", optim.Adagrad),
    Experiment("RMSprop", optim.RMSprop),
    Experiment("Adam", optim.Adam)
]

for exp in experiments:
    print(f"Training with {exp.optimizer_name} optimizer")
    if model == 'resnet':
        model = resnet50(weights=ResNet50_Weights.DEFAULT)

    criterion = nn.CrossEntropyLoss()
    dataset = CIFAR10Dataset(train_batch_size, eval_batch_size)
    trainer = Trainer(model, dataset, criterion, exp.optimizer_cls, lr)
    trainer.train(epochs)
    exp.add_trainer_state(trainer)

plot_train_losses(experiments)
plot_eval_losses(experiments)
plot_accuracies(experiments)

