### **Traning A Feed Forward Neural Network On MNIST**

In [None]:
import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
BATCH_SIZE=128
EPOCHS = 5

In [None]:
def download_mnist_dataset():
  train_data = datasets.MNIST(root='data', download = True, train = True, transform = ToTensor())
  valid_data = datasets.MNIST(root='data', download = True, train = False, transform = ToTensor())
  return train_data, valid_data

In [None]:
class FeedForwardNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.dense_layers = nn.Sequential(
        nn.Linear(28*28, 256),
        nn.ReLU(),
        nn.Linear(256,10)
    )
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_data):
    flattened_data = self.flatten(input_data)
    logits = self.dense_layers(flattened_data)
    predictions  = self.softmax(logits)

    return predictions

In [None]:
def train_one_epoch(model, data_loader, loss_fn, optimiser, device):
  for inputs, targets in data_loader:
    inputs,targets = inputs.to(device), targets.to(device)
    predictions = model(inputs)
    loss = loss_fn(predictions,targets)

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
  
  print(f"Loss:{loss.item()}")

def train(model, data_loader, loss_fn, optimiser, device, epochs):
  for i in range(epochs):
    print(f"Epoch:{i+1}")
    train_one_epoch(model, data_loader, loss_fn, optimiser, device)
  print("Training Complete")

In [None]:
from torch._C import device
if __name__ == "__main__":
  train_data,_ = download_mnist_dataset()
  print("Dataset Downloaded")

  train_data_loader = DataLoader(train_data, batch_size = BATCH_SIZE)

  if torch.cuda.is_available():
    device = 'cuda'
  else:
    device = 'cpu'
  print(device)
  feed_forward_net = FeedForwardNet().to(device)

  loss_fn = nn.CrossEntropyLoss()
  optimiser = torch.optim.Adam(feed_forward_net.parameters(),lr = 0.0001)

  train( feed_forward_net, train_data_loader, loss_fn, optimiser, device, EPOCHS)
  torch.save(feed_forward_net.state_dict(), "feedforwardnet.pth")
  print("Model trained and stored at feedforwardnet.pth")

Dataset Downloaded
cuda
Epoch:1
Loss:1.7094453573226929
Epoch:2
Loss:1.601638913154602
Epoch:3
Loss:1.5671577453613281
Epoch:4
Loss:1.549139380455017
Epoch:5
Loss:1.5370005369186401
Training Complete
Model trained and stored at feedforwardnet.pth


###**Inference**

In [None]:
class_mapping = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9"
]


In [None]:
def predict(model, input, target, class_mapping):
    model.eval() #it is like a switch if we use it, it basically switches off dropout, batch-norm
    # model.train() #it is like the other switch
    with torch.no_grad():
        predictions = model(input)
        print(predictions)
        predicted_index = predictions[0].argmax(0)
        print(predicted_index)
        predicted = class_mapping[predicted_index]
        expected = class_mapping[target]
    return predicted, expected

In [None]:
if __name__ == "__main__":
    
    feed_forward_net = FeedForwardNet()
    state_dict = torch.load("feedforwardnet.pth")
    feed_forward_net.load_state_dict(state_dict)

   
    _, validation_data = download_mnist_dataset()
    print(validation_data[0][0])
    input, target = validation_data[0][0], validation_data[0][1]

    predicted, expected = predict(feed_forward_net, input, target,
                                  class_mapping)
    print(f"Predicted: '{predicted}', expected: '{expected}'")

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

## 