In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
class ConvNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16*7*7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [3]:
# sets device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvNet(1, 10).to(device)

In [4]:
# load data
train_data = datasets.MNIST(
    root='./datasets',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

test_data = datasets.MNIST(
    root='./datasets',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

test_loader = DataLoader(test_data, batch_size=64, shuffle=True)

In [5]:
for x, y in train_loader:
    print(x.shape)
    print(y.shape)
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])


In [6]:
# sets hyperparameters
in_features = 784
num_classes = 10
learning_rate = 0.001
num_epochs = 5

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
import time
import os

def save_checkpoint(states, output_dir=None):
    print('=> Saving checkpoints')
    if output_dir is None:
        output_dir = 'checkpoints'
        if output_dir not in os.listdir():
            os.makedirs(output_dir)
    elif output_dir not in os.listdir():
            os.makedirs(output_dir)
    output_file = 'checkpoint-' + str(time.time()) + '.pth.tar'
    output_file = os.path.join(output_dir, output_file)
    torch.save(states, output_file)


def save_best_model(states, output_dir=None, output_file=None):
    print('=> Saving best model')
    if output_dir is None:
        output_dir = 'best_models'
        os.makedirs(output_dir)
    elif output_dir not in os.listdir():
            os.makedirs(output_dir)
    if output_file is None:
        output_file = 'best_model.pth.tar'
    output = os.path.join(output_dir, output_file)
    torch.save(states, output)

In [9]:
for epoch in range(num_epochs):
    losses = []

    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        scores = model(data)
        loss = criterion(scores, targets)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('Batch {}| Loss {}'.format(batch_idx, loss))
    
    states = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    save_checkpoint(states)


Batch 0| Loss 2.2918012142181396
Batch 1| Loss 2.301578998565674
Batch 2| Loss 2.3022680282592773
Batch 3| Loss 2.285740375518799
Batch 4| Loss 2.282137632369995
Batch 5| Loss 2.2516350746154785
Batch 6| Loss 2.287684440612793
Batch 7| Loss 2.2792458534240723
Batch 8| Loss 2.2448251247406006
Batch 9| Loss 2.2250280380249023
Batch 10| Loss 2.2435569763183594
Batch 11| Loss 2.2202939987182617
Batch 12| Loss 2.211193323135376
Batch 13| Loss 2.229323625564575
Batch 14| Loss 2.221078395843506
Batch 15| Loss 2.2041923999786377
Batch 16| Loss 2.1452157497406006
Batch 17| Loss 2.174067974090576
Batch 18| Loss 2.17108154296875
Batch 19| Loss 2.0646724700927734
Batch 20| Loss 2.1489622592926025
Batch 21| Loss 2.165163993835449
Batch 22| Loss 2.0403077602386475
Batch 23| Loss 2.085258960723877
Batch 24| Loss 2.0620534420013428
Batch 25| Loss 2.0384387969970703
Batch 26| Loss 1.9626272916793823
Batch 27| Loss 1.9910156726837158
Batch 28| Loss 1.9343523979187012
Batch 29| Loss 1.9267672300338745
Ba

In [21]:
checkpoint = torch.load('./checkpoints/checkpoint-1645412551.918868.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [22]:
def check_accuracy(loader, model):
    correct = 0
    total = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            scores = model(x)
            _, predictions = scores.max(1)
            correct += (predictions == y).sum()
            total += predictions.size(0)
        model.train()
        print('Accuracy: ', correct/total)

In [23]:
check_accuracy(test_loader, model)

Accuracy:  tensor(0.9872)
