In [None]:
import torch
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from matplotlib import pyplot as plt

# custom functions and classes
from simple_mnist import Model, train, evaluate

# Step1: Load the train and test data

In [None]:
trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))])
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=trans)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, suffle=True)

testset = datasets.MNIST(root="./data", train=False, download=True, transform=trans)
test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=True)

# Step2: Visualize the MNIST data

In [None]:
samples = iter(test_loader)
sample_data, sample_targets = samples.next()

print("Sample data size: ", sample_data.shape)

fig = plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.tight_layout()
    plt.imshow(sample_data[i][0], cmap="gray", interpolation="none")
    plt.title(f"Ground Truth: {sample_targets[i]}")
    plt.xticks([])
    plt.yticks([])
plt.show()

# Step3: Load CNN model and optimizer

In [None]:
model = Model()
model.to("cuda")

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# Step5: Train the network

In [None]:
train_losses = []
for epoch in range(5):
    losses = train(model, train_loader, optimizer, epoch)
    train_losses += losses
    test_loss, test_accuracy = evaluate(model, test_loader)
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.1f}%\n")

plt.figure(figsize=(7, 5))
plt.plot(train_losses)
plt.xlabel("Iterations")
plt.ylabel("Train loss")
plt.show()

# Step6: Visualize the trained network predictions

In [None]:
model.eval()

with torch.no_grad():
    output = model(sample_data.to("device"))

fig = plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.tight_layout()
    plt.imshow(sample_data[i][0], cmap="gray", interpolation="none")
    plt.title(f"Prediction: {output.data.max(1, keepdim=True)[1][i].item()}")
    plt.xticks([])
    plt.yticks([])
plt.show()

# Step7: Save the trained parameters

In [None]:
torch.save(model.state_dict(), "mnist_cnn.pth")