In [1]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch
import torch.nn as nn


In [16]:
class FCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.moduledict = nn.ModuleDict({
            'fc1' : nn.Linear(784, 128), # All img in MINST are 28x28, 28*28=784
            'fc2' : nn.Linear(128, 128),
            'relu': nn.ReLU(),
            'out' : nn.Linear(128, 10)
        })
    def forward(self, x):
        # x = x.view(-1, 784)
        x = x.view(x.size(0), -1)
        x = self.moduledict.fc1(x)
        x = self.moduledict.relu(x)
        x = self.moduledict.fc2(x)
        x = self.moduledict.relu(x)
        return self.moduledict.out(x)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [11]:
train_set = datasets.MNIST('dataset/mnist/', train=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_set = datasets.MNIST('dataset/mnist/', train=False, transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
from numpy import shape
from tqdm import tqdm


model = FCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
epoch_num = 5

model.train()

for epoch in range(epoch_num):
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epoch_num}'):
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

100%|██████████| 938/938 [00:01<00:00, 588.49it/s]
100%|██████████| 938/938 [00:01<00:00, 629.92it/s]
100%|██████████| 938/938 [00:01<00:00, 642.57it/s]
100%|██████████| 938/938 [00:01<00:00, 583.99it/s]
100%|██████████| 938/938 [00:01<00:00, 589.06it/s]


In [13]:
model.eval()
total_correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        output = model(images)
        _, predicted_class = torch.max(output.data, 1)

        total += labels.size(0)
        total_correct += (predicted_class == labels).sum().item()
print(f'Accuracy: {100 * total_correct / total:.2f}%')

100%|██████████| 157/157 [00:00<00:00, 670.93it/s]

Accuracy: 97.78%



