In [None]:
from torch import nn
import numpy as np

def analyze_model_complexity(model, input_size=(1, 3, 224, 224)):
    """
    Analyzes the complexity of a PyTorch model.
    Args:
        model: PyTorch model to analyze
        input_size: Input tensor size (batch_size, channels, height, width)
    Returns:
        dict: Dictionary containing various complexity metrics
    """
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Calculate model size in MB
    model_size = sum(p.nelement() * p.element_size() for p in model.parameters()) / (1024 * 1024)

    # Get FLOPs and MACs
    def hook_fn(module, input, output):
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if isinstance(module, nn.Linear):
                hook_fn.flops += np.prod(output.shape) * module.in_features
            else:
                hook_fn.flops += np.prod(output.shape) * module.in_channels * module.kernel_size[0] * module.kernel_size[1] / module.groups

    hook_fn.flops = 0
    hooks = []

    # Register hooks
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            hooks.append(m.register_forward_hook(hook_fn))

    # Run a forward pass
    with torch.no_grad():
        model(torch.rand(input_size))

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return {
        "total_parameters": total_params,
        "trainable_parameters": trainable_params,
        "model_size_mb": model_size,
        "approximate_flops": hook_fn.flops,
        "architecture_analysis": {
            "use_depthwise_separable": any(isinstance(m, nn.Conv2d) and m.groups > 1 for m in model.modules()),
            "use_squeeze_excitation": any(isinstance(m, SEBlock) for m in model.modules()),
            "inverted_residuals": any(isinstance(m, InvertedResidual) for m in model.modules())
        }
    }
# Example usage
model = WMobNetv2()
metrics = analyze_model_complexity(model)

# Print analysis
print("\nModel Complexity Analysis:")
print(f"Total Parameters: {metrics['total_parameters']:,}")
print(f"Trainable Parameters: {metrics['trainable_parameters']:,}")
print(f"Model Size: {metrics['model_size_mb']:.2f} MB")
print(f"Approximate FLOPs: {metrics['approximate_flops']:,}")

print("\nArchitecture Features:")
for feature, present in metrics['architecture_analysis'].items():
    print(f"- {feature}: {'Yes' if present else 'No'}")


Model Complexity Analysis:
Total Parameters: 1,138,876
Trainable Parameters: 1,138,876
Model Size: 4.34 MB
Approximate FLOPs: 122,490,672.0

Architecture Features:
- use_depthwise_separable: Yes
- use_squeeze_excitation: Yes
- inverted_residuals: Yes
