In [2]:
import torchvision
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
from ptflops import get_model_complexity_info
import torch.nn.utils.prune as prune

# 构建 transform
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])

# 加载数据
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

data_test = datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)

# 创建数据 loader
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size = 64,
                                                shuffle = True)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size = 128,
                                               shuffle = True)

use_gpu = torch.cuda.is_available()

class Model(torch.nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,4,kernel_size=3,stride=1,padding=1), # in 28*28*1 out 28*28*4
                                         torch.nn.ReLU(),
                                         torch.nn.MaxPool2d(stride=2,kernel_size=2)) # out 14*14*4
        
        self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(4,8,kernel_size=3,stride=1,padding=0), # in 14*14*4 out 12*12*8
                                         torch.nn.ReLU(),
                                         torch.nn.MaxPool2d(stride=2,kernel_size=2)) # out 6*6*8
        self.conv3 = torch.nn.Sequential(torch.nn.Conv2d(8,16,kernel_size=3,stride=1,padding=0), # in 6*6*8 out 4*4*16
                                         torch.nn.ReLU(),
                                         torch.nn.MaxPool2d(stride=2,kernel_size=2)) # out 2*2*16
        self.conv4 = torch.nn.Sequential(torch.nn.Conv2d(16,10,kernel_size=1,stride=1,padding=0), # in 2*2*16 out 2*2*10
                                         torch.nn.AvgPool2d(stride=2,kernel_size=2)) # out 1*1*10
#         self.dense = torch.nn.Sequential(torch.nn.Dropout(p=0.25),
#                                          torch.nn.Linear(1*1*32, 10),
#                                          torch.nn.Softmax())
#         self.dense = torch.nn.Sequential(torch.nn.Softmax())
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 1*1*10)
#         x = self.dense(x)
        return x
    
model = Model()
cost = torch.nn.CrossEntropyLoss()
if(use_gpu):
    model = model.cuda()
    cost = cost.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
###########################################################################################
# list(model.named_parameters())
# list(model.named_buffers())
# print(next(model.conv1.modules())[0])
parameters_to_prune = []
for module in model.modules():
    if hasattr(module, 'weight'):
        parameters_to_prune.append((module, 'weight'))
# parameters_to_prune = (
#     (next(model.conv1.modules())[0], 'weight'),
#     (next(model.conv2.modules())[0], 'weight'),
#     (next(model.conv3.modules())[0], 'weight'),
#     (next(model.conv4.modules())[0], 'weight'),
# )
        
# prune.global_unstructured(
#     tuple(parameters_to_prune),
#     pruning_method=prune.L1Unstructured,
#     amount=0.2,
# )

# ops, params = get_model_complexity_info(model, (1, 28, 28), as_strings=True, print_per_layer_stat=True, verbose=True)

# print(ops, params)

n_epochs = 10

for epoch in range(n_epochs):
    running_loss = 0.0
    running_correct = 0
    print("Epoch {}/{}".format(epoch, n_epochs))
    print("-"*10)
    for data in data_loader_train:
        X_train, y_train = data
        X_train, y_train = Variable(X_train), Variable(y_train)
        if (use_gpu):
            X_train,y_train = X_train.cuda(),y_train.cuda()
        outputs = model(X_train)
        if(use_gpu):
            outputs = outputs.cpu()
        _,pred = torch.max(outputs.data, 1)
        optimizer.zero_grad()
        loss = cost(outputs, y_train.cpu())
        
        loss.backward()
        optimizer.step()
        running_loss += loss.data.item()
        running_correct += torch.sum(pred == y_train.cpu().data)
    testing_correct = 0
    for data in data_loader_test:
        X_test, y_test = data
        X_test, y_test = Variable(X_test), Variable(y_test)
        if (use_gpu):
            X_test,y_test = X_test.cuda(),y_test.cuda()
        outputs = model(X_test)
        if (use_gpu):
            outputs = outputs.cpu()
        _, pred = torch.max(outputs.data, 1)
        testing_correct += torch.sum(pred == y_test.cpu().data)
    print("Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}%".format(1.*running_loss/len(data_train),
                                                                                      100.*running_correct/len(data_train),
                                                                                      100.*testing_correct/len(data_test)))
torch.save(model.state_dict(), "model_parameter.pkl")

print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(next(model.conv1.modules())[0].weight == 0)
            + torch.sum(next(model.conv2.modules())[0].weight == 0)
            + torch.sum(next(model.conv3.modules())[0].weight == 0)
            + torch.sum(next(model.conv4.modules())[0].weight == 0)
        )
        / float(
            next(model.conv1.modules())[0].weight.nelement()
            + next(model.conv2.modules())[0].weight.nelement()
            + next(model.conv3.modules())[0].weight.nelement()
            + next(model.conv4.modules())[0].weight.nelement()
        )
    )
)
# print(model.state_dict())
dummy_input = torch.randn(1, 1, 28, 28, device='cuda')
torch.onnx.export(model, dummy_input, "mnist.onnx", verbose=False, input_names=['input_1'], output_names=['output_1'])
print("train ok")

Epoch 0/10
----------
Loss is:0.0074, Train Accuracy is:84.1600%, Test Accuracy is:93.2200%
Epoch 1/10
----------
Loss is:0.0031, Train Accuracy is:93.9750%, Test Accuracy is:95.2500%
Epoch 2/10
----------
Loss is:0.0026, Train Accuracy is:94.8800%, Test Accuracy is:95.7400%
Epoch 3/10
----------
Loss is:0.0024, Train Accuracy is:95.2917%, Test Accuracy is:95.3900%
Epoch 4/10
----------
Loss is:0.0022, Train Accuracy is:95.5783%, Test Accuracy is:95.9700%
Epoch 5/10
----------
Loss is:0.0021, Train Accuracy is:95.7350%, Test Accuracy is:96.2000%
Epoch 6/10
----------
Loss is:0.0021, Train Accuracy is:95.9117%, Test Accuracy is:96.5300%
Epoch 7/10
----------
Loss is:0.0020, Train Accuracy is:96.0700%, Test Accuracy is:95.7600%
Epoch 8/10
----------
Loss is:0.0020, Train Accuracy is:96.1483%, Test Accuracy is:95.6900%
Epoch 9/10
----------
Loss is:0.0019, Train Accuracy is:96.2300%, Test Accuracy is:96.6400%
Global sparsity: 0.00%
train ok
