In [7]:
import os
os.chdir('..')

# Step 1: Download the MNIST dataset

In [8]:
!pip3 install torchvision



In [10]:
from torchvision.transforms import Compose, Normalize, ToTensor
import torchvision

transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=transform)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 26267556.94it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 22445005.34it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 11711964.13it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 33305120.22it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






# Step 2: Build the model

In [13]:
from Adversarial_Observation import utils

In [15]:
model = utils.build_cnn()
print(model)

Sequential(
  (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=784, out_features=10, bias=True)
  (8): Softmax(dim=1)
)


In [16]:
# define the training loop
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# define the testing loop
def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)

    return test_loss, test_accuracy

In [18]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=True)

for epoch in range(10):
    train(model, train_loader, optimizer, criterion, device)
    test_loss, test_accuracy = test(model, test_loader, criterion, device)
    print('Epoch: {}, Test Loss: {}, Test Accuracy: {}'.format(epoch, test_loss, test_accuracy))

Epoch: 0, Test Loss: 0.023607329082489012, Test Accuracy: 95.97
Epoch: 1, Test Loss: 0.0234849182844162, Test Accuracy: 96.62
Epoch: 2, Test Loss: 0.023343893492221834, Test Accuracy: 97.51
Epoch: 3, Test Loss: 0.023301265048980713, Test Accuracy: 97.69
Epoch: 4, Test Loss: 0.023230479300022124, Test Accuracy: 98.21
Epoch: 5, Test Loss: 0.023256595289707183, Test Accuracy: 98.07
Epoch: 6, Test Loss: 0.023198915433883666, Test Accuracy: 98.46
Epoch: 7, Test Loss: 0.023237432789802552, Test Accuracy: 98.1
Epoch: 8, Test Loss: 0.023247943019866945, Test Accuracy: 98.1
Epoch: 9, Test Loss: 0.023253062796592713, Test Accuracy: 98.05


In [19]:
# save the model
torch.save(model.state_dict(), './MNIST_CNN.pt')