In [1]:
import torch
import torch.cuda.nvtx as nvtx
import inspect
from inspect import currentframe, getargvalues, getfullargspec, getmembers, isfunction
import types
import re


In [2]:
class NvtxPatcher:
    
    registry = set()
    nvtx_handle = nvtx._libnvToolsExt()
    
    @staticmethod
    def nvtx_monkey_patch(func):
        def wrapper(*args, **kwargs):
            frame = currentframe()
            v = getargvalues(frame)
            argspec = getfullargspec(func)
            formal_arg_names = argspec.args
            s = "{'op':%s," % v.locals["func"].__name__
            for idx, val in enumerate(v.locals["args"]):
                name = formal_arg_names[idx]
                if isinstance(val, torch.Tensor):
                    name += "_shape"
                    val = tuple(val.size())
                s += "'%s':%s," % (name, str(val))
            num_def=len(argspec.defaults)
            defaults = dict(zip(argspec.args[-num_def:], argspec.defaults))
            overrides={k:str(v) for k, v in v.locals["kwargs"].items()}
            defaults.update(overrides)
            s += "%s}" % str(defaults).strip("{}")
            nvtx.range_push(s)
            result = func(*args, **kwargs)
            nvtx.range_pop()
            return result
        return wrapper

    @classmethod
    def list_non_builtins(cls, module, regex_filt_lst=None, log=True):
        if not isinstance(regex_filt_lst, list) and regex_filt_lst is not None:
            regex_filt_lst = list(regex_filt_lst)
        if isinstance(module, str):
            module = eval(module)
        name_list = dir(module)
        builtin_funcs_methods = [_a for _a in name_list if
                                 (isinstance(getattr(module, _a), types.BuiltinFunctionType) or
                                  isinstance(getattr(module, _a), types.BuiltinMethodType))]
        match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
        if regex_filt_lst is not None:
            function_list = [_x for _x in builtin_funcs_methods if match_any(_x)]
        else: 
            function_list = [_x for _x in builtin_funcs_methods]
        return function_list 
                                 
    @classmethod
    def register_non_builtins(cls, module, regex_filt_lst=None, log=True):
        if not isinstance(regex_filt_lst, list) and regex_filt_lst is not None:
            regex_filt_lst = list(regex_filt_lst)
        if isinstance(module, str):
            module = eval(module)
        name_list = dir(module)
        non_builtin_funcs = [_a for _a in name_list if
                     isinstance(getattr(module, _a), types.FunctionType)]
        
        match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
        if regex_filt_lst is not None:
            function_list = [_x for _x in non_builtin_funcs if match_any(_x)]
        else: 
            function_list = [_x for _x in mod_funcs]
            
        for name in function_list:
            if name in cls.registry:
                continue
            fqn = "{}.{}".format(module.__name__, name)
            temp = eval(fqn)
            patched = NvtxPatcher.nvtx_monkey_patch(temp)
            cls.registry.add(fqn)
            exec("{}=patched".format(fqn))
            
    @classmethod        
    # convNd is a built-in, so can't be registered using the non-builtin approach above
    def patch_conv(cls, dim_count, module=torch.nn.functional):
        fun_name = "{}.conv{}d".format(module.__name__, str(dim_count))
        # Function already patched
        if fun_name in cls.registry:
            return
        temp = eval(fun_name)
        def decorator(fun):
            def wrapper(data, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
                print("### {} ###".format(data))
                input_size = tuple(data.size())
                weight_size = tuple(weight.size())
                # Interpolate numbers as strings because some can be one-elem tuples as well
                nvtx_str = "{op:'conv%sd', input:%s, weight:%s, stride:%s, padding:%s, dilation:%s, groups:%s}" % (dim_count,input_size, weight_size, str(stride), str(padding), str(dilation), str(groups))
                nvtx.range_push(nvtx_str)
                op = fun(data, weight, bias, stride, padding, dilation, groups)
                nvtx.range_pop()
                return op
            return wrapper
        patched = decorator(temp)
        exec("{}=patched".format(fun_name))
        return patched
              
    @classmethod
    def print_registered_functions(cls):
              print("Functions registered for NVTX range annotation:\n{}\n".format(str(cls.registry)))

    
patterns = ["conv[1-3]?(d|(\_transpose[1-3]d))",
     "(un)?fold",
     "(avg|max)_pool",
     "max_unpool[1-3]d",
     "lp_pool[1-3]d",
     "adaptive_(avg|max)_pool[1-3]d",
     "threshold",
     "(leaky_)?[p-s]?r?elu_?6?",
     "(hard)?tanh",
     "glu",
     "(log)?sigmoid",
     "(hard|soft|tanh)shrink",
     "soft(sign|plus|min)",
     "(gumbel_|log_)?softmax",
     "(batch|layer|instance|local_response)_norm",
     "normalize",
     "(bi)?linear",
     "(alpha_)?dropout([2-3]d)?",
     "embedding(_bag)?",
     "pairwise_distance",
     "cosine_similarity",
     "(binary_)?cross_entropy",
     "(poisson_)?nll_loss",
     "(cosine|hinge)_embedding_loss",
     "kl_div",
     "((smooth_)?l1|mse)_loss",
     "(multilabel|multi)?_margin_(soft_?)(ranking)?_loss",
     "(soft|triplet)_margin_loss",
     "pad",
     "pixel_shuffle",
     "interpolate",
     "upsample_?(bilinear|nearest)?",
     "(affine_)?grid(_sample)?"]

NvtxPatcher.register_non_builtins(
    torch.nn.functional, patterns)
                    
for i in range(1, 4):
    NvtxPatcher.patch_conv(i)               

print("built-ins (manual monkey-patching required):")
print(NvtxPatcher.list_non_builtins(torch.nn.functional, patterns))
                    
NvtxPatcher.print_registered_functions()

built-ins (manual monkey-patching required):
['adaptive_avg_pool1d', 'avg_pool1d', 'avg_pool2d', 'avg_pool3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'elu_', 'hardtanh_', 'leaky_relu_', 'logsigmoid', 'prelu', 'relu_', 'rrelu_', 'selu_', 'softplus', 'softshrink', 'threshold_']
Functions registered for NVTX range annotation:
{'torch.nn.functional.instance_norm', 'torch.nn.functional.hardtanh', 'torch.nn.functional.max_pool3d', 'torch.nn.functional.grid_sample', 'torch.nn.functional.adaptive_avg_pool3d', 'torch.nn.functional.binary_cross_entropy', 'torch.nn.functional.cosine_similarity', 'torch.nn.functional.smooth_l1_loss', 'torch.nn.functional.gumbel_softmax', 'torch.nn.functional.nll_loss', 'torch.nn.functional.softmax', 'torch.nn.functional.adaptive_avg_pool2d', 'torch.nn.functional.lp_pool2d', 'torch.nn.functional.adaptive_max_pool3d', 'torch.nn.functional.fold', 'torch.nn.functional.normalize', 'torch.nn.functional.pad', 'torch.nn.functional.max_unpool3d', 'torc

In [3]:
a = torch.randn(2, 3, 5, 5)
b = torch.randn(4, 3, 5, 5)
c = torch.nn.functional.conv2d(a, b)

### tensor([[[[ 0.6323,  0.5867, -0.6936,  0.6546, -0.8506],
          [ 0.2374, -0.0905,  0.6267,  1.8584, -0.3787],
          [ 0.1963,  1.3841,  1.2110,  1.0826,  0.5698],
          [ 0.0248, -1.3243,  2.0463, -0.3436, -0.0651],
          [-1.0432,  0.2888,  1.8106,  0.4096,  0.3805]],

         [[ 1.8079, -2.4335,  2.3370,  0.1876, -1.8761],
          [ 0.2383, -0.5957, -0.2922, -1.1921,  0.3144],
          [ 0.0701, -0.9783,  0.8120,  0.9708, -0.8644],
          [ 0.8355, -0.9670,  1.5134, -0.7500,  0.8595],
          [-1.2391,  1.6980,  0.3484, -3.1029, -0.8725]],

         [[-0.9943,  0.5533,  0.0658, -0.6878, -2.0404],
          [ 0.4509, -0.2956,  0.8767, -0.8130,  0.4259],
          [-1.6613,  1.4032, -0.6154, -0.6392, -0.3006],
          [-1.9183, -0.8526,  0.4259, -0.5146, -0.2385],
          [ 0.2143,  1.3863,  0.6244,  1.2601,  2.0780]]],


        [[[-0.1714, -0.4178,  3.2297,  1.2704, -0.5574],
          [ 0.2940,  0.1871,  1.2556, -0.4744,  0.8085],
          [ 0.1565,

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [5]:
net = Net()


output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output

criterion = nn.MSELoss()

loss = criterion(output, target)

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()

### <bound method Kernel.raw_input of <ipykernel.ipkernel.IPythonKernel object at 0x7f63cf76cd30>> ###


AttributeError: 'function' object has no attribute 'size'

In [None]:
x = torch.randn(1, 3, 5, 5)
y = torch.randn(2, 3, 3, 3)
torch.nn.functional.conv2d(x, y)

In [None]:
# conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
# conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
# conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor

# conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
# conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
# conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor

# adaptive_avg_pool1d(input, output_size) -> Tensor

# avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
# avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
# avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor

# elu_(input, alpha=1.) -> Tensor
# hardtanh_(input, min_val=-1., max_val=1., inplace=False):
# leaky_relu_(input, negative_slope=0.01) -> Tensor
# logsigmoid(input) -> Tensor
# prelu(input, weight) -> Tensor
# relu_(input) -> Tensor
# rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor
# selu_(input) -> Tensor
# softplus(input, beta=1, threshold=20) -> Tensor
# softshrink(input, lambd=0.5) -> Tensor
# threshold_(input, threshold, value) -> Tensor

In [None]:
a = torch.randn(1, 3, 5, 5).cuda()
b = torch.randn(4, 3, 3, 3).cuda()
c = torch.randn(4, 4)
d = torch.randn(4, 4)
result = torch.nn.functional.conv2d(a, b)

In [None]:
print(result)

In [None]:
# function_list = [elem[0] for elem in inspect.getmembers(F) if inspect.isfunction(o[1])]

match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
regex_filt_lst = ["conv[0-9]d", "foo"]
print(match_any("conv3d_transpose"))

In [None]:
@nvtx_patch

# Functional list
@nvtx_patch()
def patched_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op: 'conv1d', input: %s, weight: %s, stride: %s, padding: %s, dilation: %s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO
    op = fn(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op:'conv2d', input:%s, weight:%s, stride:%s, padding:%s, dilation:%s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO    
    op = conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv3d():
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op:'conv3d', input:%s, weight:%s, stride:%s, padding:%s, dilation:%s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO 
    op = conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose1d', input:%s, weight:%s, stride:%s, padding:%s, output_padding: %s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose2d', input:%s, weight:%s, stride:%s, padding:%s, output_padding:%s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose3d', input:%s, weight:%s, stride:%s, padding:%s, output_padding: %s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_linear(input, weight, bias=None):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op: 'linear', input:%s}")
###### TODO 
    op = linear(input, weight, bias)
    nvtx.range_pop()
    return op

@nvtx_patch()
# (input, p=0.5, training=False, inplace=False)
def patched_dropout():
    pass

@nvtx_patch()    
# (input, inplace=False) 
def patched_relu():
    pass

@nvtx_patch()    
# (input)
def patched_relu_():
    pas



In [None]:
import inspect

def func(*args, **kwargs):
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    print 'function name "%s"' % inspect.getframeinfo(frame)[2]
    for i in args:
        print "    %s = %s" % (i, values[i])
    return [(i, values[i]) for i in args]

func()

In [None]:
import inspect
import torch

global abc


def my_decorator(func):
    def wrapper(*args, **kwargs):
        frame = inspect.currentframe()
        v = inspect.getargvalues(frame)
        argspec = inspect.getfullargspec(func)
        formal_arg_names = argspec.args
        s = "{'op':%s," % v.locals["func"].__name__
        for idx, val in enumerate(v.locals["args"]):
            name = formal_arg_names[idx]
            if isinstance(val, torch.Tensor):
                name += "_shape"
                val = tuple(val.size())
            s += "'%s':%s," % (name, str(val))
        num_def=len(argspec.defaults)
        defaults = dict(zip(argspec.args[-num_def:], argspec.defaults))
        overrides={k:str(v) for k, v in v.locals["kwargs"].items()}
        defaults.update(overrides)
        s += "%s}" % str(defaults).strip("{}")
        print("Pushing NVTX range: %s" % s)
                
        func(*args, **kwargs)
        print("Something is happening after the function is called.")
    return wrapper

import torch 

@my_decorator
def say_whee(foo, bar, baz=42, qux=123):
    print("### In say_whee")
    print("foo = %s" % str(foo))
    print("bar = %s" % str(bar))
    print("baz = %s" % str(baz))
    print("qux = %s" % str(ham))

    print("Whee!")

say_whee(torch.randn(2,2), 1, qux=456)    

In [None]:
abc.locals['kwargs']

In [None]:
abc.locals['args']

In [None]:
abc.locals['kwargs'] 

In [None]:
abc.locals['func'].__name__

In [None]:
foo = "{"
# Need to record arg name from signature
for arg in abc.locals['args']:
    if arg(isinstance(torch.Tensor)):
        dims = tuple(arg.size())

# torch.Tensor
"{name:%s, args:%s, kwargs:%s}" % (func_name, str(args), str(kwargs)) 

In [None]:
import torch.nn.functional as F

help(F.conv2d())

In [None]:
def init():
    if not initialized:
        # do stuff
    initialized = True
    pass

In [None]:
# def patched_hardtanh():
#     pass

# def patched_hardtanh_():
#     pass

# def patched_relu6():
#     pass

# def patched_elu():
#     pass

# def patched_elu_():
#     pass

# def patched_unfold():
#     pass

# def patched_fold():
#     pass

# def patched_avg_pool1d():
#     pass

# def patched_avg_pool2d():
#     pass

# def pathched_avg_pool3d():
#     pass

# def patched_max_pool1d():
#     pass

# def patched_max_pool2d():
#     pass

# def patched_max_pool3d():
#     pass

# def patched_max_unpool1d():
#     pass

# def patched_max_unpool2d():
#     pass

# def patched_max_unpool3d():
#     pass

# def patched_lp_pool1d():
#     pass

# def patched_lp_pool2d():
#     pass

# def patched_adaptive_max_pool1d():
#     pass

# def patched_adaptive_max_pool2d():
#     pass

# def patched_adaptive_max_pool3d():
#     pass

# def patched_adaptive_avg_pool1d():
#     pass

# def patched_adaptive_avg_pool2d():
#     pass

# def patched_adaptive_avg_pool3d():
#     pass

# def patched_threshold():
#     pass

# def patched_threshold_():
#     pass



# leaky_relu
# leaky_relu_
# prelu
# rrelu
# rrelu_
# logsigmoid
# hardshrink
# tanhshrink
# softplus
# softmin
# softmax
# softshrink
# gumbel_softmax
# log_softmax
# tanh
# sigmoid
# batch_norm
# instance_norm
# normalize

# bilinear
# alpha_dropout
# dropout2d
# dropout3d
# pad


# local_response_norm

# embedding
# embedding_bag
# pairwise_distance
# cosine_similarity
# binary_cross_entropy
# poisson_nll_loss
# cosine_embedding_loss
# cross_entropy
# hinge_embedding_loss
# kl_div
# l1_loss
# mse_loss
# margin_ranking_loss
# multilabel_margin_loss
# multilabel_soft_margin_loss
# multi_margin_loss
# nll_loss
# binary_cross_entropy_with_logits
# smooth_l1_loss
# soft_margin_loss
# triplet_margin_loss
# pixel_shuffle
# grid_sample
# affine_grid
# data_parallel
# calculate_gain
# uniform_
# normal_
# constant_
# eye_
# dirac_
# xavier_uniform_
# xavier_normal
# kaiming_uniform_
# kaiming_normal_
# orthogonal_
# sparse_

# upsample

# upsample_nearest
# upsample_bilinear
# interpolate