In [13]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim

import matplotlib.pyplot as plt
from FileReader import load_mnist

In [2]:
train_set = load_mnist('train')
train_imgs = torch.from_numpy(train_set['images']).type(torch.float)
train_labels = torch.from_numpy(train_set['labels']).type(torch.long)

In [3]:
# 基本参数
train_imgs_num = train_imgs.shape[0]
BATCH_SIZE = int(train_imgs_num * 0.05)
use_dropout = False

In [4]:
# 预处理(拉伸，零均值化)
train_vec_imgs = train_imgs.reshape(train_imgs_num, -1)
train_vec_imgs -= torch.mean(train_vec_imgs, dim=0)

In [14]:
class FourLayersNet(nn.Module):
    '''
    三层全连接网络进行mnist分类
    ReLU激励
    '''
    def __init__(self, n_feature, n_output, *vargs):
        super(FourLayersNet, self).__init__()
        if len(vargs) == 0:
            self.layers1 = nn.Linear(n_feature, 300)
            self.layers2 = nn.Linear(300, 100)
            self.layers3 = nn.Linear(100, 10)
        else:
            # todo: xxx
            pass
    
    def forward(self, x):
        # 第一层输出
        h1 = self.layers1(x)
        a1 = F.relu(h1)
        
        #第二层输出
        h2 = self.layers2(a1)
        a2 = F.relu(h2)
        
        #输出层输出
        h3 = self.layers3(a2)
        
        return h3

In [15]:
net = FourLayersNet(n_feature=784, n_output=10)
# print(net)

In [7]:
# 批处理
torch_dataset = Data.TensorDataset(train_vec_imgs,train_labels)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

# for step, (batch_imgs, batch_labels) in enumerate(loader):
#     print 'step {}: '.format(step+1)
#     print batch_imgs.shape
#     print batch_labels

In [8]:
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
loss_function = nn.CrossEntropyLoss()

for epoch in xrange(58):
    for step, (batch_imgs, batch_labels) in enumerate(loader):
        y_pred = net(batch_imgs)
        loss = loss_function(y_pred, batch_labels)
        
        #更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print 'epoch {} | batch {} loss: {}'.format(epoch+1, step+1, loss.data.item())

#保存模型(参数)
torch.save(net.state_dict(), 'mnist_3LayersNet_model.pkl')

epoch 1 | batch 1 loss: 8.31102561951
epoch 1 | batch 2 loss: 4.64883327484
epoch 1 | batch 3 loss: 2.42849063873
epoch 1 | batch 4 loss: 1.58929121494
epoch 1 | batch 5 loss: 1.3609610796
epoch 1 | batch 6 loss: 1.20207250118
epoch 1 | batch 7 loss: 0.996592760086
epoch 1 | batch 8 loss: 0.972966253757
epoch 1 | batch 9 loss: 0.933719754219
epoch 1 | batch 10 loss: 0.70191681385
epoch 1 | batch 11 loss: 0.556003153324
epoch 1 | batch 12 loss: 0.658551454544
epoch 1 | batch 13 loss: 0.620850801468
epoch 1 | batch 14 loss: 0.55620354414
epoch 1 | batch 15 loss: 0.592419803143
epoch 1 | batch 16 loss: 0.467078238726
epoch 1 | batch 17 loss: 0.379297941923
epoch 1 | batch 18 loss: 0.444414645433
epoch 1 | batch 19 loss: 0.424038946629
epoch 1 | batch 20 loss: 0.428944587708
epoch 2 | batch 1 loss: 0.390865713358
epoch 2 | batch 2 loss: 0.360856890678
epoch 2 | batch 3 loss: 0.335651516914
epoch 2 | batch 4 loss: 0.33077287674
epoch 2 | batch 5 loss: 0.323341012001
epoch 2 | batch 6 loss: 

