In [11]:
import torch
import torch.cuda.nvtx as nvtx
import inspect
from inspect import currentframe, getargvalues, getfullargspec, getmembers, isfunction
import types
import re


In [16]:
class NvtxPatcher:
    
    registry = set()
    nvtx_handle = nvtx._libnvToolsExt()
    
    @staticmethod
    def nvtx_monkey_patch(func):
        def wrapper(*args, **kwargs):
            frame = currentframe()
            v = getargvalues(frame)
            print("FUNC = {}".format(func))
            argspec = getfullargspec(func)
            formal_arg_names = argspec.args
            s = "{'op':%s," % v.locals["func"].__name__
            for idx, val in enumerate(v.locals["args"]):
                name = formal_arg_names[idx]
                if isinstance(val, torch.Tensor):
                    name += "_shape"
                    val = tuple(val.size())
                s += "'%s':%s," % (name, str(val))
            num_def=len(argspec.defaults)
            defaults = dict(zip(argspec.args[-num_def:], argspec.defaults))
            overrides={k:str(v) for k, v in v.locals["kwargs"].items()}
            defaults.update(overrides)
            s += "%s}" % str(defaults).strip("{}")
            nvtx.range_push(s)
            result = func(*args, **kwargs)
            nvtx.range_pop()
            return result
        return wrapper

    @classmethod
    def list_non_builtins(cls, module, regex_filt_lst=None, log=True):
        if not isinstance(regex_filt_lst, list) and regex_filt_lst is not None:
            regex_filt_lst = list(regex_filt_lst)
        if isinstance(module, str):
            module = eval(module)
        name_list = dir(module)
        builtin_funcs_methods = [_a for _a in name_list if
                                 (isinstance(getattr(module, _a), types.BuiltinFunctionType) or
                                  isinstance(getattr(module, _a), types.BuiltinMethodType))]
        match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
        if regex_filt_lst is not None:
            function_list = [_x for _x in builtin_funcs_methods if match_any(_x)]
        else: 
            function_list = [_x for _x in builtin_funcs_methods]
        return function_list 
                                 
    @classmethod
    def register_non_builtins(cls, module, regex_filt_lst=None, log=True):
        if not isinstance(regex_filt_lst, list) and regex_filt_lst is not None:
            regex_filt_lst = list(regex_filt_lst)
        if isinstance(module, str):
            module = eval(module)
        name_list = dir(module)
        non_builtin_funcs = [_a for _a in name_list if
                     isinstance(getattr(module, _a), types.FunctionType)]
        
        match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
        if regex_filt_lst is not None:
            function_list = [_x for _x in non_builtin_funcs if match_any(_x)]
        else: 
            function_list = [_x for _x in mod_funcs]
            
        for name in function_list:
            if name in cls.registry:
                continue
            fqn = "{}.{}".format(module.__name__, name)
            temp = eval(fqn)
            patched = NvtxPatcher.nvtx_monkey_patch(temp)
            cls.registry.add(fqn)
            exec("{}=patched".format(fqn))
            
        print("{}\n{}\n".format("Functions registered for NVTX range annotation:", function_list))
    
patterns = ["conv[1-3]?(d|(\_transpose[1-3]d))",
     "(un)?fold",
     "(avg|max)_pool",
     "max_unpool[1-3]d",
     "lp_pool[1-3]d",
     "adaptive_(avg|max)_pool[1-3]d",
     "threshold",
     "(leaky_)?[p-s]?r?elu_?6?",
     "(hard)?tanh",
     "glu",
     "(log)?sigmoid",
     "(hard|soft|tanh)shrink",
     "soft(sign|plus|min)",
     "(gumbel_|log_)?softmax",
     "(batch|layer|instance|local_response)_norm",
     "normalize",
     "(bi)?linear",
     "(alpha_)?dropout([2-3]d)?",
     "embedding(_bag)?",
     "pairwise_distance",
     "cosine_similarity",
     "(binary_)?cross_entropy",
     "(poisson_)?nll_loss",
     "(cosine|hinge)_embedding_loss",
     "kl_div",
     "((smooth_)?l1|mse)_loss",
     "(multilabel|multi)?_margin_(soft_?)(ranking)?_loss",
     "(soft|triplet)_margin_loss",
     "pad",
     "pixel_shuffle",
     "interpolate",
     "upsample_?(bilinear|nearest)?",
     "(affine_)?grid(_sample)?"]

NvtxPatcher.register_non_builtins(
    torch.nn.functional, patterns)

print("built-ins (manual monkey-patching required):")
print(NvtxPatcher.list_non_builtins(torch.nn.functional, patterns))

