In [30]:
import onnx
import numpy as np
from collections import defaultdict
from typing import Dict, Tuple

def get_shape_from_value_info(value_info) -> list:
    """Helper function to extract shape from ValueInfoProto"""
    if not value_info.type.tensor_type.shape.dim:
        return []
    return [dim.dim_value if dim.dim_value != 0 else 1 
            for dim in value_info.type.tensor_type.shape.dim]

def get_node_input_shape(node, value_info: Dict, initializers: Dict) -> list:
    """Get input shapes for a node"""
    shapes = []
    for input_name in node.input:
        if input_name in value_info:
            shapes.append(get_shape_from_value_info(value_info[input_name]))
        elif input_name in initializers:
            shapes.append(list(initializers[input_name].dims))
    return shapes

def get_node_output_shape(node, value_info: Dict) -> list:
    """Get output shapes for a node"""
    shapes = []
    for output_name in node.output:
        if output_name in value_info:
            shapes.append(get_shape_from_value_info(value_info[output_name]))
    return shapes

def calculate_conv_macs_params(node, input_shape, output_shape) -> Tuple[int, int]:
    """Calculate MACs and parameters for Conv layers"""
    attrs = {attr.name: attr for attr in node.attribute}
    
    kernel_shape = attrs['kernel_shape'].ints if 'kernel_shape' in attrs else [1, 1]
    group = attrs['group'].i if 'group' in attrs else 1
    
    # Handle cases where shapes might be incomplete
    if len(input_shape) < 4 or len(output_shape) < 4:
        return 0, 0
        
    in_channels = input_shape[1]
    out_channels = output_shape[1]
    
    output_height = output_shape[2]
    output_width = output_shape[3]
    
    # Calculate parameters and MACs
    params = (in_channels * kernel_shape[0] * kernel_shape[1] * out_channels // group) + out_channels
    macs = (params - out_channels) * output_height * output_width
    
    return macs, params

def analyze_model_metrics(model_path: str) -> Dict:
    """Analyze ONNX model metrics"""
    model = onnx.load(model_path)
    graph = model.graph
    
    # Build value_info dictionary
    value_info = {}
    initializers = {init.name: init for init in graph.initializer}
    
    # Add input shapes
    for input_info in graph.input:
        value_info[input_info.name] = input_info
        
    # Add output shapes
    for output_info in graph.output:
        value_info[output_info.name] = output_info
        
    # Add value_info shapes
    for info in graph.value_info:
        value_info[info.name] = info
    
    # Try to infer shapes if possible
    try:
        model_with_shapes = onnx.shape_inference.infer_shapes(model)
        for info in model_with_shapes.graph.value_info:
            value_info[info.name] = info
    except Exception as e:
        print(f"Warning: Shape inference failed: {e}")
    
    total_macs = 0
    total_params = 0
    layer_stats = defaultdict(lambda: {'count': 0, 'macs': 0, 'params': 0})
    
    for node in graph.node:
        op_type = node.op_type
        layer_stats[op_type]['count'] += 1
        
        input_shapes = get_node_input_shape(node, value_info, initializers)
        output_shapes = get_node_output_shape(node, value_info)
        
        if not input_shapes or not output_shapes:
            continue
            
        macs = params = 0
        
        if op_type == 'Conv':
            macs, params = calculate_conv_macs_params(node, input_shapes[0], output_shapes[0])
        elif op_type == 'Gemm':
            if len(input_shapes[0]) > 0 and len(output_shapes[0]) > 0:
                in_features = np.prod(input_shapes[0][1:])  # Handle flattened inputs
                out_features = output_shapes[0][1]
                params = in_features * out_features + out_features
                macs = in_features * out_features
        
        layer_stats[op_type]['macs'] += macs
        layer_stats[op_type]['params'] += params
        total_macs += macs
        total_params += params
    
    total_flops = total_macs * 2
    gflops = total_flops / 1e9
    
    return {
        'total_macs': total_macs,
        'total_params': total_params,
        'total_gflops': gflops,
        'layer_stats': dict(layer_stats)
    }

def format_number(num: int) -> str:
    """Format number to millions or billions"""
    if num >= 1e9:
        return f"{num/1e9:.2f}G"
    elif num >= 1e6:
        return f"{num/1e6:.2f}M"
    else:
        return f"{num:.2f}K"

def print_model_metrics(model_path: str):
    """Print model metrics"""
    print(f"\n{'='*20} Model Metrics {'='*20}")
    
    try:
        metrics = analyze_model_metrics(model_path)
        
        print(f"\nTotal MACs: {format_number(metrics['total_macs'])}")
        print(f"Total Parameters: {format_number(metrics['total_params'])}")
        print(f"Total GFLOPs: {metrics['total_gflops']:.2f}")
        
        print(f"\n{'='*20} Layer Statistics {'='*20}")
        print(f"{'Layer Type':<15} {'Count':<8} {'MACs':<12} {'Params':<12}")
        print("-" * 47)
        
        for layer_type, stats in sorted(metrics['layer_stats'].items()):
            if stats['count'] > 0:
                print(f"{layer_type:<15} {stats['count']:<8} {format_number(stats['macs']):<12} {format_number(stats['params']):<12}")
    except Exception as e:
        print(f"Error analyzing model: {e}")

if __name__ == "__main__":
    model_path =  r"C:\Users\user\Desktop\AI_npu\code\runs\pruning_onnx\tooth.onnx"
    print_model_metrics(model_path)



Total MACs: 4.04G
Total Parameters: 3.01M
Total GFLOPs: 8.08

Layer Type      Count    MACs         Params      
-----------------------------------------------
Add             9        0.00K        0.00K       
Concat          19       0.00K        0.00K       
Constant        22       0.00K        0.00K       
Conv            64       4.04G        3.01M       
Div             2        0.00K        0.00K       
Gather          1        0.00K        0.00K       
MaxPool         3        0.00K        0.00K       
Mul             60       0.00K        0.00K       
Reshape         5        0.00K        0.00K       
Resize          2        0.00K        0.00K       
Shape           1        0.00K        0.00K       
Sigmoid         58       0.00K        0.00K       
Slice           2        0.00K        0.00K       
Softmax         1        0.00K        0.00K       
Split           9        0.00K        0.00K       
Sub             2        0.00K        0.00K       
Transpose       1    

In [4]:
0.76679 / 3.01104

0.254659519634412

In [7]:
8.1 /6.5

1.2461538461538462