In [1]:
import torch
import torch.quantization
import torch.nn as nn
import copy
import os 
import time

## 1. Set up

In [2]:
class lstm_for_demonstration(nn.Module):
    
    def __init__(self, in_dim, out_dim, depth):
        super(lstm_for_demonstration, self).__init__()
        self.lstm = nn.LSTM(in_dim, out_dim, depth)
        
    def forward(self, inputs, hidden):
        out, hidden = self.lstm(inputs, hidden)
        return out, hidden

In [5]:
torch.manual_seed(42)

model_dimension = 8
sequence_length = 20
batch_size = 1
lstm_depth = 1 

# random data for input
inputs = torch.randn(sequence_length, batch_size, model_dimension)
# hidden is actually is a tuple of the initial hidden state and the initial cell state
hidden = (torch.randn(lstm_depth, batch_size, model_dimension),
         torch.randn(lstm_depth, batch_size, model_dimension))

## 2. Quantization

In [10]:
float_lstm = lstm_for_demonstration(model_dimension, model_dimension,lstm_depth)

quantized_lstm = torch.quantization.quantize_dynamic(
    float_lstm, {nn.LSTM, nn.Linear}, dtype=torch.qint8)
    
print('Here is the floating point version of this module:')
print(float_lstm)
print('')
print('and now the quantized version:')
print(quantized_lstm)

Here is the floating point version of this module:
lstm_for_demonstration(
  (lstm): LSTM(8, 8)
)

and now the quantized version:
lstm_for_demonstration(
  (lstm): DynamicQuantizedLSTM(8, 8)
)


## 3. Size

In [11]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ", label,' \t','Size (KB):', size/1e3)
    os.remove("temp.p")
    return size

f = print_size_of_model(float_lstm, "fp32")
q = print_size_of_model(quantized_lstm, "int8")
print("{0:.2f} times smaller".format(f/q))

model:  fp32  	 Size (KB): 3.111
model:  int8  	 Size (KB): 1.861
1.67 times smaller


## 4. Latency

In [13]:
# compare the performance
print("Floating point FP32")
%timeit float_lstm.forward(inputs, hidden)

print("Quantized INT8")
%timeit quantized_lstm.forward(inputs,hidden)

Floating point FP32
1.31 ms ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Quantized INT8
467 µs ± 4.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## 5. Accuracy

In [15]:
# run the float model
out1, hidden1 = float_lstm(inputs, hidden)
mag1 = torch.mean(abs(out1)).item()
print('mean absolute value of output tensor values in the FP32 model is {0:.5f} '.format(mag1))

# run the quantized model
out2, hidden2 = quantized_lstm(inputs, hidden)
mag2 = torch.mean(abs(out2)).item()
print('mean absolute value of output tensor values in the INT8 model is {0:.5f}'.format(mag2))

# compare them
mag3 = torch.mean(abs(out1-out2)).item()
print('mean absolute value of the difference between the output tensors is {0:.5f} or {1:.2f} percent'.format(mag3,mag3/mag1*100))

mean absolute value of output tensor values in the FP32 model is 0.14986 
mean absolute value of output tensor values in the INT8 model is 0.14672
mean absolute value of the difference between the output tensors is 0.01802 or 12.02 percent
