In [None]:
%load_ext autoreload
%autoreload 2

from loguru import logger
from tqdm import tqdm

from access_pytorch import config

In [None]:
import torch

from models.nets import MNIST_cnn
from access_pytorch.dataset import pull_mnist

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from access_pytorch.modeling.train import mnist_train

from access_pytorch.dataset import mnist_loaders

mnist_model = MNIST_cnn().to(device)

train_kwargs = {'batch_size': 64}
test_kwargs={'batch_size': 1000}
optim_kwargs = {'lr': 1.0}

if torch.cuda.is_available():
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

log_kwargs = {'log_interval': 50}

mnist_train_loader, mnist_test_loader = mnist_loaders(train_kwargs, test_kwargs)

optimizer = optim.Adadelta(mnist_model.parameters(), **optim_kwargs)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7) # lr_{i+step_size} = gamma * lr_i

for epoch in range(1):
    mnist_train(args=log_kwargs, model=mnist_model, device=device, train_loader=mnist_train_loader, optimizer=optimizer, epoch=epoch)

In [None]:
from access_pytorch.modeling.test import mnist_test

mnist_test(model=mnist_model, device=device, test_loader=mnist_test_loader)

In [None]:
torch.save(mnist_model.state_dict(), f"{config.MODELS_DIR}/mnist_cnn.pth")
mnist_model = MNIST_cnn().to(device)
mnist_model.load_state_dict(torch.load(f"{config.MODELS_DIR}/mnist_cnn.pth"))