In [49]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import timm
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchprofile import profile_macs
import numpy as np
import json
from PIL import Image
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'

In [50]:
class ModelAnalyzer:
    def __init__(self, model_name, pretrained=True, device='cuda'):
        self.device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu'
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.to(self.device)
        self.model.eval()
        
    def prepare_input(self, batch_size=1):
        input_size = self.model.default_cfg['input_size']
        dummy_input = torch.randn(batch_size, *input_size).to(self.device)
        return dummy_input
    
    def get_model_profile(self, input_tensor):
        macs = profile_macs(self.model, (input_tensor,))
        params = sum(p.numel() for p in self.model.parameters())
        return {'MACs': macs}, {'Params': params}
    
    def profile_model(self, input_tensor):
        with torch.autograd.profiler.profile(use_device=self.device) as prof:
            with torch.no_grad():
                _ = self.model(input_tensor)
        return prof
    
    def parse_profiling(self, prof):
        prof_data = prof.key_averages().table(sort_by="cpu_time_total")
        return prof_data
    
    def generate_report(self, prof_data, model_profile, layer_times):
        report = {
            'ModelProfile': model_profile,
            'LayerWiseProfiling': prof_data,
            'LayerExecutionTimes': layer_times
        }
        return json.dumps(report, indent=4)

    
    
    # Hooks
    def add_hooks(self):
        self.hooks = []
        for name, module in self.model.named_modules():
            hook = module.register_forward_hook(self.get_hook(name))
            self.hooks.append(hook)
        self.layer_times = {}

    def get_hook(self, name):
        def hook(module, input, output):
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)
            start_time.record()
            # Forward pass
            end_time.record()
            torch.cuda.synchronize()
            elapsed_time = start_time.elapsed_time(end_time)
            self.layer_times[name] = elapsed_time
        return hook

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
            
    def profile_model_with_hooks(self, input_tensor):
        self.add_hooks()
        with torch.no_grad():
            _ = self.model(input_tensor)
        self.remove_hooks()
        return self.layer_times



In [51]:
analyzer = ModelAnalyzer('vit_base_patch16_224', pretrained=False)
input_tensor = analyzer.prepare_input()
model_profile = analyzer.get_model_profile(input_tensor)
layer_times = analyzer.profile_model_with_hooks(input_tensor)
prof = analyzer.profile_model(input_tensor)
prof_data = analyzer.parse_profiling(prof)
report = analyzer.generate_report(prof_data, model_profile, layer_times)

# Make JSON
report_dict = json.loads(report)
with open('model_analysis_report.json', 'w') as f:
    json.dump(report_dict, f, indent=4)


