<a href="https://colab.research.google.com/github/littlejacinthe/torchaudio/blob/main/Predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Making predictions with models

Tutorial by The Sound of AI on YT

In [12]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [13]:
class FeedForwardNet(nn.Module):

  def __init__(self):
    super().__init__() # nn.Module functions

    self.flatten = nn.Flatten()
    self.dense_layers = nn.Sequential( #trick to get several layers into one component
        nn.Linear(28*28, 256), #images in the dataset are of size 28x28 --> Flattened
        nn.ReLU(), #activation layer
        nn.Linear(256, 10) #output layer
    )
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_data): #how to manipulate the data
    flattened_data = self.flatten(input_data)
    logits = self.dense_layers(flattened_data)
    predictions = self.softmax(logits)
    return predictions
def download_mnist_datasets():

  train_data = datasets.MNIST(
      root="data",
      download=True,
      train=True,
      transform=ToTensor() # normalized btw 0 and 1
  )

  validation_data = datasets.MNIST(
      root="data",
      download=True,
      train=False,
      transform=ToTensor() # normalized btw 0 and 1
  )
  
  return train_data, validation_data

In [14]:
class_mapping = [
                 "0",
                 "1",
                 "2",
                 "3",
                 "4",
                 "5",
                 "6",
                 "7",
                 "8",
                 "9"
]

In [15]:
def predict(model, input, target, class_mapping):
  model.eval()
  with torch.no_grad():
    predictions = model(input)
    # Tensor object (1, 10)
    predicted_index = predictions[0].argmax(0)
    predicted = class_mapping[predicted_index]
    expected = class_mapping[target]

  return predicted, expected

In [16]:
if __name__ == "__main__":
  # load the model back
  feed_forward_net = FeedForwardNet()
  state_dict = torch.load("/content/feedforwardnet.pth")
  feed_forward_net.state_dict(state_dict)

  #load mnist validation dataset
  _, validation_data = download_mnist_datasets()

  #get a sample from the validation dataset
  input, target = validation_data[0][0], validation_data[0][1]

  #inference
  predicted, expected = predict(feed_forward_net, input, target, class_mapping)

  print(f"Predicted: '{predicted}', expected: '{expected}'")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  "Positional args are being deprecated, use kwargs instead. Refer to "


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Predicted: '0', expected: '7'
