In [3]:
import torch
from torch import nn,optim
from torch.utils.data import DataLoader
import torch.nn.functional as f
from torchvision import transforms
from torchvision import datasets
import time 

In [12]:
batch_size=32
learning_rate=1e-3
num_epochs=100

In [4]:
train=datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [7]:
test=datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor(),download=True)

In [15]:
train_loader=DataLoader(train,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test,batch_size=batch_size,shuffle=False)

In [9]:
class logisticr(nn.Module):
    def __init__(self,in_dim,n_class):
        super(logisticr,self).__init__()
        self.logistic=nn.Linear(in_dim,n_class)
    def forward(self,x):
        out=self.logistic(x)
        return out

In [13]:
model=logisticr(28*28,10)
use_gpu=torch.cuda.is_available()
if use_gpu:
    model=model.cuda()
criterion=nn.CrossEntropyLoss()
opt=optim.SGD(model.parameters(),lr=learning_rate)

In [17]:
for epoch in range(num_epochs):
    print('*'*10)
    print('epoch{}'.format(epoch+1))
    start=time.time()
    running_loss=0.0
    running_acc=0.0
    for i,data in enumerate(train_loader,1):
        img,label=data
        img=img.view(img.size(0),-1) #将图片展开成28*28
        if use_gpu:
            img=img.cuda()
            label=label.cuda()
        else:
            img=img
            label=label
    
        out=model(img)
        loss=criterion(out,label)
        running_loss+=loss.data[0]*label.size(0)
        _,pred=torch.max(out,1)
        num_correct=(pred==label).sum()
        running_acc+=num_correct.data[0]

        opt.zero_grad()
        loss.backward()
        opt.step()

        if i%300==0:
            print('[{}/{}]loss:{:.6f},Acc:{:.6f}'.format(epoch+1,num_epochs,
            running_loss/(batch_size*i),running_acc/(batch_size*i)))
    print ('Finish{}epoch,Loss:{:.6f},Acc:{:.6f}'.format(
        epoch+1,running_loss/(len(train)), running_acc/(len(train))))
    
    model.eval()
    eval_loss=0.
    eval_acc=0.
    
    for data in test_loader:
        img,label=data
        img=img.view(img.size(0),-1)
        if use_gpu:
            img=img.cuda()
            label=label.cuda()
        else:
            img=img
            label=label
        out=model(img)
        loss=criterion(out,label)
        eval_loss+=loss.data[0]*label.size(0)
        _,pred=torch.max(out,1)
        num_correct=(pred==label).sum()
        eval_acc+=num_correct.data[0]
    print('Test Loss:{:.6f},Acc:{:.6f}'.format(eval_loss/(len(
        test)),eval_acc/(len(test))))
    print('Time:{:.1f} s'.format(time.time() - start))
    

**********
epoch1




[1/100]loss:2.152880,Acc:0.000000
[1/100]loss:2.027013,Acc:0.000000
[1/100]loss:1.915277,Acc:0.000000
[1/100]loss:1.820753,Acc:0.000000
[1/100]loss:1.738571,Acc:0.000000
[1/100]loss:1.665880,Acc:0.000000
Finish1epoch,Loss:1.649044,Acc:0.000000




**********
epoch2
[2/100]loss:1.190464,Acc:0.000000
[2/100]loss:1.160336,Acc:0.000000
[2/100]loss:1.127994,Acc:0.000000
[2/100]loss:1.099509,Acc:0.000000
[2/100]loss:1.073042,Acc:0.000000
[2/100]loss:1.050043,Acc:0.000000
Finish2epoch,Loss:1.044058,Acc:0.000000
**********
epoch3
[3/100]loss:0.884864,Acc:0.000000
[3/100]loss:0.874184,Acc:0.000000
[3/100]loss:0.862356,Acc:0.000000
[3/100]loss:0.852163,Acc:0.000000
[3/100]loss:0.842104,Acc:0.000000
[3/100]loss:0.829861,Acc:0.000000
Finish3epoch,Loss:0.828113,Acc:0.000000
**********
epoch4
[4/100]loss:0.752042,Acc:0.000000
[4/100]loss:0.741598,Acc:0.000000
[4/100]loss:0.736399,Acc:0.000000
[4/100]loss:0.730298,Acc:0.000000
[4/100]loss:0.724972,Acc:0.000000
[4/100]loss:0.719925,Acc:0.000000
Finish4epoch,Loss:0.718831,Acc:0.000000
**********
epoch5
[5/100]loss:0.672201,Acc:0.000000
[5/100]loss:0.667591,Acc:0.000000
[5/100]loss:0.664328,Acc:0.000000
[5/100]loss:0.658842,Acc:0.000000
[5/100]loss:0.655960,Acc:0.000000
[5/100]loss:0.653215,Acc:0

