In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

writer = SummaryWriter()

num_epochs = 10**10
batch_size = 32
lr = 0.01
path = "./mnist_model.pt"

class EarlyStopping():
    def __init__(self, patience=2, save_path=path):
        self.patience = patience
        self.save_path = save_path
        self.min_loss = float("inf")
        self.count = 0
    
    def should_stop(self, model, loss):
        if loss < self.min_loss:
            self.min_loss = loss
            self.count = 0
            torch.save(model.state_dict(), self.save_path)
        elif loss > self.min_loss:
            self.count += 1
            if self.count >= self.patience:
                return True
        return False

    def load(self, model):
        model.load_state_dict(torch.load(self.save_path))

early_stopper = EarlyStopping(patience=3)

train_data = datasets.MNIST("./", download=True, train=True, transform=transforms.ToTensor())
test_data = datasets.MNIST("./", download=True, train=False, transform=transforms.ToTensor())

train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test = DataLoader(test_data, batch_size=batch_size, shuffle=False)

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10),
    nn.ReLU(),
    nn.Linear(300, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

print(f"Train on {len(train_data)}, test on {test_data} samples")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    accuracy = 0
    for datas, labels in train:
        datas = datas.cuda()
        labels = labels.cuda()
        result: torch.Tensor = model(datas)
        answer = result.argmax(dim=1)
        correct = torch.sum(answer == labels)
        loss: torch.Tensor = criterion(result, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accuracy += correct.item()
        total_loss += loss.item()
    writer.add_scalar("Loss/train", total_loss, epoch)
    print(f"Epoch {epoch + 1}, loss: {total_loss / len(train)}, accuracy: {accuracy / len(train_data)}")
    if early_stopper.should_stop(model, total_loss):
        print(f"EarlyStopping: [Epoch: {epoch - early_stopper.count}]")
        break

writer.close()
early_stopper.load(model)

with torch.no_grad():
    model.eval()
    total_loss = 0
    accuracy = 0
    for datas, labels in test:
        datas = datas.cuda()
        labels = labels.cuda()
        result: torch.Tensor = model(datas)
        answer = result.argmax(dim=1)
        correct = torch.sum(answer == labels)
        loss: torch.Tensor = criterion(result, labels)
        accuracy += correct.item()
        total_loss += loss.item()
    print(f"test_loss: {total_loss / len(test)}, test_accuracy: {accuracy / len(test_data)}")


Train on 60000, test on Dataset MNIST
    Number of datapoints: 10000
    Root location: ./
    Split: Test
    StandardTransform
Transform: ToTensor() samples
Epoch 1, loss: 0.4149372245132923, accuracy: 0.8867666666666667
Epoch 2, loss: 0.3132713982482751, accuracy: 0.9122833333333333
Epoch 3, loss: 0.2967062597036362, accuracy: 0.9169333333333334
Epoch 4, loss: 0.28775123942693076, accuracy: 0.92
Epoch 5, loss: 0.28159374766548473, accuracy: 0.9213666666666667
Epoch 6, loss: 0.2772076519846916, accuracy: 0.9220333333333334
Epoch 7, loss: 0.2738807370642821, accuracy: 0.9236166666666666
Epoch 8, loss: 0.2709792521973451, accuracy: 0.92435
Epoch 9, loss: 0.2690669713020325, accuracy: 0.92515
Epoch 10, loss: 0.26696224703987437, accuracy: 0.9258833333333333
Epoch 11, loss: 0.26454849247336387, accuracy: 0.9266833333333333
Epoch 12, loss: 0.2633892964243889, accuracy: 0.9269833333333334
Epoch 13, loss: 0.26228023982842763, accuracy: 0.9269666666666667
Epoch 14, loss: 0.260826555532217, 