In [7]:
import re
import torch
from transformer_net import Student, mods, TransformerNet

def profile_model(model_path='./models_pth/mosaic.pth', model_class=TransformerNet):
  with torch.no_grad():
      if model_class == TransformerNet:
        style_model = TransformerNet()
      else:
        style_model = Student(mods)
      
      state_dict = torch.load(model_path,map_location=torch.device('cpu'))
      # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
      for k in list(state_dict.keys()):
          if re.search(r'in\d+\.running_(mean|var)$', k):
              del state_dict[k]
      style_model.load_state_dict(state_dict)
      style_model.eval()

      input = torch.Tensor(1,3,640,480)


  def trace_handler(prof):
      print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total", row_limit=20))
      
  with torch.profiler.profile(
      with_stack=True,
      # profile_memory=True,
      record_shapes=True,
      activities=[
          torch.profiler.ProfilerActivity.CPU,
      ],
      schedule=torch.profiler.schedule(
          wait=1,
          warmup=1,
          active=1,
          repeat=1),
      on_trace_ready=trace_handler
      ) as p:
          for iter in range(4):
              style_model(input)
              # send a signal to the profiler that the next iteration has started
              p.step()

In [9]:
profile_model(model_path='./models_pth/mosaic_compressed_36.pth', model_class=Student)

STAGE:2023-10-12 02:39:10 70106:8714535 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-12 02:39:10 70106:8714535 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-12 02:39:10 70106:8714535 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
      aten::_slow_conv2d_forward        85.93%     164.551ms        86.17%     165.002ms     165.002ms             1                               [[1, 12, 648, 488], [3, 12, 9, 9], [], [3], [], []]  
      aten::_slow_conv2d_forward         4.69%       8.981ms         4.77%       9.127ms       9.127ms             1                                 [[1, 3, 648, 488], [9, 3, 9, 9], [], [9], [], [