In [2]:
import onnxruntime as onnxrun
import onnx
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as trans
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import os
import pickle
ONNX_MODEL_PATH = 'vnncomp2021/benchmarks/mnistfc/mnist-net_256x2.onnx'
MODEL: onnx.ModelProto = onnx.load(ONNX_MODEL_PATH)

pydot_graph = GetPydotGraph(MODEL.graph, name=MODEL.graph.name, rankdir="TB",
                            node_producer=GetOpNodeProducer("docstring"))
pydot_graph.write_dot("graph.dot")

"""
Save immediate output to the graph
"""
SESS = onnxrun.InferenceSession(ONNX_MODEL_PATH)
original_output = [x.name for x in SESS.get_outputs()]
print(original_output)
for node in MODEL.graph.node:
    for output in node.output:
        if output not in original_output:
            print(output)
            MODEL.graph.output.extend([onnx.ValueInfoProto(name=output)])

#inference session
inf_session = onnxrun.InferenceSession(MODEL.SerializeToString())
outputs = [x.name for x in inf_session.get_outputs()]
inputs = [x.name for x in inf_session.get_inputs()]
print(inputs)
input_name = inputs[0]
print(input_name)

#load data
trns_norm = trans.ToTensor()
mnist_train = datasets.MNIST('datasets/MNIST/', train=True, download=False, transform=trns_norm)
train_loader = DataLoader(mnist_train, batch_size=50000)
images, labels = next(iter(train_loader))

#init pickle dict that contains all immediate values
immediate_values = {}
for output in outputs:
    immediate_values[output] = []

correct = 0
for idx in range(images.shape[0]):
    result = inf_session.run(outputs, {input_name: images[idx].numpy().reshape(1, 784, 1)})
    #save output and immediate values
    for o_idx in range(len(result)):
        immediate_values[outputs[o_idx]].append(result[o_idx].squeeze())
    # print(result)
    pred = np.argmax(result[0])
    if pred == labels[idx]:
        correct+=1
    

['12']
7
8
9
10
11
['0']
0


In [3]:
#repack immediate values to numpy 2d array
for output in immediate_values:
    immediate_values[output] = np.array(immediate_values[output])
    print(immediate_values[output].shape)

(50000, 10)
(50000, 784)
(50000, 256)
(50000, 256)
(50000, 256)
(50000, 256)


In [None]:
#save pickle
with open(os.path.basename(ONNX_MODEL_PATH)+".immediate_values.pickle", 'wb') as handle:
    pickle.dump(immediate_values, handle, protocol=pickle.HIGHEST_PROTOCOL)