epoch 11 | batch 12 loss: 0.0492650717497
epoch 11 | batch 13 loss: 0.0491354167461
epoch 11 | batch 14 loss: 0.0548430643976
epoch 11 | batch 15 loss: 0.0615501664579
epoch 11 | batch 16 loss: 0.0590940192342
epoch 11 | batch 17 loss: 0.0568147413433
epoch 11 | batch 18 loss: 0.058586679399
epoch 11 | batch 19 loss: 0.0589291527867
epoch 11 | batch 20 loss: 0.0629076585174
epoch 12 | batch 1 loss: 0.0582566410303
epoch 12 | batch 2 loss: 0.0465785861015
epoch 12 | batch 3 loss: 0.0465272478759
epoch 12 | batch 4 loss: 0.0503288693726
epoch 12 | batch 5 loss: 0.0483330674469
epoch 12 | batch 6 loss: 0.0372797101736
epoch 12 | batch 7 loss: 0.0537297055125
epoch 12 | batch 8 loss: 0.0420431680977
epoch 12 | batch 9 loss: 0.0550136119127
epoch 12 | batch 10 loss: 0.052721478045
epoch 12 | batch 11 loss: 0.0450317822397
epoch 12 | batch 12 loss: 0.0563045367599
epoch 12 | batch 13 loss: 0.0575491636992
epoch 12 | batch 14 loss: 0.0465667098761
epoch 12 | batch 15 loss: 0.0498315617442
epo

epoch 21 | batch 13 loss: 0.0165460724384
epoch 21 | batch 14 loss: 0.0186941102147
epoch 21 | batch 15 loss: 0.0166941490024
epoch 21 | batch 16 loss: 0.0200935825706
epoch 21 | batch 17 loss: 0.0181957092136
epoch 21 | batch 18 loss: 0.0199995785952
epoch 21 | batch 19 loss: 0.022914448753
epoch 21 | batch 20 loss: 0.018631035462
epoch 22 | batch 1 loss: 0.0182910691947
epoch 22 | batch 2 loss: 0.0171170514077
epoch 22 | batch 3 loss: 0.0147491274402
epoch 22 | batch 4 loss: 0.0186500009149
epoch 22 | batch 5 loss: 0.01369744353
epoch 22 | batch 6 loss: 0.0169586855918
epoch 22 | batch 7 loss: 0.0187426321208
epoch 22 | batch 8 loss: 0.0191771667451
epoch 22 | batch 9 loss: 0.024947674945
epoch 22 | batch 10 loss: 0.0139393154532
epoch 22 | batch 11 loss: 0.020777579397
epoch 22 | batch 12 loss: 0.0171928983182
epoch 22 | batch 13 loss: 0.0153585327789
epoch 22 | batch 14 loss: 0.0208848807961
epoch 22 | batch 15 loss: 0.0129057364538
epoch 22 | batch 16 loss: 0.0240434166044
epoch 2

epoch 31 | batch 13 loss: 0.00913082342595
epoch 31 | batch 14 loss: 0.00903428439051
epoch 31 | batch 15 loss: 0.00723793776706
epoch 31 | batch 16 loss: 0.0141198728234
epoch 31 | batch 17 loss: 0.00625444529578
epoch 31 | batch 18 loss: 0.00760471588001
epoch 31 | batch 19 loss: 0.00696515012532
epoch 31 | batch 20 loss: 0.008530725725
epoch 32 | batch 1 loss: 0.00762152438983
epoch 32 | batch 2 loss: 0.00703844381496
epoch 32 | batch 3 loss: 0.00926111824811
epoch 32 | batch 4 loss: 0.00933346617967
epoch 32 | batch 5 loss: 0.0083819963038
epoch 32 | batch 6 loss: 0.00648722983897
epoch 32 | batch 7 loss: 0.00918870512396
epoch 32 | batch 8 loss: 0.00685628177598
epoch 32 | batch 9 loss: 0.00764460628852
epoch 32 | batch 10 loss: 0.00599967455491
epoch 32 | batch 11 loss: 0.00771780405194
epoch 32 | batch 12 loss: 0.00633315229788
epoch 32 | batch 13 loss: 0.0064616156742
epoch 32 | batch 14 loss: 0.00646134559065
epoch 32 | batch 15 loss: 0.00708282040432
epoch 32 | batch 16 loss:

