In [102]:
import torch

"""Debug operator"""
class DebugOp(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, name):
        return x
    
    @staticmethod
    def symbolic(g, x, name):
        return g.op("my::Debug", x, name_s=name)

debug_apply = DebugOp.apply

In [103]:
import onnx
import onnxruntime

"""Debugger"""
class Debugger():
    
    def __init__(self) -> None:
        super().__init__()
        self.torch_value = dict()
        self.onnx_value = dict()
        self.output_debug_name = []
    
    def debug(self, x, name):
        self.torch_value[name] = x.detach().cpu().numpy()
        return debug_apply(x, name)
    
    def extract_debug_model(self, input_path, output_path):
        """Extract debug nodes from onnx model and save to new model"""
        model = onnx.load(input_path)
        input_names = [input.name for input in model.graph.input]
        # output_names = [output.name for output in model.graph.output]
        output_names = []

        for node in model.graph.node:
            if node.op_type == "Debug":
                self.output_debug_name.append(node.attribute[0].s.decode("utf-8"))
                output_names.append(node.output[0])
                
                node.op_type = "Identity"
                node.domain = ""
                node.ClearField("attribute")
        
        e = onnx.utils.Extractor(model)
        extracted_model = e.extract_model(input_names, output_names)
        onnx.save(extracted_model, output_path)

    def run_debug_model(self, input, debug_model):
        sess = onnxruntime.InferenceSession(debug_model, providers = ['CPUExecutionProvider'])
        onnx_outputs = sess.run(None, input)

        for name, value in zip(self.output_debug_name, onnx_outputs):
            self.onnx_value[name] = value

    def print_debug_result(self):
        for name in self.output_debug_name:
            print(f"Debug node name: {name}")
            # print(f"Pytorch value: {self.torch_value[name]}")
            # print(f"Onnx value: {self.onnx_value[name]}")
            print(f"MSE: {((self.torch_value[name] - self.onnx_value[name])**2).mean()}")
            print("\n")


In [104]:
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 16, 3), 
                                          torch.nn.Conv2d(16, 16, 3),
                                          torch.nn.Conv2d(16, 16, 3))
        self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(16, 16, 3), 
                                          torch.nn.Conv2d(16, 16, 3),
                                          torch.nn.Conv2d(16, 16, 3))
        self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(16, 16, 3),
                                            torch.nn.Conv2d(16, 16, 3),
                                            torch.nn.Conv2d(16, 3, 3))
    
    def forward(self, x):
        x = self.convs1(x)
        x = self.convs2(x)
        x = self.convs3(x)
        return x

In [105]:
model = Model()
debugger = Debugger()

In [106]:
from types import MethodType

def new_forward(self, x):
    x = self.convs1(x)
    x = debugger.debug(x, "x_1")
    x = self.convs2(x)
    x = debugger.debug(x, "x_2")
    x = self.convs3(x)
    return x

model.forward = MethodType(new_forward, model)

In [107]:
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "before_debug.onnx", 
                  input_names=["input"], output_names=["output"])

  self.torch_value[name] = x.detach().cpu().numpy()


In [108]:
debugger.extract_debug_model("before_debug.onnx", "after_debug.onnx")

In [109]:
debugger.run_debug_model({"input": dummy_input.numpy()}, "after_debug.onnx")

In [110]:
debugger.print_debug_result()

Debug node name: x_1
MSE: 5.233316408128968e-15


Debug node name: x_2
MSE: 7.869065781912498e-16


