In [12]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from sklearn.metrics import confusion_matrix

In [2]:
BATCH_SIZE=128
device = 'cpu'
EPOCHS = 10
LEARNING_RATE = 0.001

def download_mnist_datasets():
    train_data = datasets.MNIST(root='data', download=True, train=True, transform=ToTensor())
    validation_data = datasets.MNIST(root='data', download=True, train=False, transform=ToTensor())
    return train_data, validation_data

In [3]:
class_mapping = ['0','1','2','3','4','5','6','7','8','9']

In [4]:
class FeedForwardNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_layers = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, input_data):
        flatten_data = self.flatten(input_data)
        logits = self.dense_layers(flatten_data)
        predictions = self.softmax(logits)
        return predictions
    
def train_one_epoch(model, data_loader, loss_func, optimizer, device):
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        predictions = model(inputs)
        loss = loss_func(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('loss: ', loss.item())
    
def train(model, data_loader, loss_func, optimizer, device, epochs):
    for i in range(epochs):
        print('Epochs: ',i+1)
        train_one_epoch(model, data_loader, loss_func, optimizer, device)
        
def predict(model, input, target, class_mapping):
    model.eval()
    with torch.no_grad():
        predictions = model(input)
        predicted_idx = predictions[0].argmax(0)
        predicted = class_mapping[predicted_idx]
        expected = class_mapping[target]
    return predicted, expected


loss_func = nn.CrossEntropyLoss()
feed_forward_net = FeedForwardNet().to('cpu')
optimizer = torch.optim.Adam(feed_forward_net.parameters(), lr = LEARNING_RATE)

In [5]:
train_data,validation_data = download_mnist_datasets()

In [6]:
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE)

In [7]:
train(feed_forward_net, train_data_loader, loss_func, optimizer, device, EPOCHS)

Epochs:  1
loss:  1.508934497833252
Epochs:  2
loss:  1.5183266401290894
Epochs:  3
loss:  1.5036649703979492
Epochs:  4
loss:  1.485832691192627
Epochs:  5
loss:  1.5028905868530273
Epochs:  6
loss:  1.4873566627502441
Epochs:  7
loss:  1.483221411705017
Epochs:  8
loss:  1.4720631837844849
Epochs:  9
loss:  1.486409068107605
Epochs:  10
loss:  1.4717451333999634


In [8]:
torch.save(feed_forward_net.state_dict(),'feedforwardnet.pth')


In [9]:
state_dict = torch.load('feedforwardnet.pth')

In [10]:
feed_forward_net.load_state_dict(state_dict)

<All keys matched successfully>

In [13]:
preds = []
expects = []
for i in range(1000):
    input, target = validation_data[i][0], validation_data[i][1]
    predicted, expected = predict(feed_forward_net, input, target, class_mapping)
    preds.append(predicted)
    expects.append(expected)
    #print('predicted: ', predicted, ' expected: ', expected)

In [16]:
confusion_matrix(preds, expects) #looks good

array([[ 85,   0,   0,   0,   1,   0,   3,   0,   1,   0],
       [  0, 125,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0, 112,   0,   1,   0,   0,   1,   1,   0],
       [  0,   0,   0, 105,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,  98,   1,   0,   0,   2,   0],
       [  0,   0,   0,   1,   0,  86,   0,   0,   0,   0],
       [  0,   1,   0,   0,   3,   0,  84,   0,   0,   0],
       [  0,   0,   2,   1,   0,   0,   0,  98,   1,   0],
       [  0,   0,   1,   0,   0,   0,   0,   0,  83,   2],
       [  0,   0,   1,   0,   7,   0,   0,   0,   1,  92]])