In [68]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))
import timm
from timm.models.helpers import model_parameters
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchprofile import profile_macs
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np
import json
from PIL import Image
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'

from torch_scatter import scatter_add


In [58]:
# models = [model for model in timm.list_models('vit*')]
# print([x for x in models if 'vitamin' not in str(x)])
model_name = "vit_base_patch16_224"

In [69]:
class Analyzer:
    def __init__(self, model_name, pretrained=True, pretrained_cfg=None, pretrained_cfg_overlay=None,
                checkpoint_path='', scriptable=None, exportable=None, no_jit=True):
        self.model_name = model_name
        self.model = timm.create_model(
            model_name, pretrained=pretrained, pretrained_cfg=pretrained_cfg, pretrained_cfg_overlay=pretrained_cfg_overlay,
            checkpoint_path=checkpoint_path, scriptable=scriptable, exportable=exportable, no_jit=no_jit
        )
        self.profiler = None  
        
        # print(self.model.default_cfg)
        # params = sum(p.numel() for p in self.model.parameters())
        # print(f"Number of parameters: {params / 1e6:.2f}M")
    
    def inference_one(self, input_tensor: torch.Tensor):
        """Run inference on a single input tensor."""
        self.model.eval()  
        with torch.no_grad():
            output = self.model(input_tensor)
        return output
    
    def start_profiler(self) -> None:
        """Initialize and start the profiler."""
        self.profiler = profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
            record_shapes=True,
            with_stack=True,
            with_flops=True,
            profile_memory=True,
            with_modules=True
        )
        self.profiler.__enter__()  # Start the profiler context manually

    def stop_profiler(self) -> None:
        """Stop the profiler and process the collected data."""
        if self.profiler:
            self.profiler.__exit__(None, None, None) 

    def list_events(self) -> None:
        """
        List all recorded events from the profiler.	
        
        name	         - The name of the operation (e.g., aten::add, aten::matmul, aten::conv2d).
		cpu_time_total	 - Total time spent on the CPU for this operation, in microseconds.
		cuda_time_total	 - Total time spent on the GPU for this operation, in microseconds.
		input_shapes	 - Shapes of the tensors used as inputs to this operation.
		output_shapes	 - Shapes of the tensors produced by this operation (if applicable).
		device_type	     - Whether the operation was executed on CPU or CUDA.
		device	         - The device ID on which the operation was executed.
		self_cpu_time	 - Time spent on the CPU for this operation alone (excluding time for child operations).
		self_cuda_time   - Time spent on the GPU for this operation alone (excluding time for child operations).
		"""
        if self.profiler is None:
            print("Profiler has not been initialized or profiling session has ended.")
        else:
            print(self.profiler.key_averages()) 
            for event in self.profiler.events():
                print(f"Name: {event.name}, CPU Time: {event.cpu_time_total}, CUDA Time: {event.cuda_time_total}")

In [66]:
analyzer = Analyzer(model_name)
input_tensor = torch.randn(1, 3, 224, 224)
analyzer.start_profiler()
analyzer.inference_one(input_tensor)
analyzer.stop_profiler()
analyzer.list_events()


AttributeError: module 'torch.profiler' has no attribute 'export_chrome_trace'