epoch 41 | batch 12 loss: 0.00463056610897
epoch 41 | batch 13 loss: 0.00423014769331
epoch 41 | batch 14 loss: 0.00418936274946
epoch 41 | batch 15 loss: 0.00408421177417
epoch 41 | batch 16 loss: 0.00477712741122
epoch 41 | batch 17 loss: 0.00457600643858
epoch 41 | batch 18 loss: 0.00389662967063
epoch 41 | batch 19 loss: 0.00482527632266
epoch 41 | batch 20 loss: 0.00364958262071
epoch 42 | batch 1 loss: 0.00297097628936
epoch 42 | batch 2 loss: 0.00358231994323
epoch 42 | batch 3 loss: 0.00370639259927
epoch 42 | batch 4 loss: 0.00345232826658
epoch 42 | batch 5 loss: 0.00416194833815
epoch 42 | batch 6 loss: 0.00393915409222
epoch 42 | batch 7 loss: 0.00736629823223
epoch 42 | batch 8 loss: 0.00407224008814
epoch 42 | batch 9 loss: 0.00437977164984
epoch 42 | batch 10 loss: 0.00359717430547
epoch 42 | batch 11 loss: 0.00356913497671
epoch 42 | batch 12 loss: 0.00483965687454
epoch 42 | batch 13 loss: 0.00429230742157
epoch 42 | batch 14 loss: 0.00459922943264
epoch 42 | batch 15 

epoch 51 | batch 6 loss: 0.0026480879169
epoch 51 | batch 7 loss: 0.00320797879249
epoch 51 | batch 8 loss: 0.00278991786763
epoch 51 | batch 9 loss: 0.00254595838487
epoch 51 | batch 10 loss: 0.00330721447244
epoch 51 | batch 11 loss: 0.00226444331929
epoch 51 | batch 12 loss: 0.00233503151685
epoch 51 | batch 13 loss: 0.00268502882682
epoch 51 | batch 14 loss: 0.00394379161298
epoch 51 | batch 15 loss: 0.00257721636444
epoch 51 | batch 16 loss: 0.00319420476444
epoch 51 | batch 17 loss: 0.00265809590928
epoch 51 | batch 18 loss: 0.00289847701788
epoch 51 | batch 19 loss: 0.00229653483257
epoch 51 | batch 20 loss: 0.00211356836371
epoch 52 | batch 1 loss: 0.00323544652201
epoch 52 | batch 2 loss: 0.00290271593258
epoch 52 | batch 3 loss: 0.00244389404543
epoch 52 | batch 4 loss: 0.00248353788629
epoch 52 | batch 5 loss: 0.00220821611583
epoch 52 | batch 6 loss: 0.00360512267798
epoch 52 | batch 7 loss: 0.00233206804842
epoch 52 | batch 8 loss: 0.00236987438984
epoch 52 | batch 9 loss:

In [24]:
#加载训练好的模型
net.load_state_dict(torch.load('mnist_3LayersNet_model.pkl'))

#test

test_set = load_mnist('t10k')
origin = test_set['images']
test_imgs = torch.from_numpy(test_set['images']).type(torch.float)
test_labels = torch.from_numpy(test_set['labels']).type(torch.long)

test_vec_imgs = test_imgs.reshape(test_imgs.shape[0], -1)
test_vec_imgs -= torch.mean(test_vec_imgs, dim=0)

top_k = 1
acc = 0
y_pred = net(test_vec_imgs)
ss, indices = torch.sort(y_pred, dim=1, descending=True)
results = indices[:,0:top_k]
for idx, label in enumerate(test_labels):
    if label in results[idx]:
        acc += 1
print 'accuracy: {}'.format(float(acc)/test_labels.shape[0])
# for idx in xrange(10000):
#     plt.subplot(2,1,1)
#     plt.imshow(origin[idx], cmap='Greys')
#     plt.subplot(2,1,2)
#     plt.barh(range(y_pred.shape[1]), y_pred[idx].data.numpy(), color='blue')
#     plt.yticks(range(y_pred.shape[1]), range(y_pred.shape[1]))
#     plt.show()

accuracy: 0.9738