[32/100]loss:0.381374,Acc:0.000000
[32/100]loss:0.381035,Acc:0.000000
Finish32epoch,Loss:0.381244,Acc:0.000000
**********
epoch33
[33/100]loss:0.385419,Acc:0.000000
[33/100]loss:0.386583,Acc:0.000000
[33/100]loss:0.382629,Acc:0.000000
[33/100]loss:0.376877,Acc:0.000000
[33/100]loss:0.378888,Acc:0.000000
[33/100]loss:0.379621,Acc:0.000000
Finish33epoch,Loss:0.378973,Acc:0.000000
**********
epoch34
[34/100]loss:0.378782,Acc:0.000000
[34/100]loss:0.376603,Acc:0.000000
[34/100]loss:0.375917,Acc:0.000000
[34/100]loss:0.375455,Acc:0.000000
[34/100]loss:0.374764,Acc:0.000000
[34/100]loss:0.375765,Acc:0.000000
Finish34epoch,Loss:0.376806,Acc:0.000000
**********
epoch35
[35/100]loss:0.379777,Acc:0.000000
[35/100]loss:0.378995,Acc:0.000000
[35/100]loss:0.377365,Acc:0.000000
[35/100]loss:0.375251,Acc:0.000000
[35/100]loss:0.374383,Acc:0.000000
[35/100]loss:0.374198,Acc:0.000000
Finish35epoch,Loss:0.374739,Acc:0.000000
**********
epoch36
[36/100]loss:0.370649,Acc:0.000000
[36/100]loss:0.373573,Acc

**********
epoch63
[63/100]loss:0.343477,Acc:0.000000
[63/100]loss:0.345209,Acc:0.000000
[63/100]loss:0.339625,Acc:0.000000
[63/100]loss:0.338366,Acc:0.000000
[63/100]loss:0.338855,Acc:0.000000
[63/100]loss:0.339670,Acc:0.000000
Finish63epoch,Loss:0.338915,Acc:0.000000
**********
epoch64
[64/100]loss:0.351980,Acc:0.000000
[64/100]loss:0.343978,Acc:0.000000
[64/100]loss:0.341346,Acc:0.000000
[64/100]loss:0.336055,Acc:0.000000
[64/100]loss:0.339086,Acc:0.000000
[64/100]loss:0.337797,Acc:0.000000
Finish64epoch,Loss:0.338109,Acc:0.000000
**********
epoch65
[65/100]loss:0.345744,Acc:0.000000
[65/100]loss:0.342865,Acc:0.000000
[65/100]loss:0.338217,Acc:0.000000
[65/100]loss:0.336562,Acc:0.000000
[65/100]loss:0.336090,Acc:0.000000
[65/100]loss:0.337923,Acc:0.000000
Finish65epoch,Loss:0.337301,Acc:0.000000
**********
epoch66
[66/100]loss:0.322124,Acc:0.000000
[66/100]loss:0.330206,Acc:0.000000
[66/100]loss:0.329029,Acc:0.000000
[66/100]loss:0.337218,Acc:0.000000
[66/100]loss:0.337716,Acc:0.000

[93/100]loss:0.320663,Acc:0.000000
[93/100]loss:0.319464,Acc:0.000000
[93/100]loss:0.318797,Acc:0.000000
Finish93epoch,Loss:0.320423,Acc:0.000000
**********
epoch94
[94/100]loss:0.329136,Acc:0.000000
[94/100]loss:0.324964,Acc:0.000000
[94/100]loss:0.323633,Acc:0.000000
[94/100]loss:0.322671,Acc:0.000000
[94/100]loss:0.323658,Acc:0.000000
[94/100]loss:0.321071,Acc:0.000000
Finish94epoch,Loss:0.319952,Acc:0.000000
**********
epoch95
[95/100]loss:0.310351,Acc:0.000000
[95/100]loss:0.320572,Acc:0.000000
[95/100]loss:0.322969,Acc:0.000000
[95/100]loss:0.322431,Acc:0.000000
[95/100]loss:0.319242,Acc:0.000000
[95/100]loss:0.319221,Acc:0.000000
Finish95epoch,Loss:0.319504,Acc:0.000000
**********
epoch96
[96/100]loss:0.321550,Acc:0.000000
[96/100]loss:0.324407,Acc:0.000000
[96/100]loss:0.325277,Acc:0.000000
[96/100]loss:0.320150,Acc:0.000000
[96/100]loss:0.319709,Acc:0.000000
[96/100]loss:0.318664,Acc:0.000000
Finish96epoch,Loss:0.319061,Acc:0.000000
**********
epoch97
[97/100]loss:0.321962,Acc

In [None]:
troch.save(model.state_dice(),'logr')