In [23]:
import torch
from tvm import relay
import tvm
from collections import namedtuple



In [24]:
Workload = namedtuple(
    "Conv2DWorkload",
    [
        "batch",
        "height",
        "width",
        "in_filter",
        "out_filter",
        "hkernel",
        "wkernel",
        "hpad",
        "wpad",
        "hstride",
        "wstride",
    ],
)

In [25]:
import re
channels_re = re.compile('.*Tensor\[\(([\d]+), ([\d]+), [\d]+, [\d]+\).*padding.*Tensor\[\([\d]+, [\d]+, ([\d]+), ([\d]+)\).*')
cast_re = re.compile('cast.*Tensor\[\([\d]+, [\d]+, ([\d]+), ([\d]+)\).*')

In [26]:
final_final = []

In [35]:
def extract_wkls(model):
    count=0
    pytorch_model = torch.hub.load('pytorch/vision', model, pretrained=True)
    pytorch_model.eval()
    workloads = []
    for layer in pytorch_model.modules():
        if type(layer) == torch.nn.modules.conv.Conv2d:
            if(layer.in_channels % 16 == 0 and layer.out_channels % 16 ==0 and layer.padding[0] == layer.padding[1]):
                workloads.append(Workload(1, 0, 0, layer.in_channels, layer.out_channels,
                                  layer.kernel_size[0], layer.kernel_size[1], layer.padding[0], layer.padding[1]
                                 , layer.stride[0], layer.stride[1]))
    input_shape = [1, 3, 299, 299]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(pytorch_model, input_data).eval()
    shape_list = [("input0", input_shape)]
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    with tvm.transform.PassContext(opt_level=3):
        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
             mod = relay.quantize.quantize(mod, params=params)
                
    mod_as_string = mod.astext(show_meta_data=False)
    cast_line = ""
    cast_line_idx = -1
    final_workloads = []
    for i, line in enumerate(mod_as_string.split('\n')):
        if "cast" in line and "int8" in line:
            cast_line = line
            cast_line_idx = i
        elif "conv2d" in line and "int8" in line:
            match = re.search(channels_re, line)
            if match:
                if int(match.group(1)) % 16 == 0 and int(match.group(2)) % 16 == 0:
                    match_cast = re.search(cast_re, cast_line)
                    if match_cast:
                        wkl = workloads[count]
                        final_workloads.append(
                        'Workload({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {})'.format(
                        1, match_cast.group(1), match_cast.group(2), wkl.in_filter, wkl.out_filter,
                        wkl.hkernel,wkl.wkernel, wkl.hpad, wkl.wpad, wkl.hstride, wkl.wstride))
                        count += 1
    
    return final_workloads

In [36]:
final_final.extend(extract_wkls("inception_v3"))

Using cache found in /home/srchand/.cache/torch/hub/pytorch_vision_main


IndexError: list index out of range

In [33]:
len(final_final)

93

In [34]:
final_final = list(dict.fromkeys(final_final))
for i, wkl in enumerate(final_final):
    print('(\'workloads_{}\', {}),'.format(i, wkl))

('workloads_0', Workload(1, 149, 149, 32, 32, 3, 3, 0, 0, 1, 1)),
('workloads_1', Workload(1, 147, 147, 32, 64, 3, 3, 1, 1, 1, 1)),
('workloads_2', Workload(1, 73, 73, 64, 80, 1, 1, 0, 0, 1, 1)),
('workloads_3', Workload(1, 73, 73, 80, 192, 3, 3, 0, 0, 1, 1)),
('workloads_4', Workload(1, 35, 35, 192, 64, 1, 1, 0, 0, 1, 1)),
('workloads_5', Workload(1, 35, 35, 192, 48, 1, 1, 0, 0, 1, 1)),
('workloads_6', Workload(1, 35, 35, 48, 64, 5, 5, 2, 2, 1, 1)),
('workloads_7', Workload(1, 35, 35, 64, 96, 3, 3, 1, 1, 1, 1)),
('workloads_8', Workload(1, 35, 35, 96, 96, 3, 3, 1, 1, 1, 1)),
('workloads_9', Workload(1, 35, 35, 192, 32, 1, 1, 0, 0, 1, 1)),
('workloads_10', Workload(1, 35, 35, 256, 64, 1, 1, 0, 0, 1, 1)),
('workloads_11', Workload(1, 35, 35, 256, 48, 1, 1, 0, 0, 1, 1)),
('workloads_12', Workload(1, 35, 35, 288, 64, 1, 1, 0, 0, 1, 1)),
('workloads_13', Workload(1, 35, 35, 288, 48, 1, 1, 0, 0, 1, 1)),
('workloads_14', Workload(1, 35, 35, 288, 384, 3, 3, 0, 0, 2, 2)),
('workloads_15', Work