# Define Original Model

In [2]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.compression.pytorch.utils import count_flops_params
import time

from mnist_model import Net, train, test, device, optimizer_scheduler_generator

# define the model
model = Net().to(device)

# show the model stbructure, note that pruner will wrap the model layer.
print(model)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


### Pre-train model

In [2]:
# define the optimizer and criterion for pre-training
optimizer, scheduler = optimizer_scheduler_generator(model)

# pre-train and evaluate the model on MNIST dataset
total_epoch = 10

for epoch in range(1, total_epoch+1):
    train(model, device, optimizer=optimizer, epoch=epoch)
    test(model, device)
    scheduler.step()
    
torch.save(model, "mnist_cnn.pt")


Test set: Average loss: 0.0482, Accuracy: 9838/10000 (98.38%)




Test set: Average loss: 0.0335, Accuracy: 9895/10000 (98.95%)


Test set: Average loss: 0.0321, Accuracy: 9892/10000 (98.92%)




Test set: Average loss: 0.0336, Accuracy: 9897/10000 (98.97%)


Test set: Average loss: 0.0274, Accuracy: 9915/10000 (99.15%)




Test set: Average loss: 0.0267, Accuracy: 9915/10000 (99.15%)




Test set: Average loss: 0.0281, Accuracy: 9915/10000 (99.15%)


Test set: Average loss: 0.0268, Accuracy: 9913/10000 (99.13%)




Test set: Average loss: 0.0285, Accuracy: 9912/10000 (99.12%)


Test set: Average loss: 0.0267, Accuracy: 9919/10000 (99.19%)



### Performance and statistics of original model 

In [3]:
start = time.time()

pre_best_acc = test(model, device)
pre_test_time = time.time() - start

pre_flops, pre_params, _ = count_flops_params(model, torch.randn([3, 1, 28, 28]).to(device))
print(f'Pretrained model FLOPs {pre_flops/1e6:.2f} M, #Params: {pre_params/1e6:.2f}M, Accuracy: {pre_best_acc: .2f}%, Test-time: {pre_test_time: .4f}s')


Test set: Average loss: 0.0267, Accuracy: 9919/10000 (99.19%)

+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
| Index | Name  |  Type  |  Weight Shape  |    Input Size   |   Output Size   |  FLOPs   | #Params |
+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
|   0   | conv1 | Conv2d | (32, 1, 3, 3)  |  (3, 1, 28, 28) | (3, 32, 26, 26) |  194688  |   320   |
|   1   | conv2 | Conv2d | (64, 32, 3, 3) | (3, 32, 26, 26) | (3, 64, 24, 24) | 10616832 |  18496  |
|   2   | fc1   | Linear |  (128, 9216)   |    (3, 9216)    |     (3, 128)    | 1179648  | 1179776 |
|   3   | fc2   | Linear |   (10, 128)    |     (3, 128)    |     (3, 10)     |   1280   |   1290  |
+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
FLOPs total: 11992448
#Params total: 1199882
Pretrained model FLOPs 11.99 M, #Params: 1.20M, Accuracy:  99.19%, Test-time:  1.88