# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [10]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import numpy as np

BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [4]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST\raw\train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ./mnist/MNIST\raw\train-images-idx3-ubyte.gz to ./mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST\raw\train-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ./mnist/MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST\raw\t10k-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ./mnist/MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ./mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist/MNIST\raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


Then, we define the model, object function and optimizer that we use to classify.

In [8]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin_1 = nn.Linear(784, 1024)
        self.lin_2 = nn.Linear(1024, 1024)
        self.lin_3 = nn.Linear(1024, 1024)
        self.lin_4 = nn.Linear(1024, 10)
        self.bn_1 = nn.LayerNorm([1024])
        self.bn_2 = nn.LayerNorm([1024])
        self.bn_3 = nn.LayerNorm([1024])

    def mod(self, x, lin, bn):
        return torch.relu(bn(lin(x)))

    def forward(self, x):
        d = x.flatten(1)
        d = self.mod(d, self.lin_1, self.bn_1)
        d = self.mod(d, self.lin_2, self.bn_2)
        d = self.mod(d, self.lin_3, self.bn_3)
        d = self.lin_4(d)
        return d

    
model = SimpleNet()

# TODO:define loss function and optimiter
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

Next, we can start to train and evaluate!

In [11]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    accs = []
    for images, labels in tqdm(train_loader, position=0, leave=True):
        # TODO:forward + backward + optimize
        with torch.enable_grad():
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            batch_acc = (logits.argmax(-1) == labels).float().mean().item()
            accs.append(batch_acc)
            optimizer.step()
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    print("Train acc. %.4f." % np.mean(accs))
    accs = []
    for images, labels in tqdm(test_loader, position=0, leave=True):
        with torch.no_grad():
            logits = model(images)
            batch_acc = (logits.argmax(-1) == labels).float().mean().item()
            accs.append(batch_acc)
    print("Test acc. %.4f." % np.mean(accs))


100%|████████████████████████████████████████| 468/468 [00:33<00:00, 13.96it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 26.78it/s]

Train acc. 0.9589.


100%|██████████████████████████████████████████| 78/78 [00:03<00:00, 25.93it/s]
  0%|▏                                         | 2/468 [00:00<00:39, 11.83it/s]

Test acc. 0.9636.


100%|████████████████████████████████████████| 468/468 [00:35<00:00, 13.05it/s]
  4%|█▋                                         | 3/78 [00:00<00:03, 24.79it/s]

Train acc. 0.9681.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.58it/s]
  0%|▏                                         | 2/468 [00:00<00:36, 12.74it/s]

Test acc. 0.9690.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.95it/s]
  4%|█▋                                         | 3/78 [00:00<00:03, 25.00it/s]

Train acc. 0.9747.


100%|██████████████████████████████████████████| 78/78 [00:03<00:00, 25.88it/s]
  0%|▏                                         | 2/468 [00:00<00:35, 13.07it/s]

Test acc. 0.9755.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.98it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 27.52it/s]

Train acc. 0.9778.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.90it/s]
  0%|▏                                         | 2/468 [00:00<00:35, 13.07it/s]

Test acc. 0.9738.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.84it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 25.86it/s]

Train acc. 0.9806.


100%|██████████████████████████████████████████| 78/78 [00:03<00:00, 25.84it/s]
  0%|▏                                         | 2/468 [00:00<00:35, 12.99it/s]

Test acc. 0.9754.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.85it/s]
  4%|█▋                                         | 3/78 [00:00<00:03, 23.25it/s]

Train acc. 0.9822.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.71it/s]
  0%|▏                                         | 2/468 [00:00<00:37, 12.50it/s]

Test acc. 0.9779.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.84it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 27.78it/s]

Train acc. 0.9845.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.35it/s]
  0%|▏                                         | 2/468 [00:00<00:37, 12.50it/s]

Test acc. 0.9707.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.91it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 25.42it/s]

Train acc. 0.9857.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.68it/s]
  0%|▏                                         | 2/468 [00:00<00:34, 13.60it/s]

Test acc. 0.9778.


100%|████████████████████████████████████████| 468/468 [00:35<00:00, 13.14it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 28.84it/s]

Train acc. 0.9874.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.64it/s]
  0%|▏                                         | 2/468 [00:00<00:35, 13.24it/s]

Test acc. 0.9751.


100%|████████████████████████████████████████| 468/468 [00:36<00:00, 12.95it/s]
  4%|█▋                                         | 3/78 [00:00<00:02, 27.03it/s]

Train acc. 0.9873.


100%|██████████████████████████████████████████| 78/78 [00:02<00:00, 26.09it/s]

Test acc. 0.9757.





#### Q5:
Please print the training and testing accuracy.