In [13]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms 
from torch.utils.data import DataLoader

In [14]:
training_epochs = 15
batch_size = 100

root = './data'
mnist_train = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dset.MNIST(root=root, train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(mnist_test, batch_size = batch_size, shuffle = False)


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
linear = torch.nn.Linear(784, 10, bias = True).to(device)
torch.nn.init.normal_(linear.weight)

Parameter containing:
tensor([[ 0.6558, -0.9865, -0.3268,  ...,  1.3842, -0.4537,  2.0292],
        [-0.4716,  0.4212,  0.8715,  ...,  0.6578,  0.1918, -0.0359],
        [ 0.8237, -0.7472, -0.8979,  ..., -0.2953,  0.3836,  0.6029],
        ...,
        [-0.8371,  0.8613,  0.5890,  ...,  0.3800,  0.6788, -0.5509],
        [-0.5659, -0.3064, -0.8841,  ..., -0.2912,  0.2062, -1.0937],
        [ 0.6580,  0.0699,  1.1987,  ...,  1.0337, -0.7570, -0.0434]],
       requires_grad=True)

In [16]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)

In [None]:
for epoch in range(training_epochs):
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)        
        imgs = imgs.view(-1, 28*28)

        outputs = linear(imgs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _,argmax = torch.max(outputs, 1)
        accuracy = (labels == argmax.to(device)).float().mean()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {: .4f}, Accuracy: {: .2f}%'.format(
                epoch+1, training_epochs, i+1, len(train_loader), loss.item(), accuracy.item()* 100))


In [18]:
linear.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for i, (imgs, labels) in enumerate(test_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        imgs = imgs.view(-1, 28 * 28)

        outputs = linear(imgs)

        _, argmax = torch.max(outputs, 1) 
        total += imgs.size(0)
        correct += (labels == argmax).sum().item()
    
    print('Test accuracy for {} images: {: .2f}%'.format(total, correct / total * 100))

Test accuracy for 10000 images:  89.31%
