In [25]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import timm
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchprofile import profile_macs
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np
import json
from PIL import Image
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'

In [29]:
class ModelAnalyzer:
    def __init__(self, model_name, pretrained=True, device='cuda'):
        self.device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu'
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model_name = model_name
        self.model.to(self.device)
        self.model.eval()
        
    def prepare_input(self, batch_size=1):
        input_size = self.model.default_cfg['input_size']
        dummy_input = torch.randn(batch_size, *input_size).to(self.device)
        return dummy_input
    
    def get_model_profile(self, input_tensor):
        macs = profile_macs(self.model, (input_tensor,))
        params = sum(p.numel() for p in self.model.parameters())
        return {'MACs': macs}, {'Params': params}
    
    def profile_model(self, input_tensor):
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
            with record_function("model_inference"):
                with torch.no_grad():
                    _ = self.model(input_tensor)
        return prof
    
    # def parse_profiling(self, prof):
    #     # Convert the profiling data into a structured format
    #     prof_data = []
    #     for event in prof.key_averages():
    #         event_data = {'Name': event.key}
    #         # Loop through all attributes of the event
    #         for attr in dir(event):
    #             # Filter out private and callable attributes
    #             if not attr.startswith('_') and not callable(getattr(event, attr)):
    #                 value = getattr(event, attr)
    #                 # Ensure the value is JSON serializable
    #                 try:
    #                     json.dumps(value)  # Test if value is JSON serializable
    #                     event_data[attr] = value
    #                 except (TypeError, OverflowError):
    #                     event_data[attr] = str(value)  # Convert non-serializable types to string
    #         prof_data.append(event_data)
    #     return prof_data
    
    def parse_profiling(self, prof):
        # Convert the profiling data into a structured format
        prof_data = []
        for event in prof.key_averages():
            event_data = {'Name': event.key}
            # Loop through all attributes of the event
            for attr in dir(event):
                # Skip 'cpu_children' and filter out private and callable attributes
                if attr == 'cpu_children' or attr.startswith('_') or callable(getattr(event, attr)):
                    continue
                value = getattr(event, attr)
                # Ensure the value is JSON serializable
                try:
                    json.dumps(value)  # Test if value is JSON serializable
                    event_data[attr] = value
                except (TypeError, OverflowError):
                    event_data[attr] = str(value)  # Convert non-serializable types to string
            prof_data.append(event_data)
        return prof_data

    def generate_report(self, prof_data, model_profile, layer_times):
        report = {
            'ModelProfile': model_profile,
            'LayerWiseProfiling': prof_data,
            'LayerExecutionTimes': layer_times
        }
        report_json = json.dumps(report, indent=4)
        with open(f'{self.model_name}_report.json', 'w') as f:
            f.write(report_json)
        return report_json
    
    # Hooks
    def add_hooks(self):
        self.hooks = []
        for name, module in self.model.named_modules():
            hook = module.register_forward_hook(self.get_hook(name))
            self.hooks.append(hook)
        self.layer_times = {}

    def get_hook(self, name):
        def hook(module, input, output):
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)
            start_time.record()
            # Forward pass
            end_time.record()
            torch.cuda.synchronize()
            elapsed_time = start_time.elapsed_time(end_time)
            self.layer_times[name] = elapsed_time
        return hook

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
            
    def profile_model_with_hooks(self, input_tensor):
        self.add_hooks()
        with torch.no_grad():
            _ = self.model(input_tensor)
        self.remove_hooks()
        return self.layer_times

In [30]:
# Initialize the ModelAnalyzer with a specific model
model_name = 'resnet50'  # Example model name
analyzer = ModelAnalyzer(model_name=model_name, pretrained=True, device='cuda')

# Prepare a dummy input tensor
input_tensor = analyzer.prepare_input(batch_size=1)

# Get the model profile (MACs and Params)
model_profile = analyzer.get_model_profile(input_tensor)
# print("Model Profile:", model_profile)

# Profile the model using the new torch.profiler
prof = analyzer.profile_model(input_tensor)

# Parse the profiling data
prof_data = analyzer.parse_profiling(prof)
# print("Profiling Data:\n", prof_data)

# Optionally, generate a report
layer_times = analyzer.profile_model_with_hooks(input_tensor)
report = analyzer.generate_report(prof_data, model_profile, layer_times)
# print("Generated Report:\n", report)

  if attr == 'cpu_children' or attr.startswith('_') or callable(getattr(event, attr)):
  value = getattr(event, attr)
