<a href="https://colab.research.google.com/github/gulabpatel/Knowledge_Distillation/blob/main/ConvertToTorchScript.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchstat

In [None]:
import os
import torch
import torchvision
model = torchvision.models.resnet18()
torch.save(model.state_dict(), 'resnet18.pt')

In [None]:
# Regarding the number of the parameters in PyTorch you can use:
sum(p.numel() for p in model.parameters())

11689512

In [None]:
file_size = os.path.getsize('/content/resnet18.pt')
print("File Size is :", file_size/1048576, "MB")

File Size is : 44.66516971588135 MB


In [None]:
model.load_state_dict(torch.load("/content/resnet18.pt"))
model.eval()
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
# print(quantized_model)
torch.save(quantized_model.state_dict(), 'quantized_modelv1.pt')

file_size = os.path.getsize('/content/quantized_modelv1.pt')
print("File Size is :", file_size/1048576, "MB")

File Size is : 43.20219898223877 MB


In [None]:
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))
# output

In [None]:
# traced_script_module.save("traced_resnet_model.pt")
torch.save(traced_script_module.state_dict(), 'traced_resnet_model.pt')

In [None]:
file_size = os.path.getsize('/content/traced_resnet_model.pt')
print("File Size is :", file_size/1048576, "MB")

File Size is : 44.666470527648926 MB


In [None]:
model.load_state_dict(torch.load("/content/traced_resnet_model.pt"))
model.eval()
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
# print(quantized_model)
torch.save(quantized_model.state_dict(), 'quantized_modelv2.pt')

file_size = os.path.getsize('/content/quantized_modelv2.pt')
print("File Size is :", file_size/1048576, "MB")

File Size is : 43.20219898223877 MB


In [None]:
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

In [None]:
import copy
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torchvision.models import resnet50
fp32_model = resnet50().eval()
model = copy.deepcopy(fp32_model)
# `qconfig` means quantization configuration, it specifies how should we
# observe the activation and weight of an operator
# `qconfig_dict`, specifies the `qconfig` for each operator in the model
# we can specify `qconfig` for certain types of modules
# we can specify `qconfig` for a specific submodule in the model
# we can specify `qconfig` for some functioanl calls in the model
# we can also set `qconfig` to None to skip quantization for some operators
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
# `prepare_fx` inserts observers in the model based on the configuration in `qconfig_dict`
model_prepared = prepare_fx(model, qconfig_dict, torch.rand(1, 3, 224, 224))
# calibration runs the model with some sample data, which allows observers to record the statistics of
# the activation and weigths of the operators
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
for i in range(len(calibration_data)):
   model_prepared(calibration_data[i])
# `convert_fx` converts a calibrated model to a quantized model, this includes inserting
# quantize, dequantize operators to the model and swap floating point operators with quantized operators
model_quantized = convert_fx(copy.deepcopy(model_prepared))
# benchmark
x = torch.randn(1, 3, 224, 224)
%timeit fp32_model(x)
%timeit model_quantized(x)



221 ms ± 6.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
112 ms ± 680 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

In [None]:
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
sm

RecursiveScriptModule(original_name=MyModule)

In [None]:
sm.save("traced_resnet_modelv2.pt")

In [None]:
# Regarding the number of the parameters in PyTorch you can use:
sum(p.numel() for p in sm.parameters())

200

In [None]:
file_size = os.path.getsize('/content/traced_resnet_modelv2.pt')
print("File Size is :", file_size, "bytes")

File Size is : 2970 bytes


In [None]:
print("save model...")              
m = torch.jit.script(sm)
with torch.no_grad() :
    m.eval()
    torch.save(m.state_dict(), 'freeze_model.pt')

save model...


In [None]:
import torch
m.load_state_dict(torch.load("/content/freeze_model.pt"))
sm.eval()
quantized_model = torch.quantization.quantize_dynamic(sm, {torch.nn.Linear}, dtype=torch.qint8)
# print(quantized_model)
torch.save(quantized_model.state_dict(), 'quantized_modelv2.pt')

In [None]:
file_size = os.path.getsize('/content/quantized_modelv2.pt')
print("File Size is :", file_size, "bytes")

File Size is : 1609 bytes
