In [1]:
import os
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from model import Net, train, test

In [2]:
cwd = os.getcwd()
path = os.path
pjoin = path.join

use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

In [3]:
print(f"The current working directory is: {cwd}")

print(f"Cuda available: {use_cuda}")
print(f"macOS GPU training: {use_mps}")

The current working directory is: /code
Cuda available: True
macOS GPU training: False


In [4]:
if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [5]:
print(f"Device is: {device}")

Device is: cuda


In [6]:
batch_size = 128
epochs = 6
learning_rate = 1.0
gamma = 0.7
seed = 1
log_interval = 200
save_model=False

torch.manual_seed(seed)

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('./data', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
for epoch in range(1, epochs + 1):
    train(log_interval, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

if save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")


Test set: Average loss: 0.0471, Accuracy: 9835/10000 (98%)


Test set: Average loss: 0.0361, Accuracy: 9877/10000 (99%)


Test set: Average loss: 0.0306, Accuracy: 9899/10000 (99%)


Test set: Average loss: 0.0269, Accuracy: 9907/10000 (99%)


Test set: Average loss: 0.0245, Accuracy: 9909/10000 (99%)


Test set: Average loss: 0.0249, Accuracy: 9916/10000 (99%)

