# How to replace pre-trained model layers with custom ones?

Replace pytorch built-in layers with custom ones may seem trivial at first glance. However, it turns out that we need to treat this seriously to avoid potential pitfalls.

The take-homes are we should use `set_submodule` method, and use `get_submodule` before we do the replacement to make sure we get things go as expected.

In this tutorial, we showcase the right way to replace pre-trained model layers with quatified ones, a quite common scenario when using INT8 quantization to deploy models.

## Naive and simple case

In [1]:
import torch
from torch import nn

from pytorch_quantization import tensor_quant
import pytorch_quantization.nn as quant_nn

in_features = 100
out_features = 200
in_channels = 3
out_channels = 4
kernel_size = 3

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(in_features, out_features, bias=True)
        self.fc2 = nn.Linear(out_features, in_features, bias=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


model_original = Model()
input1 = torch.randn(1000, in_features)
output1 = model_original(input1)

In [2]:
for n, m in model_original.named_modules():
    if isinstance(m, nn.Linear):
        # print(getattr(model_original, n)
        quant_linear = quant_nn.Linear(
            m.in_features, m.out_features, bias=m.bias is not None,
            quant_desc_input=tensor_quant.QUANT_DESC_8BIT_PER_TENSOR,
            quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
        )
        setattr(quant_nn, "weight", m.weight)
        setattr(quant_nn, "bias", m.bias)
        setattr(model_original, n, quant_linear)
print(model_original)
output2 = model_original(input1)
print(torch.allclose(output1, output2), (output1 - output2).mean())

Model(
  (fc1): QuantLinear(
    in_features=100, out_features=200, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (fc2): QuantLinear(
    in_features=200, out_features=100, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
)
False tensor(0.0054, grad_fn=<MeanBackward0>)


In [None]:
# error is not thrown, not complied with official doc.
model_original.set_submodule("fake", nn.Linear(10,10))
print(model_original)
model_original.get_submodule("fake1")

Model(
  (fc1): QuantLinear(
    in_features=100, out_features=200, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (fc2): QuantLinear(
    in_features=200, out_features=100, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (fake): Linear(in_features=10, out_features=10, bias=True)
)


AttributeError: Model has no attribute `fake1`

## more complex network
The above method using setattr does not work for more complex network.

In [None]:
class ComplexModel(nn.Module):
    def __init__(self):
        super(ComplexModel, self).__init__()
        self.backbone = Model()
        self.classifier = nn.Linear(20, 20)


In [5]:
import copy
model_original = ComplexModel()
print(model_original)
print('-'*100)
new_model = copy.deepcopy(model_original)
print(new_model)
print('-'*100)
for n, m in model_original.named_modules():
    if isinstance(m, nn.Linear):
        # print(getattr(model_original, n)
        quant_linear = quant_nn.Linear(
            m.in_features, m.out_features, bias=m.bias is not None,
            quant_desc_input=tensor_quant.QUANT_DESC_8BIT_PER_TENSOR,
            quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
        )
        setattr(quant_nn, "weight", m.weight)
        setattr(quant_nn, "bias", m.bias)
        setattr(new_model, n, quant_linear)
print(new_model)

ComplexModel(
  (backbone): Model(
    (fc1): Linear(in_features=100, out_features=200, bias=True)
    (fc2): Linear(in_features=200, out_features=100, bias=True)
  )
  (classifier): Linear(in_features=20, out_features=20, bias=True)
)
----------------------------------------------------------------------------------------------------
ComplexModel(
  (backbone): Model(
    (fc1): Linear(in_features=100, out_features=200, bias=True)
    (fc2): Linear(in_features=200, out_features=100, bias=True)
  )
  (classifier): Linear(in_features=20, out_features=20, bias=True)
)
----------------------------------------------------------------------------------------------------
ComplexModel(
  (backbone): Model(
    (fc1): Linear(in_features=100, out_features=200, bias=True)
    (fc2): Linear(in_features=200, out_features=100, bias=True)
  )
  (classifier): QuantLinear(
    in_features=20, out_features=20, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrato

### The replacement failed by using setattr method, because the way we reference sub-module is different.

In [10]:

# model_original['backbone.fc1'] # error
# getattr(model_original, 'backbone.fc1') # error
model_original.get_submodule('backbone.fc1') # correct

Linear(in_features=100, out_features=200, bias=True)

## Let's use set_submodule 

In [7]:
import copy
model_original = ComplexModel()
# print(model_original)
# print('-'*100)
new_model = copy.deepcopy(model_original)
# print(new_model)
# print('-'*100)
for n, m in model_original.named_modules():
    if isinstance(m, nn.Linear):
        # print(getattr(model_original, n)
        quant_linear = quant_nn.Linear(
            m.in_features, m.out_features, bias=m.bias is not None,
            quant_desc_input=tensor_quant.QUANT_DESC_8BIT_PER_TENSOR,
            quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
        )
        setattr(quant_nn, "weight", m.weight)
        setattr(quant_nn, "bias", m.bias)
        # NOTE: Make sure the name is correct
        _ = new_model.get_submodule(n)
        new_model.set_submodule(n, quant_linear)
print(new_model)

ComplexModel(
  (backbone): Model(
    (fc1): QuantLinear(
      in_features=100, out_features=200, bias=True
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (fc2): QuantLinear(
      in_features=200, out_features=100, bias=True
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
  )
  (classifier): QuantLinear(
    in_features=20, out_features=20, bias=True
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
)
