In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import Tensor

In [4]:
class NN(nn.Module):
    def __init__(self, input_size: int, num_classes: int) -> None:
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        
    def forward(self, x: Tensor):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size = 784
num_classes = 10
batch_size = 64
num_epochs = 1
learning_rate = 0.001

In [6]:
train_dataset  = datasets.MNIST(root='./datasets', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset  = datasets.MNIST(root='./datasets', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 11532870.89it/s]


Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 14704502.77it/s]


Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 11122927.32it/s]


Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10342306.61it/s]


Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw



In [21]:
model = NN(input_size=input_size, num_classes=num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)

In [20]:
test = torch.rand((64, 1, 28, 28))
test.reshape(test.shape[0], -1).shape

torch.Size([64, 784])

In [31]:
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        data = data.reshape(data.shape[0], -1)
        scores = model(data)
        loss = loss_fn(scores, targets)
        
        # backward
        loss.backward()
        
        # gradient descent
        optimizer.zero_grad()
        optimizer.step()
        
        print(f'loss: {loss}')
            
            


loss: 0.12298229336738586
loss: 0.12354741990566254
loss: 0.12160138785839081
loss: 0.11580488830804825
loss: 0.14661163091659546
loss: 0.11438895761966705
loss: 0.08118384331464767
loss: 0.09254367649555206
loss: 0.23436498641967773
loss: 0.19459864497184753
loss: 0.14921513199806213
loss: 0.07203206419944763
loss: 0.19980569183826447
loss: 0.12475406378507614
loss: 0.08285287767648697
loss: 0.1534738689661026
loss: 0.12209007889032364
loss: 0.06759630888700485
loss: 0.080714151263237
loss: 0.1392904371023178
loss: 0.1911219358444214
loss: 0.08377555012702942
loss: 0.14705750346183777
loss: 0.0856703445315361
loss: 0.20958732068538666
loss: 0.24413008987903595
loss: 0.1477138251066208
loss: 0.05085792392492294
loss: 0.2195497304201126
loss: 0.100023053586483
loss: 0.07757953554391861
loss: 0.026892930269241333
loss: 0.2984291911125183
loss: 0.10246266424655914
loss: 0.19634780287742615
loss: 0.07858657091856003
loss: 0.10796257853507996
loss: 0.0961274728178978
loss: 0.109975993633270

In [29]:
def check_accuracy(loader:DataLoader, model:NN):
    if loader.dataset.train:
        print('training data accuracy')
    else:
        print('test data accuracy')
        
    num_correct = 0
    num_samples = 0
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)
            
            scores = model(x)
            _, pred = scores.max(1)
            num_correct += (pred == y).sum()
            num_samples += pred.size(0)
        
            print(f'correct: {num_correct}, samples: {num_samples}, accuracy: {float(num_correct)/float(num_samples)*100:.2f}')
        
        model.train()
        
check_accuracy(train_loader, model)
check_accuracy(test_loader, model)


training data accuracy
correct: 62, samples: 64, accuracy: 96.88
correct: 125, samples: 128, accuracy: 97.66
correct: 188, samples: 192, accuracy: 97.92
correct: 249, samples: 256, accuracy: 97.27
correct: 310, samples: 320, accuracy: 96.88
correct: 372, samples: 384, accuracy: 96.88
correct: 435, samples: 448, accuracy: 97.10
correct: 498, samples: 512, accuracy: 97.27
correct: 559, samples: 576, accuracy: 97.05
correct: 617, samples: 640, accuracy: 96.41
correct: 676, samples: 704, accuracy: 96.02
correct: 739, samples: 768, accuracy: 96.22
correct: 800, samples: 832, accuracy: 96.15
correct: 863, samples: 896, accuracy: 96.32
correct: 925, samples: 960, accuracy: 96.35
correct: 985, samples: 1024, accuracy: 96.19
correct: 1047, samples: 1088, accuracy: 96.23
correct: 1108, samples: 1152, accuracy: 96.18
correct: 1168, samples: 1216, accuracy: 96.05
correct: 1230, samples: 1280, accuracy: 96.09
correct: 1293, samples: 1344, accuracy: 96.21
correct: 1356, samples: 1408, accuracy: 96.3