Functions registered for NVTX range annotation:
['adaptive_avg_pool2d', 'adaptive_avg_pool3d', 'adaptive_max_pool1d', 'adaptive_max_pool2d', 'adaptive_max_pool3d', 'affine_grid', 'alpha_dropout', 'batch_norm', 'bilinear', 'binary_cross_entropy', 'binary_cross_entropy_with_logits', 'cosine_embedding_loss', 'cosine_similarity', 'cross_entropy', 'dropout', 'dropout2d', 'dropout3d', 'elu', 'embedding', 'embedding_bag', 'fold', 'glu', 'grid_sample', 'gumbel_softmax', 'hardshrink', 'hardtanh', 'hinge_embedding_loss', 'instance_norm', 'interpolate', 'kl_div', 'l1_loss', 'layer_norm', 'leaky_relu', 'linear', 'local_response_norm', 'log_softmax', 'lp_pool1d', 'lp_pool2d', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'max_unpool1d', 'max_unpool2d', 'max_unpool3d', 'mse_loss', 'nll_loss', 'normalize', 'pad', 'pairwise_distance', 'pixel_shuffle', 'poisson_nll_loss', 'relu', 'relu6', 'rrelu', 'selu', 'sigmoid', 'smooth_l1_loss', 'soft_margin_loss', 'softmax', 'softmin', 'softsign', 'tanh', 'tanhshrink

In [None]:
# Implement the following:
adaptive_avg_pool1d

avg_pool1d
avg_pool2d
avg_pool3d

conv1d 
conv2d
conv3d

conv_transpose1d
conv_transpose2d
conv_transpose3d

elu_
hardtanh_
leaky_relu_
logsigmoid
prelu
relu_
rrelu_
selu_
softplus
softshrink
threshold_

In [13]:
a = torch.randn(1, 3, 5, 5).cuda()
b = torch.randn(4, 3, 3, 3).cuda()
c = torch.randn(4, 4)
d = torch.randn(4, 4)
result = torch.nn.functional.conv2d(a, b)

In [None]:
print(result)

In [None]:
# function_list = [elem[0] for elem in inspect.getmembers(F) if inspect.isfunction(o[1])]

match_any = lambda txt:  any((map(lambda x: re.match(r"%s" % x, txt), regex_filt_lst)))
regex_filt_lst = ["conv[0-9]d", "foo"]
print(match_any("conv3d_transpose"))

In [None]:
@nvtx_patch

# Functional list
@nvtx_patch()
def patched_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op: 'conv1d', input: %s, weight: %s, stride: %s, padding: %s, dilation: %s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO
    op = fn(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op:'conv2d', input:%s, weight:%s, stride:%s, padding:%s, dilation:%s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO    
    op = conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv3d():
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    # Interpolate numbers as strings because some can be one-elem tuples as well
    nvtx.range_push("{op:'conv3d', input:%s, weight:%s, stride:%s, padding:%s, dilation:%s, groups:%s}" % 
                    (input_size, weight_size, str(stride), str(padding), str(dilation), str(groups)))
###### TODO 
    op = conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose1d', input:%s, weight:%s, stride:%s, padding:%s, output_padding: %s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose2d', input:%s, weight:%s, stride:%s, padding:%s, output_padding:%s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op:'conv_transpose3d', input:%s, weight:%s, stride:%s, padding:%s, output_padding: %s, groups:%s, dilation:%s}"
                    % (input_size, weight_size, str(stride), str(padding), str(output_padding), str(groups), str(dilation)))
###### TODO 
    op = conv_transpose3d(input, weight, bias, stride, padding, output_padding, groups, dilation)
    nvtx.range_pop()
    return op

@nvtx_patch()
def patched_linear(input, weight, bias=None):
    input_size = tuple(input.size())
    weight_size = tuple(weight.size())
    nvtx.range_push("{op: 'linear', input:%s}")
###### TODO 
    op = linear(input, weight, bias)
    nvtx.range_pop()
    return op

@nvtx_patch()
# (input, p=0.5, training=False, inplace=False)
def patched_dropout():
    pass

@nvtx_patch()    
# (input, inplace=False) 
def patched_relu():
    pass

@nvtx_patch()    
# (input)
def patched_relu_():
    pas



In [None]:
import inspect

def func(*args, **kwargs):
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    print 'function name "%s"' % inspect.getframeinfo(frame)[2]
    for i in args:
        print "    %s = %s" % (i, values[i])
    return [(i, values[i]) for i in args]

func()

In [None]:
import inspect
import torch

