In [1]:
import torch

from torch import nn, optim

import torch.nn.functional as F

from torch.autograd import Variable

from torch.utils.data import DataLoader

from torchvision import transforms

from torchvision import datasets

import time

In [2]:
# 定义超参数

batch_size = 32

learning_rate = 1e-3

num_epoches = 100

In [5]:
# 下载训练集 MNIST 手写数字训练集

train_dataset = datasets.MNIST(

    root='./data', train=True, transform=transforms.ToTensor(), download=False)



test_dataset = datasets.MNIST(

    root='./data', train=False, transform=transforms.ToTensor())

In [17]:
train_dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train

In [18]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [19]:
# 定义 Logistic Regression 模型

class Logstic_Regression(nn.Module):

    def __init__(self, in_dim, n_class):

        super(Logstic_Regression, self).__init__()

        self.logstic = nn.Linear(in_dim, n_class)



    def forward(self, x):

        out = self.logstic(x)

        return out

In [20]:
model = Logstic_Regression(28 * 28, 10)  # 图片大小是28x28

use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速

if use_gpu:

    model = model.cuda()

# 定义loss和optimizer

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [25]:
# 开始训练

for epoch in range(num_epoches):

    print('*' * 10)

    print('epoch {}'.format(epoch + 1))

    since = time.time()

    running_loss = 0.0

    running_acc = 0.0

    for i, data in enumerate(train_loader, 1):

        img, label = data

        img = img.view(img.size(0), -1)  # 将图片展开成 28x28

        if use_gpu:

            img = Variable(img).cuda()

            label = Variable(label).cuda()

        else:

            img = Variable(img)

            label = Variable(label)

        # 向前传播

        out = model(img)

        loss = criterion(out, label)

        running_loss += loss.item() * label.size(0)

        _, pred = torch.max(out, 1)

        num_correct = (pred == label).sum()

        running_acc += num_correct.item()

        # 向后传播

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()



        if i % 300 == 0:

            print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(

                epoch + 1, num_epoches, running_loss / (batch_size * i),

                running_acc / (batch_size * i)))

    print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(

        epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(

            train_dataset))))

    model.eval()

    eval_loss = 0.

    eval_acc = 0.

    for data in test_loader:

        img, label = data

        img = img.view(img.size(0), -1)

        if use_gpu:

            img = Variable(img, volatile=True).cuda()

            label = Variable(label, volatile=True).cuda()

        else:

            img = Variable(img, volatile=True)

            label = Variable(label, volatile=True)

        out = model(img)

        loss = criterion(out, label)

        eval_loss += loss.item() * label.size(0)

        _, pred = torch.max(out, 1)

        num_correct = (pred == label).sum()

        eval_acc += num_correct.item()

    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(

        test_dataset)), eval_acc / (len(test_dataset))))

    print('Time:{:.1f} s'.format(time.time() - since))

    print()

**********
epoch 1
[1/100] Loss: 1.200962, Acc: 0.798854
[1/100] Loss: 1.162999, Acc: 0.803594
[1/100] Loss: 1.126618, Acc: 0.808854
[1/100] Loss: 1.098802, Acc: 0.811667
[1/100] Loss: 1.070919, Acc: 0.814500
[1/100] Loss: 1.048679, Acc: 0.816111
Finish 1 epoch, Loss: 1.043446, Acc: 0.816683




Test Loss: 0.880771, Acc: 0.841400
Time:16.8 s

**********
epoch 2
[2/100] Loss: 0.889827, Acc: 0.833125
[2/100] Loss: 0.877464, Acc: 0.833281
[2/100] Loss: 0.867640, Acc: 0.832431
[2/100] Loss: 0.850078, Acc: 0.836224
[2/100] Loss: 0.841344, Acc: 0.836542
[2/100] Loss: 0.829250, Acc: 0.837917
Finish 2 epoch, Loss: 0.826896, Acc: 0.837667
Test Loss: 0.735379, Acc: 0.854800
Time:16.8 s

**********
epoch 3
[3/100] Loss: 0.764338, Acc: 0.844375
[3/100] Loss: 0.748265, Acc: 0.845729
[3/100] Loss: 0.739678, Acc: 0.845972
[3/100] Loss: 0.733746, Acc: 0.846406
[3/100] Loss: 0.727405, Acc: 0.847083
[3/100] Loss: 0.719760, Acc: 0.848125
Finish 3 epoch, Loss: 0.717587, Acc: 0.848750
Test Loss: 0.652738, Acc: 0.862900
Time:16.6 s

**********
epoch 4
[4/100] Loss: 0.669465, Acc: 0.853958
[4/100] Loss: 0.665729, Acc: 0.856510
[4/100] Loss: 0.661176, Acc: 0.856736
[4/100] Loss: 0.657324, Acc: 0.855885
[4/100] Loss: 0.654610, Acc: 0.855583
[4/100] Loss: 0.650837, Acc: 0.856476
Finish 4 epoch, Loss: 0

[25/100] Loss: 0.395697, Acc: 0.893792
[25/100] Loss: 0.396207, Acc: 0.893646
Finish 25 epoch, Loss: 0.397179, Acc: 0.893633
Test Loss: 0.374961, Acc: 0.900000
Time:18.7 s

**********
epoch 26
[26/100] Loss: 0.394464, Acc: 0.892604
[26/100] Loss: 0.388958, Acc: 0.896823
[26/100] Loss: 0.396175, Acc: 0.894757
[26/100] Loss: 0.393523, Acc: 0.894375
[26/100] Loss: 0.392632, Acc: 0.895021
[26/100] Loss: 0.392705, Acc: 0.894618
Finish 26 epoch, Loss: 0.394062, Acc: 0.894333
Test Loss: 0.372177, Acc: 0.901000
Time:17.7 s

