In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
BATCH_SIZE=128
device = 'cpu'
EPOCHS = 10
LEARNING_RATE = 0.001

def download_mnist_datasets():
    train_data = datasets.MNIST(root='data', download=True, train=True, transform=ToTensor())
    validation_data = datasets.MNIST(root='data', download=True, train=False, transform=ToTensor())
    return train_data, validation_data

In [4]:
class FeedForwardNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_layers = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, input_data):
        flatten_data = self.flatten(input_data)
        logits = self.dense_layers(flatten_data)
        predictions = self.softmax(logits)
        return predictions
    
def train_one_epoch(model, data_loader, loss_func, optimizer, device):
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        predictions = model(inputs)
        loss = loss_func(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('loss: ', loss.item())
    
def train(model, data_loader, loss_func, optimizer, device, epochs):
    for i in range(epochs):
        print('Epochs: ',i+1)
        train_one_epoch(model, data_loader, loss_func, optimizer, device)
    
loss_func = nn.CrossEntropyLoss()
feed_forward_net = FeedForwardNet().to('cpu')
optimizer = torch.optim.Adam(feed_forward_net.parameters(), lr = LEARNING_RATE)

In [5]:
train_data,_ = download_mnist_datasets()

In [6]:
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE)

In [7]:
train(feed_forward_net, train_data_loader, loss_func, optimizer, device, EPOCHS)

Epochs:  1
loss:  1.513006567955017
Epochs:  2
loss:  1.4963860511779785
Epochs:  3
loss:  1.489706039428711
Epochs:  4
loss:  1.481165885925293
Epochs:  5
loss:  1.4793280363082886
Epochs:  6
loss:  1.4768176078796387
Epochs:  7
loss:  1.475023865699768
Epochs:  8
loss:  1.4726568460464478
Epochs:  9
loss:  1.472481369972229
Epochs:  10
loss:  1.4735437631607056


In [10]:
torch.save(feed_forward_net.state_dict(),'feedforwardnet.pth')
