In [1]:
import torch
import torchvision.models as models
import onnx

from torchvision.io import read_image

In [2]:
def fetch_model(name):
    #model = getattr(models, name)(pretrained=True)
    if name == "resnet18":
        default_weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=default_weights)    
        model.eval()
        return model, default_weights

model, weights = fetch_model("resnet18")

In [3]:
def save_to_jit(model, name="resnet18_torchscript.pt"):
    model_scripted = torch.jit.script(model) # Export to TorchScript
    model_scripted.save(name) # Save

    
def save_to_pt(model, name="resnet18.pt"):
    torch.save(model.state_dict(), name)

    
def save_to_onnx(model, weights, name="resnet18.onnx"):

    img = read_image("/Users/zjy/Downloads/511px-Grace_Hopper.jpg")

    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    torch.onnx.export(model, batch, name, verbose=True)


In [4]:
# Saves all the tensors to disk in binary format with all float32s 
# serialized in little endianness order.

import csv, re, uuid
import os.path
from itertools import chain
from onnx import TensorProto

def _is_valid_filename(filename):  # type: (Text) -> bool
    """Utility to check whether the provided filename is valid."""
    exp = re.compile("^[^<>:;,?\"*|/]+$")
    match = exp.match(filename)
    if match:
        return True
    else:
        return False


def _sanitize_tensor_name(name):
    return "Tensor_" + name.removeprefix("onnx::")

    
def _get_initializer_tensors(onnx_model_proto):  # type: (ModelProto) -> Iterable[TensorProto]
    """Create an iterator of initializer tensors from ONNX model."""
    for initializer in onnx_model_proto.graph.initializer:
        yield initializer


def _get_attribute_tensors(onnx_model_proto):  # type: (ModelProto) -> Iterable[TensorProto]
    """Create an iterator of tensors from node attributes of an ONNX model."""
    for node in onnx_model_proto.graph.node:
        for attribute in node.attribute:
            if attribute.HasField("t"):
                yield _massage_tensor_name(node, attribute.t)
            for tensor in attribute.tensors:
                yield _massage_tensor_name(node, tensor)

                
def _get_all_tensors(onnx_model_proto):  # type: (ModelProto) -> Iterable[TensorProto]
    """Scan an ONNX model for all tensors and return as an iterator."""
    return chain(_get_initializer_tensors(onnx_model_proto),
                 _get_attribute_tensors(onnx_model_proto))

def _dims_to_str(dims):
    return ",".join(str(dim) for dim in dims)
    
    
def dump_all_tensors_to_disk(onnx_model, location, metadata_file="tensor_metadata.csv"):
    assert os.path.isdir(location)
    
    metadata = []
    for tensor in _get_all_tensors(onnx_model):
        tensor_name = _sanitize_tensor_name(tensor.name)
        assert _is_valid_filename(tensor_name)
        tensor_attr = {}
        tensor_attr["name"] = tensor_name
        tensor_attr["dims"] = _dims_to_str(tensor.dims)
        metadata.append(tensor_attr)
        with open(os.path.join(location, "tensors", tensor_name), "wb") as f:
            f.write(tensor.raw_data)
    
    # write tensor dimensions to disk
    with open(os.path.join(location, "metadata", metadata_file), "w") as f:
        dict_writer = csv.DictWriter(f, metadata[0].keys())
        dict_writer.writeheader()
        for row in metadata:
            dict_writer.writerow(row)
        

In [18]:
# Collects and dumps metadata about the computation graph

from onnx import AttributeProto

def collapse_items(items):
    """collapse repeated items in a list into a single item"""
    assert len(set(items)) == 1, items
    return items[0]
    

def collect_conv_attributes(conv_node):
    """returns key attributes of a convolution node into a dict"""
    assert conv_node.op_type == "Conv", conv_node    
    attributes_dict = {}
    for attr in conv_node.attribute:
        if attr.name in ("pads", "strides"):
            assert attr.type == AttributeProto.INTS
            attributes_dict[attr.name] = collapse_items(attr.ints)
    return attributes_dict
    
    