**********
epoch 27
[27/100] Loss: 0.402270, Acc: 0.893229
[27/100] Loss: 0.396686, Acc: 0.893385
[27/100] Loss: 0.392314, Acc: 0.895174
[27/100] Loss: 0.392686, Acc: 0.894089
[27/100] Loss: 0.391384, Acc: 0.894500
[27/100] Loss: 0.390557, Acc: 0.894722
Finish 27 epoch, Loss: 0.391121, Acc: 0.894817
Test Loss: 0.369522, Acc: 0.901100
Time:16.5 s

**********
epoch 28
[28/100] Loss: 0.398959, Acc: 0.892292
[28/100] Loss: 0.397799, Acc: 0.891875
[28/100] Loss: 0.395182, Acc: 0

Time:17.4 s

**********
epoch 49
[49/100] Loss: 0.357041, Acc: 0.902396
[49/100] Loss: 0.357104, Acc: 0.902188
[49/100] Loss: 0.357457, Acc: 0.901285
[49/100] Loss: 0.355793, Acc: 0.901432
[49/100] Loss: 0.354926, Acc: 0.902271
[49/100] Loss: 0.351357, Acc: 0.903177
Finish 49 epoch, Loss: 0.351451, Acc: 0.903033
Test Loss: 0.333915, Acc: 0.909100
Time:16.8 s

**********
epoch 50
[50/100] Loss: 0.342911, Acc: 0.906875
[50/100] Loss: 0.345451, Acc: 0.904271
[50/100] Loss: 0.346953, Acc: 0.904444
[50/100] Loss: 0.344793, Acc: 0.905156
[50/100] Loss: 0.349179, Acc: 0.903563
[50/100] Loss: 0.349879, Acc: 0.903403
Finish 50 epoch, Loss: 0.350307, Acc: 0.903417
Test Loss: 0.332861, Acc: 0.908800
Time:16.8 s

**********
epoch 51
[51/100] Loss: 0.340070, Acc: 0.905625
[51/100] Loss: 0.348517, Acc: 0.904479
[51/100] Loss: 0.349517, Acc: 0.904757
[51/100] Loss: 0.348001, Acc: 0.905234
[51/100] Loss: 0.346901, Acc: 0.905333
[51/100] Loss: 0.347878, Acc: 0.904323
Finish 51 epoch, Loss: 0.349191, Ac

[72/100] Loss: 0.330253, Acc: 0.908620
[72/100] Loss: 0.331255, Acc: 0.908146
[72/100] Loss: 0.331242, Acc: 0.908125
Finish 72 epoch, Loss: 0.331333, Acc: 0.908167
Test Loss: 0.316535, Acc: 0.913500
Time:17.7 s

**********
epoch 73
[73/100] Loss: 0.337311, Acc: 0.904583
[73/100] Loss: 0.337458, Acc: 0.906146
[73/100] Loss: 0.333586, Acc: 0.907188
[73/100] Loss: 0.332114, Acc: 0.907839
[73/100] Loss: 0.330921, Acc: 0.908333
[73/100] Loss: 0.329978, Acc: 0.908628
Finish 73 epoch, Loss: 0.330704, Acc: 0.908233
Test Loss: 0.315931, Acc: 0.913700
Time:17.0 s

**********
epoch 74
[74/100] Loss: 0.332301, Acc: 0.906771
[74/100] Loss: 0.329401, Acc: 0.908125
[74/100] Loss: 0.326632, Acc: 0.909479
[74/100] Loss: 0.328873, Acc: 0.908698
[74/100] Loss: 0.328092, Acc: 0.908917
[74/100] Loss: 0.329012, Acc: 0.908663
Finish 74 epoch, Loss: 0.330051, Acc: 0.908300
Test Loss: 0.315346, Acc: 0.913600
Time:17.2 s

**********
epoch 75
[75/100] Loss: 0.333188, Acc: 0.909583
[75/100] Loss: 0.338020, Acc: 0

Test Loss: 0.306254, Acc: 0.915800
Time:17.8 s

**********
epoch 96
[96/100] Loss: 0.317030, Acc: 0.912917
[96/100] Loss: 0.319376, Acc: 0.911667
[96/100] Loss: 0.320567, Acc: 0.910729
[96/100] Loss: 0.316608, Acc: 0.911328
[96/100] Loss: 0.318847, Acc: 0.911021
[96/100] Loss: 0.318134, Acc: 0.911563
Finish 96 epoch, Loss: 0.318517, Acc: 0.911517
Test Loss: 0.305904, Acc: 0.916100
Time:17.2 s

**********
epoch 97
[97/100] Loss: 0.312580, Acc: 0.914062
[97/100] Loss: 0.315009, Acc: 0.911302
[97/100] Loss: 0.314533, Acc: 0.911389
[97/100] Loss: 0.313742, Acc: 0.912161
[97/100] Loss: 0.316391, Acc: 0.911625
[97/100] Loss: 0.318328, Acc: 0.911285
Finish 97 epoch, Loss: 0.318088, Acc: 0.911600
Test Loss: 0.305573, Acc: 0.915500
Time:17.8 s

**********
epoch 98
[98/100] Loss: 0.309891, Acc: 0.910729
[98/100] Loss: 0.317463, Acc: 0.911406
[98/100] Loss: 0.318615, Acc: 0.911632
[98/100] Loss: 0.318083, Acc: 0.912292
[98/100] Loss: 0.318435, Acc: 0.911687
[98/100] Loss: 0.317190, Acc: 0.911701


In [26]:
# 保存模型

torch.save(model.state_dict(), './logstic.pth')