In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

class Model:
    def __init__(self, net, cost, optimist):
        self.net = net
        self.cost = self.create_cost(cost)
        self.optimizer = self.create_optimizer(optimist)

    def create_cost(self, cost):
        support_cost = {
            'CROSS_ENTROPY': nn.CrossEntropyLoss(),
            'MSE': nn.MSELoss()
        }

        return support_cost[cost]

    def create_optimizer(self, optimist, **rests):
        support_optim = {
            'SGD': optim.SGD(self.net.parameters(), lr=0.1, **rests),
            'ADAM': optim.Adam(self.net.parameters(), lr=0.01, **rests),
            'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rests)
        }

        return support_optim[optimist]

    def train(self, train_loader, epoches=5):
        for epoch in range(epoches):
            running_loss = 0.0
            for i, data in enumerate(train_loader, 0): # start = 0
                inputs, labels = data

                self.optimizer.zero_grad() # Resets the gradients of all optimized tensor, set_to_none=True means instead of setting to zero, set the grads to None.

                # forward + backward + optimize
                outputs = self.net(inputs)
                loss = self.cost(outputs, labels)
                loss.backward() # Compute gradients of the parameters respect to the loss
                self.optimizer.step() # updates the parameters

                running_loss += loss.item()
                if i % 100 == 0:
                    print('[epoch %d, %.2f%%] loss: %.3f' %
                          (epoch + 1, (i + 1)*100/len(train_loader), running_loss))
                    running_loss = 0.0

        print('Finished Training')

    def evaluate(self, test_loader):
        print('Evaluating ...')
        correct = 0
        total = 0
        with torch.no_grad():  # no grad when test and predict
            for data in test_loader:
                images, labels = data

                outputs = self.net(images)
                predicted = torch.argmax(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

def mnist_load_data():
    transform = transforms.Compose(
        [transforms.ToTensor(), # Convert a PIL Image or ndarray to tensor and scale the values [0.0, 1.0] accordingly
         transforms.Normalize([0], [1])]) # Normalize a tensor image with mean and standard deviation

    trainset = torchvision.datasets.MNIST(root='./mnist', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='./mnist', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)
    return trainloader, testloader


class MnistNet(torch.nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        self.fc1 = torch.nn.Linear(28*28, 512)
        self.fc2 = torch.nn.Linear(512, 512)
        self.fc3 = torch.nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x

if __name__ == '__main__':
    net = MnistNet()
    model = Model(net, 'CROSS_ENTROPY', 'RMSP')
    train_loader, test_loader = mnist_load_data()
    model.train(train_loader)
    model.evaluate(test_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 110554882.27it/s]


Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 17352197.94it/s]


Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 158316349.16it/s]

Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9602081.03it/s]


Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

[epoch 1, 0.05%] loss: 2.302
[epoch 1, 5.39%] loss: 175.183
[epoch 1, 10.72%] loss: 157.970
[epoch 1, 16.05%] loss: 154.773
[epoch 1, 21.39%] loss: 155.781
[epoch 1, 26.72%] loss: 154.937
[epoch 1, 32.05%] loss: 154.259
[epoch 1, 37.39%] loss: 153.178
[epoch 1, 42.72%] loss: 152.849
[epoch 1, 48.05%] loss: 152.493
[epoch 1, 53.39%] loss: 151.942
[epoch 1, 58.72%] loss: 152.960
[epoch 1, 64.05%] loss: 152.619
[epoch 1, 69.39%] loss: 152.015
[epoch 1, 74.72%] loss: 152.345
[epoch 1, 80.05%] loss: 151.109
[epoch 1, 85.39%] loss: 152.355
[epoch 1, 90.72%] loss: 151.746
[epoch 1, 96.05%] loss: 151.581
[epoch 2, 0.05%] loss: 1.497
[epoch 2, 5.39%] loss: 150.204
[epoch 2, 10.72%] loss: 150.440
[epoch 2, 16.05%] loss: 150.722
[epoch 2, 21.39%] loss: 151.304
[epoch 2, 26.72%] loss: 151.456
[epoch 2, 32.05%] loss: 150.535
[epoch 2, 37.39%] loss: 150.249
[epoch 2, 42.72%] loss: 150.751
[epoch 2, 48.05%] loss: 150.435
[ep