In [1]:
import Ipynb_importer
from a_basic_quant import *
from b_model import *
# from ipywidgets import IntProgress

importing Jupyter notebook from a_basic_quant.ipynb
importing Jupyter notebook from b_model.ipynb


In [2]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import os

### 1、定义训练

In [3]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (datas, targets) in enumerate(train_loader):
        datas, targets = datas.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(datas)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss
        pred = outputs.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
        
    train_loss /= len(train_loader.dataset)
        
    print('Train Epoch:{} \t  Average loss: {:.4f}, Accuracy: {:.0f}%'.format(
        epoch, train_loss, 100. * correct / len(train_loader.dataset)
    ))

### 2、定义测试

In [4]:
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (datas, targets) in enumerate(test_loader):
        datas, targets = datas.to(device), targets.to(device)
        outputs = model(datas)
        loss = criterion(outputs, targets)
        
        test_loss += loss.item()
        pred = outputs.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
        
    test_loss /= len(test_loader.dataset)
    
    print('Test set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
        test_loss, 100. * correct / len(test_loader.dataset)
    ))

### 3、定义数据集和loader

In [5]:
def dataset_loader(batch_size, test_batch_size):
    train_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    test_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
    
    train_dataset = datasets.MNIST(r'C:\Users\xia\Documents\datasets', train=True, download=True,
                                  transform=train_transform)
    test_dataset = datasets.MNIST(r'C:\Users\xia\Documents\datasets', train=False, download=True,
                                 transform=test_transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, num_workers=1)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                             shuffle=False, num_workers=1)
    
    return train_loader, test_loader

In [6]:
def main():
    batch_size = 128
    test_batch_size = 64
    seed = 1
    epochs = 15
    lr = 0.01
    momentum = 0.8
    save_model = True
    using_bn = False
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if device == 'cuda':
        torch.cuda.manual_seed(seed)
    else:
        torch.manual_seed(seed)
    
    # 制作 loader
    train_loader, test_loader = dataset_loader(batch_size, test_batch_size)
    
    if using_bn:
        model = NetBN().to(device)
    else:
        model = Net().to(device)
        
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        train(model, device, train_loader, criterion, optimizer, epoch)
        test(model, device, test_loader, criterion)
    
    if save_model:
        if not os.path.exists('ckpt'):
            os.makedirs('ckpt')
        if using_bn:
            torch.save(model.state_dict(), 'ckpt/mnist_cnnbn.pt')
        else:
            torch.save(model.state_dict(), 'ckpt/mnist_cnn.pt')

In [7]:
if __name__ == "__main__":
    main()

Train Epoch:0 	  Average loss: 0.0028, Accuracy: 89
Test set: Average loss: 0.0022, Accuracy: 96%

Train Epoch:1 	  Average loss: 0.0010, Accuracy: 96
Test set: Average loss: 0.0015, Accuracy: 97%

Train Epoch:2 	  Average loss: 0.0008, Accuracy: 97
Test set: Average loss: 0.0012, Accuracy: 97%

Train Epoch:3 	  Average loss: 0.0007, Accuracy: 97
Test set: Average loss: 0.0011, Accuracy: 98%



KeyboardInterrupt: 