In [9]:
import torch
from torch import nn
from torch.nn import functional as F

from torch.utils.data import DataLoader

import torchvision

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# MNIST Dataset
train_data = torchvision.datasets.MNIST(
    root='data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.MNIST(
    root='data',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 11538600.22it/s]


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 28889981.83it/s]


Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 8747527.41it/s]


Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw






In [5]:
print(f'Shape of train data: {train_data.data.shape}')
print(f'Shape of test data: {test_data.data.shape}')

Shape of train data: torch.Size([60000, 28, 28])
Shape of test data: torch.Size([10000, 28, 28])


In [8]:
train_data.targets

tensor([5, 0, 4,  ..., 5, 6, 8])

In [12]:
train_dataloader = DataLoader(
    train_data,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

test_dataloader = DataLoader(
    test_data,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

In [23]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [24]:
model = Net().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [35]:
n_epochs = 10

In [38]:
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_dataloader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 20 == 0:
            print(f'Train epoch: {epoch}, [{batch_idx*len(data)}/{len(train_dataloader.dataset)}, ({100*batch_idx/len(train_dataloader):.3f}%)], loss: {loss.item():.6f}')

def test():
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_dataloader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_loss += loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataloader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_dataloader.dataset)} ({correct/(len(test_dataloader.dataset)):.3f}%)')

In [39]:
for epoch in range(1, n_epochs+1):
    train(epoch)
    test()

  return F.log_softmax(x)



Test set: Average loss: 0.0012, Accuracy: 9747/10000 (0.975%)

Test set: Average loss: 0.0010, Accuracy: 9796/10000 (0.980%)

Test set: Average loss: 0.0009, Accuracy: 9828/10000 (0.983%)

Test set: Average loss: 0.0008, Accuracy: 9847/10000 (0.985%)

Test set: Average loss: 0.0007, Accuracy: 9862/10000 (0.986%)

Test set: Average loss: 0.0007, Accuracy: 9859/10000 (0.986%)

Test set: Average loss: 0.0007, Accuracy: 9866/10000 (0.987%)

Test set: Average loss: 0.0006, Accuracy: 9877/10000 (0.988%)

Test set: Average loss: 0.0006, Accuracy: 9882/10000 (0.988%)

Test set: Average loss: 0.0006, Accuracy: 9881/10000 (0.988%)


In [42]:
torch.save(model.state_dict(), './model.pth')
torch.save(optimizer.state_dict(), './optimizer.pth')