# Pytorch Tutorial

Pytorch is a popular deep learning framework and it's easy to get started.

In [2]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

First, we read the mnist data, preprocess them and encapsulate them into dataloader form.

In [3]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

Then, we define the model, object function and optimizer that we use to classify.

In [6]:
import torch.nn.functional as F
class SimpleNet(nn.Module):
# TODO:define model

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        # 第一个线性层输入维度：28*28-图片长*宽；输出维度500-第一层神经元个数
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
        # 预测层输出为10个神经元，代表10个数字
    def forward(self, x):
        x = x.view(-1, 28*28) #把torch tensor先展开成一行，再按照指定的size进行resize
        #此处就是每一列28*28（一个图片）；-1表示的是行数自动算出，行数也就是图片数目
        x = F.relu(self.fc1(x))
        # 激活
        x = F.relu(self.fc2(x))
        # 激活
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x


# 实例化    
model = SimpleNet()

# TODO:define loss function and optimiter
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
print(model)

SimpleNet(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=10, bias=True)
)


Next, we can start to train and evaluate!

In [7]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    ave_loss = 0
    for X_train, y_train in tqdm(train_loader):
        # TODO:forward + backward + optimize
       optimizer.zero_grad()                       
       out = model(X_train)                        
       loss = criterion(out, y_train)              
       ave_loss = ave_loss * 0.9 + loss.item() * 0.1
       loss.backward()                             
       optimizer.step()                            
    print("epoch_number:",epoch,"------------loss:",ave_loss)
    correct_cnt= 0
    total_cnt = 0
    for X_test, y_test in tqdm(test_loader):
        out = model(X_test)                     
        loss = criterion(out, y_test)       
        _, pred_label = torch.max(out.data, 1)   
        total_cnt += X_test.data.size()[0]
        correct_cnt += (pred_label == y_test).sum()
    accuracy = float(correct_cnt)/total_cnt
    print("epoch_number:",epoch,"------------test accuracy:",accuracy)
   
        
        
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset
    
    
    
    


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:25<00:00, 18.38it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

epoch_number: 0 ------------loss: 0.19567947985589446


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:03<00:00, 24.95it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:27, 16.71it/s]

epoch_number: 0 ------------test accuracy: 0.9507211538461539


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 15.95it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

epoch_number: 1 ------------loss: 0.11619742233061624


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.12it/s]
  0%|▏                                                                                 | 1/468 [00:00<00:53,  8.72it/s]

epoch_number: 1 ------------test accuracy: 0.9614383012820513


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:31<00:00, 14.84it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

epoch_number: 2 ------------loss: 0.09508599777936234


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.83it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:38, 12.08it/s]

epoch_number: 2 ------------test accuracy: 0.96484375


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:30<00:00, 15.26it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:03, 24.86it/s]

epoch_number: 3 ------------loss: 0.0938442731821325


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:03<00:00, 25.33it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:27, 16.85it/s]

epoch_number: 3 ------------test accuracy: 0.9692508012820513


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:30<00:00, 15.20it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:03, 24.66it/s]

epoch_number: 4 ------------loss: 0.07889131218627818


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.79it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:29, 15.92it/s]

epoch_number: 4 ------------test accuracy: 0.9705528846153846


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:31<00:00, 14.74it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 26.62it/s]

epoch_number: 5 ------------loss: 0.05893462146167242


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 27.29it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:29, 15.54it/s]

epoch_number: 5 ------------test accuracy: 0.9760616987179487


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:32<00:00, 14.57it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:03, 23.87it/s]

epoch_number: 6 ------------loss: 0.05335643544278587


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.37it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:28, 16.44it/s]

epoch_number: 6 ------------test accuracy: 0.9777644230769231


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:36<00:00, 12.90it/s]
  3%|██▏                                                                                | 2/78 [00:00<00:04, 17.75it/s]

epoch_number: 7 ------------loss: 0.06268776397146568


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:05<00:00, 13.14it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:42, 10.84it/s]

epoch_number: 7 ------------test accuracy: 0.9764623397435898


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:35<00:00, 13.25it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 26.61it/s]

epoch_number: 8 ------------loss: 0.04603005365076438


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.82it/s]
  0%|▎                                                                                 | 2/468 [00:00<00:27, 16.71it/s]

epoch_number: 8 ------------test accuracy: 0.9791666666666666


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:31<00:00, 14.87it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:03, 23.32it/s]

epoch_number: 9 ------------loss: 0.03712699264227906


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:03<00:00, 25.26it/s]


epoch_number: 9 ------------test accuracy: 0.9746594551282052


#### Q5:
Please print the training and testing accuracy.

In [8]:
correct_cnt1= 0
total_cnt1 = 0
for X_train, y_train in tqdm(train_loader):
        out = model(X_train)
        loss = criterion(out, y_train)
        _, pred_label = torch.max(out.data, 1)
        total_cnt1 += X_train.data.size()[0]
        correct_cnt1 += (pred_label == y_train).sum()
train_accuracy = float(correct_cnt1)/total_cnt1
print(train_accuracy)
correct_cnt2= 0
total_cnt2 = 0
for X_test, y_test in tqdm(test_loader):
        out = model(X_train)
        loss = criterion(out, y_train)
        _, pred_label = torch.max(out.data, 1)
        total_cnt2 += X_train.data.size()[0]
        correct_cnt2 += (pred_label == y_train).sum()
test_accuracy = float(correct_cnt2)/total_cnt2
print(test_accuracy)


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [00:16<00:00, 28.69it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 28.38it/s]

0.9849425747863247


100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 27.08it/s]

0.984375



