<a href="https://colab.research.google.com/github/ninikvn/hackathon-project/blob/main/pytorch_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import numpy as np
from tqdm import tqdm

In [None]:

# pytorch has it's own methods of loading and batching data, which you can see below
""" Ordinarily the commented code below is all you need to load the MNIST data.
      For some reason, Yann Lecun.com is timing out so we instead make our own
      dataset to save time.
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transforms.ToTensor())
"""

class MNISTDataset(Dataset):
  """
    Every Pytorch Dataset needs an __init__, __len__, and __getitem__
    These methods are used to get and batch the data using a DataLoader later
  """
  def __init__(self, images, labels):
    self.images = torch.Tensor(images)
    self.labels = torch.Tensor(labels)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    return self.images[idx], self.labels[idx]


train_dataset = MNISTDataset(train_inputs, train_labels)
test_dataset = MNISTDataset(test_inputs, test_labels)


# dataloaders are an easy way to batch and shuffle datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False)

Now that we have our data, let's create and train a PyTorch Model! It looks very similar to Tensorflow, with some small differences in naming conventions, etc.

In [None]:
class Model(torch.nn.Module):

  def __init__(self, **kwargs):
    """
    The model class inherits from tf.keras.Model.
    It stores the trainable weights as attributes.
    """
    super(Model, self).__init__(**kwargs)

    # Initialize our torch.nn.Linear layers again we use 256, 128, 10
    self.layer1 = torch.nn.Linear(784, 256)
    self.layer2 = torch.nn.Linear(256, 128)
    self.layer3 = torch.nn.Linear(128, 10)

    # PyTorch Linear Layers don't let you nicely initialize an activation function
    #   line TF does so we need to create these explicitly
    self.relu = torch.nn.ReLU()
    self.softmax = torch.nn.Softmax(dim=1)
  def forward(self, inputs):
    """
    Forward pass, predicts labels given an input image using fully connected layers
    :return: the probabilites of each label
    """

    out1 = self.layer1(inputs)
    out1 = self.relu(out1)
    out2 = self.layer2(out1)
    out2 = self.relu(out2)
    out3 = self.layer3(out2)
    prbs = self.softmax(out3)
    return prbs

  def loss(self, predictions, labels):
    """
    Calculates the model loss
    :return: the loss of the model as a tensor
    """
    nll_comps = -labels * torch.log(torch.clip(predictions,1e-10,1.0))
    return torch.mean(torch.sum(nll_comps, axis=[1]))

  def accuracy(self, predictions, labels):
    """
    Calculates the model accuracy
    :return: the accuracy of the model as a tensor
    """
    pred_classes = torch.argmax(predictions, 1)
    true_classes = torch.argmax(labels, 1)
    correct_prediction = torch.eq(pred_classes, true_classes)
    return torch.mean(torch.Tensor(correct_prediction).to(torch.float32))

## END TODO
################################################################################

# Instantiate our model
model = Model()

# Create our optimizer, notice that the parameters are passed into the init.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Loop through training steps
epochs = 10

for j in range(epochs):
  for batch_idx, (input, label) in tqdm(enumerate(train_loader)):
    # There isn't a "GradientTape" context manager for torch
    #   Instead, torch Tensors have a backward method which backpropagates
    #   automatically. We will talk a little about some of these differences later

    input = torch.reshape(input, (len(input),-1))
    y_pred = model(input) # this calls the call function conveniently
    loss = model.loss(y_pred, label) # compute the loss
    loss.backward() # compute and assign the gradients via backprop
    optimizer.step() # update the parameters
    optimizer.zero_grad() # reset the stored gradients for each of the parameters (can also move this above the line that starts with input)

  test_acc = 0
  for batch_idx, (input, label) in enumerate(test_loader):
    input = torch.reshape(input, (len(input),-1))
    test_acc += model.accuracy(model(input), label)
  print(f"Accuracy on testing set after epoch {j}: {test_acc/len(test_loader)}")
print()
print(model)

# Different optimizer used here; tf basic but here Adam

59it [00:01, 40.03it/s]


Accuracy on testing set after epoch 0: 0.9364855885505676


59it [00:01, 53.38it/s]


Accuracy on testing set after epoch 1: 0.9549247026443481


59it [00:01, 50.60it/s]


Accuracy on testing set after epoch 2: 0.9593949317932129


59it [00:01, 52.18it/s]


Accuracy on testing set after epoch 3: 0.9654256701469421


59it [00:01, 43.21it/s]


Accuracy on testing set after epoch 4: 0.9686263799667358


59it [00:01, 52.84it/s]


Accuracy on testing set after epoch 5: 0.9700155258178711


59it [00:01, 51.78it/s]


Accuracy on testing set after epoch 6: 0.9710897207260132


59it [00:01, 49.31it/s]


Accuracy on testing set after epoch 7: 0.9715701341629028


59it [00:02, 27.35it/s]


Accuracy on testing set after epoch 8: 0.9713448286056519


59it [00:01, 39.87it/s]


Accuracy on testing set after epoch 9: 0.9723971486091614

Model(
  (layer1): Linear(in_features=784, out_features=256, bias=True)
  (layer2): Linear(in_features=256, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=1)
)
