In [3]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
import torchmetrics
import numpy as np

In [6]:
import wandb
import os
from dotenv import load_dotenv

load_dotenv()
!wandb login {os.getenv("WANDB_API_KEY")}

In [4]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [16]:
import sys
sys.path.append('../src/')

import engine
import experiments_maker

modes = ["disabled", "online"]
device = "cuda" if torch.cuda.is_available() else "cpu"

epochs = 50
lr = 1e-3
classes = 10
loss_fn = nn.CrossEntropyLoss()
accuracy_fn = torchmetrics.Accuracy(task='multiclass', num_classes=classes, average='macro')

batch_size = 16

# datasets_names = ['mnist', 'tmnist','fashion_mnist', 'cifar10']
# optimizers_names = ['SGD', 'HessianFree', 'PB_BFGS', 'K_BFGS', 'K_LBFGS']
# models_names = ['SmallCNN', 'DepthCNN', 'WidthCNN', 'DepthWidthCNN']

# experimental run v0.1
datasets_names = ['tmnist']
optimizers_names = ['SGD']#, 'HessianFree']
models_names = ['SmallCNN']

for dataset_name in datasets_names:
    for optimizer_name in optimizers_names:
        for model_name in models_names:
            config = dict(
                epochs=epochs,
                classes=classes,
                learning_rate=lr,
                batch_size=batch_size,
                dataset=dataset_name, # iterable
                optimizer=optimizer_name, # iterable
                model=model_name, # iterable
                architecture="CNN",
                wandb_log="all",
                wandb_log_freq=1,
            )

            with wandb.init(project="baselines_cnn", config=config, mode=modes[0]):
                config = wandb.config
                # make the model, data and optimization
                model, train_dataloader, test_dataloader, optimizer = experiments_maker.make(config, device)
                engine.train(model, train_dataloader, test_dataloader, loss_fn, optimizer, accuracy_fn, device, config)

            wandb.finish()

  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 