In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.autograd.profiler as profiler

In [5]:
class MyModule(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super(MyModule, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias)

    def forward(self, input, mask):
        with profiler.record_function("LINEAR PASS"):
            out = self.linear(input)

        with profiler.record_function("MASK INDICES"):
            threshold = out.sum(axis=1).mean().item()
            hi_idx = np.argwhere(mask.cpu().numpy() > threshold)
            hi_idx = torch.from_numpy(hi_idx)

        return out, hi_idx

In [11]:
model = MyModule(500, 10)
input = torch.rand(128, 500)
mask = torch.rand((500, 500, 500), dtype=torch.float)

# warm-up
model(input, mask)

with profiler.profile(with_stack=True, profile_memory=True) as prof:
    out, idx = model(input, mask)

print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 MASK INDICES        99.92%        4.743s        99.96%        4.745s        4.745s           0 b        -516 b             1  
             aten::lift_fresh         0.04%       1.753ms         0.04%       1.753ms       1.753ms           0 b           0 b             1  
                  aten::addmm         0.02%     740.000us         0.02%     837.000us     837.000us       5.00 Kb       5.00 Kb             1  
                  LINEAR PASS         0.01%     673.000us         0.04%       1.686ms       1.686ms       5.00 Kb           0 b         

<function torch._C._autograd.PyCapsule.kineto_available>