def collect_tensor_names(node):
    """returns tensor names of a convolution or FC/Gemm node into a dict"""
    if node.op_type == "Conv":
        _, t1, t2 = node.input
        assert t1.startswith("onnx::") and t2.startswith("onnx::"), node.input
    elif node.op_type == "Gemm":
        _, t1, t2 = node.input
        assert t1.startswith("fc") and t2.startswith("fc"), node.input    
    else:
        return {}
    return {"W": _sanitize_tensor_name(t1), "B": _sanitize_tensor_name(t2)}
    

def collect_graph_info(onnx_model_proto):
    """collect onnx graph info into a list of node info, where node info is a dict of key attributes of a node"""
    node_info = []
    for node in onnx_model_proto.graph.node:
        node_attributes = {}
        node_attributes["name"] = node.name
        node_attributes["type"] = node.op_type
        node_attributes["input"] = _sanitize_tensor_name(node.input[0])
        node_attributes["output"] = _sanitize_tensor_name(collapse_items(node.output))

        #print(f"Node name {node.name} with input {node.input} and output {node.output}")
        if node.op_type == "Conv":
            node_attributes.update(collect_conv_attributes(node))
        node_attributes.update(collect_tensor_names(node))
        node_info.append(node_attributes)
    return node_info


def write_graph_info(graph_info, model_dump_location, metadata_name="graph_metadata.csv"):
    with open(os.path.join(model_dump_location, "metadata", metadata_name), "w") as f:
        assert graph_info[0]["type"] == "Conv"
        dict_writer = csv.DictWriter(f, graph_info[0].keys())
        dict_writer.writeheader()
        for row in graph_info:
            dict_writer.writerow(row)



In [19]:
collect_graph_info(onnx_model_proto)

[{'name': 'Conv_0',
  'type': 'Conv',
  'input': 'Tensor_input.1',
  'output': 'Tensor_input.4',
  'pads': 3,
  'strides': 2,
  'W': 'Tensor_Conv_193',
  'B': 'Tensor_Conv_194'},
 {'name': 'Relu_1',
  'type': 'Relu',
  'input': 'Tensor_input.4',
  'output': 'Tensor_MaxPool_125'},
 {'name': 'MaxPool_2',
  'type': 'MaxPool',
  'input': 'Tensor_MaxPool_125',
  'output': 'Tensor_input.8'},
 {'name': 'Conv_3',
  'type': 'Conv',
  'input': 'Tensor_input.8',
  'output': 'Tensor_input.16',
  'pads': 1,
  'strides': 1,
  'W': 'Tensor_Conv_196',
  'B': 'Tensor_Conv_197'},
 {'name': 'Relu_4',
  'type': 'Relu',
  'input': 'Tensor_input.16',
  'output': 'Tensor_Conv_129'},
 {'name': 'Conv_5',
  'type': 'Conv',
  'input': 'Tensor_Conv_129',
  'output': 'Tensor_Add_198',
  'pads': 1,
  'strides': 1,
  'W': 'Tensor_Conv_199',
  'B': 'Tensor_Conv_200'},
 {'name': 'Add_6',
  'type': 'Add',
  'input': 'Tensor_Add_198',
  'output': 'Tensor_Relu_132'},
 {'name': 'Relu_7',
  'type': 'Relu',
  'input': 'Tens

In [20]:
MODEL_DUMP_DIR = "model_dump/resnet18"
onnx_model_proto = onnx.load(os.path.join(MODEL_DUMP_DIR, "original_onnx/production_resnet18.onnx"))
dump_all_tensors_to_disk(onnx_model_proto, MODEL_DUMP_DIR)
graph_info = collect_graph_info(onnx_model_proto)
write_graph_info(graph_info, MODEL_DUMP_DIR)

!{ls model_dump/resnet18}


README.txt    [34mmetadata[m[m      [34moriginal_onnx[m[m [34mtensors[m[m
