In [8]:
import Ipynb_importer
from a_basic_quant import *
from b_model import *
# from ipywidgets import IntProgress

In [9]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import os

### 1、定义训练

In [10]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (datas, targets) in enumerate(train_loader):
        datas, targets = datas.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(datas)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss
        pred = outputs.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
        
    train_loss /= len(train_loader.dataset)
        
    print('Train Epoch:{} \t  Average loss: {:.4f}, Accuracy: {}/{}'.format(
        epoch, train_loss, correct , len(train_loader.dataset)
    ))

### 2、定义测试

In [11]:
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (datas, targets) in enumerate(test_loader):
        datas, targets = datas.to(device), targets.to(device)
        outputs = model(datas)
        loss = criterion(outputs, targets)
        
        test_loss += loss.item()
        pred = outputs.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
        
    test_loss /= len(test_loader.dataset)
    
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{}'.format(
        test_loss,  correct ,  len(test_loader.dataset)
    ))

### 3、定义数据集和loader

In [12]:
def dataset_loader(batch_size, test_batch_size):
    train_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    test_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    
    train_dataset = datasets.MNIST('/home/xia/Dataset', train=True, download=True,
                                  transform=train_transform)
    test_dataset = datasets.MNIST('/home/xia/Dataset', train=False, download=True,
                                 transform=test_transform)  # r'C:\Users\xia\Documents\datasets'
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, num_workers=1)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                             shuffle=False, num_workers=1)
    
    return train_loader, test_loader

In [6]:
def main():
    batch_size = 256
    test_batch_size = 64
    seed = 1
    epochs = 25
    lr = 0.01
    momentum = 0.7
    save_model = True
    using_bn = True
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if device == 'cuda':
        torch.cuda.manual_seed(seed)
    else:
        torch.manual_seed(seed)
    
    # 制作 loader
    train_loader, test_loader = dataset_loader(batch_size, test_batch_size)
    
    if using_bn:
        model = NetBN().to(device)
    else:
        model = Net().to(device)
        
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        if epoch == 18:     lr = 0.001
        train(model, device, train_loader, criterion, optimizer, epoch)
        test(model, device, test_loader, criterion)
    
    if save_model:
        if not os.path.exists('ckpt'):
            os.makedirs('ckpt')
        if using_bn:
            torch.save(model.state_dict(), 'ckpt/mnist_cnnbn.pt')
        else:
            torch.save(model.state_dict(), 'ckpt/mnist_cnn.pt')

In [7]:
if __name__ == "__main__":
    main()

Train Epoch:0 	  Average loss: 0.0010, Accuracy: 56085/60000
Test set: Average loss: 0.0016, Accuracy: 9731/10000
Train Epoch:1 	  Average loss: 0.0003, Accuracy: 58555/60000
Test set: Average loss: 0.0011, Accuracy: 9788/10000
Train Epoch:2 	  Average loss: 0.0003, Accuracy: 58858/60000
Test set: Average loss: 0.0010, Accuracy: 9819/10000
Train Epoch:3 	  Average loss: 0.0002, Accuracy: 59031/60000
Test set: Average loss: 0.0008, Accuracy: 9842/10000
Train Epoch:4 	  Average loss: 0.0002, Accuracy: 59151/60000
Test set: Average loss: 0.0008, Accuracy: 9854/10000
Train Epoch:5 	  Average loss: 0.0002, Accuracy: 59204/60000
Test set: Average loss: 0.0008, Accuracy: 9847/10000
Train Epoch:6 	  Average loss: 0.0002, Accuracy: 59283/60000
Test set: Average loss: 0.0007, Accuracy: 9858/10000
Train Epoch:7 	  Average loss: 0.0002, Accuracy: 59327/60000
Test set: Average loss: 0.0007, Accuracy: 9871/10000
Train Epoch:8 	  Average loss: 0.0001, Accuracy: 59385/60000
Test set: Average loss: 0.0

In [22]:
train_loader, test_loader = dataset_loader(64, 64)
for index,item in enumerate(train_loader):
    datas,_ = item
    print(datas.shape)
    break

torch.Size([64, 1, 28, 28])
