출처 : pytorch 튜트리얼

In [2]:
import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time

In [11]:
class lstm_for_demonstration(nn.Module):
    def __init__(self,in_dim,out_dim,depth):
        super(lstm_for_demonstration,self).__init__()
        self.lstm = nn.LSTM(in_dim,out_dim,depth)

    def forward(self,inputs,hidden):
        out,hidden = self.lstm(inputs,hidden)
        return out, hidden


In [15]:
torch.manual_seed(29592)

model_dimension=8
sequence_length=20
batch_size=1
lstm_depth=1

inputs = torch.randn(sequence_length,batch_size,model_dimension)
hidden = (torch.randn(lstm_depth,batch_size,model_dimension), torch.randn(lstm_depth,batch_size,model_dimension))

In [17]:
float_lstm = lstm_for_demonstration(model_dimension, model_dimension, lstm_depth)

quantized_lstm =  torch.quantization.quantize_dynamic(
    float_lstm, {nn.LSTM, nn.Linear}, dtype = torch.qint8
)

print('Here is the floating point version of this module:')
print(float_lstm)
print('')
print('and now the quantized version:')
print(quantized_lstm)

Here is the floating point version of this module:
lstm_for_demonstration(
  (lstm): LSTM(8, 8)
)

and now the quantized version:
lstm_for_demonstration(
  (lstm): DynamicQuantizedLSTM(8, 8)
)


In [18]:
def print_size_of_model(model, label = ""):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

f=print_size_of_model(float_lstm,"fp32")
q=print_size_of_model(quantized_lstm,"int8")
print("{0:.2f} times smaller".format(f/q))

model:  fp32  	 Size (KB): 3.743
model:  int8  	 Size (KB): 2.719
1.38 times smaller


In [19]:
print("Floating point FP32")
%timeit float_lstm.forward(inputs, hidden)

Floating point FP32
The slowest run took 70.56 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 950 µs per loop


In [20]:
print("Quantized INT8")
%timeit quantized_lstm.forward(inputs,hidden)

Quantized INT8
The slowest run took 25.55 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 621 µs per loop


In [21]:
out1, hidden1 = float_lstm(inputs, hidden)
mag1 = torch.mean(abs(out1)).item()
print('mean absolute value of output tensor values in the FP32 model is {0:.5f} '.format(mag1))

out2, hidden2 = quantized_lstm(inputs, hidden)
mag2 = torch.mean(abs(out2)).item()
print('mean absolute value of output tensor values in the INT8 model is {0:.5f}'.format(mag2))

mag3 = torch.mean(abs(out1-out2)).item()
print('mean absolute value of the difference between the output tensors is {0:.5f} or {1:.2f} percent'.format(mag3,mag3/mag1*100))

mean absolute value of output tensor values in the FP32 model is 0.12887 
mean absolute value of output tensor values in the INT8 model is 0.12912
mean absolute value of the difference between the output tensors is 0.00156 or 1.21 percent
