In [1]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/81/d0/84a2f072cd407f93a1e50dff059656bce305f084e63a45cbbceb2fdb67b4/pytorch_lightning-1.1.0-py3-none-any.whl (665kB)
[K     |████████████████████████████████| 675kB 6.6MB/s 
[?25hCollecting fsspec>=0.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/a5/8b/1df260f860f17cb08698170153ef7db672c497c1840dcc8613ce26a8a005/fsspec-0.8.4-py3-none-any.whl (91kB)
[K     |████████████████████████████████| 92kB 8.0MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 23.3MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 37.0MB/s 
Buildin

In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [3]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        # MNIST images are of dimensions(1,28,28)(channel,width,height)
        self.layer_1 = nn.Linear(in_features=28 * 28, out_features=128)
        self.layer_2 = nn.Linear(in_features=128, out_features=256)
        self.layer_3 = nn.Linear(in_features=256, out_features=10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b,1,28,28)->(b,1*128*128)
        x = x.view(batch_size, -1)

        # layer 1
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3
        x = self.layer_3(x)

        # probabaility distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

In [4]:
# Transforms
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]
)

In [5]:
# Training , Validation Data
mnist_train = datasets.MNIST(root=os.getcwd(), train=True, download=True,transform=transform)
mnist_train, mnist_val = random_split(dataset=mnist_train, lengths=[55000, 5000])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


In [6]:
# Test Data
mnist_test = datasets.MNIST(root=os.getcwd(), train=False, download=True,transform=transform)

In [7]:
# DataLoaders
mnist_train = DataLoader(dataset=mnist_train, batch_size=64)
mnist_val = DataLoader(dataset=mnist_val, batch_size=64)
mnist_test = DataLoader(dataset=mnist_test, batch_size=64)

In [8]:
# Optimizer
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(params=pytorch_model.parameters(), lr=1e-3)

In [9]:
# Loss
def cross_entropy_loss(logits, labels):
    return nn.functional.nll_loss(logits, labels)





In [10]:
# Training Loop
num_epochs = 10

for epoch in range(num_epochs):
  print("Epoch: ", epoch)
  
  # Training
  for train_batch in mnist_train:
      x, y = train_batch
      logits = pytorch_model(x)
      loss = cross_entropy_loss(logits, y)
      print("train loss: ", loss.item())

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

  # Validation
  with torch.no_grad():
      val_loss = []
      for val_batch in mnist_val:
          x, y = val_batch
          logits = pytorch_model(x)
          val_loss.append(cross_entropy_loss(logits, y).item())

      val_loss = torch.mean(torch.tensor(val_loss))
      print("val_loss: ", val_loss.item())

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
train loss:  0.005644440185278654
train loss:  0.041159845888614655
train loss:  0.13122306764125824
train loss:  0.04380566254258156
train loss:  0.07701154798269272
train loss:  0.026706727221608162
train loss:  0.04122988134622574
train loss:  0.029274215921759605
train loss:  0.017999833449721336
train loss:  0.04827303811907768
train loss:  0.006754807662218809
train loss:  0.023744111880660057
train loss:  0.021815655753016472
train loss:  0.04145560786128044
train loss:  0.07162683457136154
train loss:  0.1398409754037857
train loss:  0.04645596072077751
train loss:  0.017741616815328598
train loss:  0.05203850939869881
train loss:  0.13359872996807098
train loss:  0.055770572274923325
train loss:  0.01271052099764347
train loss:  0.031201595440506935
train loss:  0.017982332035899162
train loss:  0.08215659111738205
train loss:  0.07065184414386749
train loss:  0.09248240292072296
train loss:  0.01038680225610733