global abc


def my_decorator(func):
    def wrapper(*args, **kwargs):
        frame = inspect.currentframe()
        v = inspect.getargvalues(frame)
        argspec = inspect.getfullargspec(func)
        formal_arg_names = argspec.args
        s = "{'op':%s," % v.locals["func"].__name__
        for idx, val in enumerate(v.locals["args"]):
            name = formal_arg_names[idx]
            if isinstance(val, torch.Tensor):
                name += "_shape"
                val = tuple(val.size())
            s += "'%s':%s," % (name, str(val))
        num_def=len(argspec.defaults)
        defaults = dict(zip(argspec.args[-num_def:], argspec.defaults))
        overrides={k:str(v) for k, v in v.locals["kwargs"].items()}
        defaults.update(overrides)
        s += "%s}" % str(defaults).strip("{}")
        print("Pushing NVTX range: %s" % s)
                
        func(*args, **kwargs)
        print("Something is happening after the function is called.")
    return wrapper

import torch 

@my_decorator
def say_whee(foo, bar, baz=42, qux=123):
    print("### In say_whee")
    print("foo = %s" % str(foo))
    print("bar = %s" % str(bar))
    print("baz = %s" % str(baz))
    print("qux = %s" % str(ham))

    print("Whee!")

say_whee(torch.randn(2,2), 1, qux=456)    

In [None]:
abc.locals['kwargs']

In [None]:
abc.locals['args']

In [None]:
abc.locals['kwargs'] 

In [None]:
abc.locals['func'].__name__

In [None]:
foo = "{"
# Need to record arg name from signature
for arg in abc.locals['args']:
    if arg(isinstance(torch.Tensor)):
        dims = tuple(arg.size())

# torch.Tensor
"{name:%s, args:%s, kwargs:%s}" % (func_name, str(args), str(kwargs)) 

In [None]:
import torch.nn.functional as F

help(F.conv2d())

In [None]:
def init():
    if not initialized:
        # do stuff
    initialized = True
    pass

In [None]:
# def patched_hardtanh():
#     pass

# def patched_hardtanh_():
#     pass

# def patched_relu6():
#     pass

# def patched_elu():
#     pass

# def patched_elu_():
#     pass

# def patched_unfold():
#     pass

# def patched_fold():
#     pass

# def patched_avg_pool1d():
#     pass

# def patched_avg_pool2d():
#     pass

# def pathched_avg_pool3d():
#     pass

# def patched_max_pool1d():
#     pass

# def patched_max_pool2d():
#     pass

# def patched_max_pool3d():
#     pass

# def patched_max_unpool1d():
#     pass

# def patched_max_unpool2d():
#     pass

# def patched_max_unpool3d():
#     pass

# def patched_lp_pool1d():
#     pass

# def patched_lp_pool2d():
#     pass

# def patched_adaptive_max_pool1d():
#     pass

# def patched_adaptive_max_pool2d():
#     pass

# def patched_adaptive_max_pool3d():
#     pass

# def patched_adaptive_avg_pool1d():
#     pass

# def patched_adaptive_avg_pool2d():
#     pass

# def patched_adaptive_avg_pool3d():
#     pass

# def patched_threshold():
#     pass

# def patched_threshold_():
#     pass



# leaky_relu
# leaky_relu_
# prelu
# rrelu
# rrelu_
# logsigmoid
# hardshrink
# tanhshrink
# softplus
# softmin
# softmax
# softshrink
# gumbel_softmax
# log_softmax
# tanh
# sigmoid
# batch_norm
# instance_norm
# normalize

# bilinear
# alpha_dropout
# dropout2d
# dropout3d
# pad


# local_response_norm

# embedding
# embedding_bag
# pairwise_distance
# cosine_similarity
# binary_cross_entropy
# poisson_nll_loss
# cosine_embedding_loss
# cross_entropy
# hinge_embedding_loss
# kl_div
# l1_loss
# mse_loss
# margin_ranking_loss
# multilabel_margin_loss
# multilabel_soft_margin_loss
# multi_margin_loss
# nll_loss
# binary_cross_entropy_with_logits
# smooth_l1_loss
# soft_margin_loss
# triplet_margin_loss
# pixel_shuffle
# grid_sample
# affine_grid
# data_parallel
# calculate_gain
# uniform_
# normal_
# constant_
# eye_
# dirac_
# xavier_uniform_
# xavier_normal
# kaiming_uniform_
# kaiming_normal_
# orthogonal_
# sparse_

# upsample

# upsample_nearest
# upsample_bilinear
# interpolate