In [21]:
import os
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

# Specify the directory where the dataset is stored
root_dir = './data'

# Ensure the necessary directory structure
if not os.path.exists(os.path.join(root_dir, 'MNIST', 'raw')):
    os.makedirs(os.path.join(root_dir, 'MNIST', 'raw'))


train = datasets.MNIST(root = root_dir, train = True, download=False, transform = transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST(root = root_dir, train = False, download=False, transform = transforms.Compose([transforms.ToTensor()]))


trainset = torch.utils.data.DataLoader(train, batch_size = 10, shuffle = True)
testset = torch.utils.data.DataLoader(test, batch_size = 10, shuffle = False)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.log_softmax(x, dim = 1)

net = Net()
print(net)


Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=10, bias=True)
)


In [22]:
import torch.optim as optim

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [None]:
for epoch in range(3): 
    for data in trainset:  
        X, y = data  
        net.zero_grad() 
        output = net(X.view(-1,784))  
        loss = F.nll_loss(output, y)  
        loss.backward() 
        optimizer.step()  
    print(loss)

tensor(0.1877, grad_fn=<NllLossBackward0>)


In [20]:
correct = 0
total = 0

with torch.no_grad():
    for data in testset:
        X, y = data
        output = net(X.view(-1,784))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct += 1
            total += 1

print("Accuracy: ", round(correct/total, 3))

Accuracy:  0.968
