In [None]:
print("Torch version:", torch.__version__)

In [None]:
import typing
from collections import Counter
import torch
import clip
from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, generic_activation_jit, get_shape

# Helper function to compute the product of elements in a list
def prod(lst):
    result = 1
    for x in lst:
        result *= x
    return result

# Custom JIT handle definitions
def generic_pooling_jit(name, multiplier=1):
    def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
        input_shape = get_shape(inputs[0])
        output_shape = get_shape(outputs[0])
        assert 2 <= len(input_shape) <= 5, input_shape
        flop = prod(input_shape) + prod(output_shape)
        flop_counter = Counter({name: flop * multiplier})
        return flop_counter
    return lambda inputs, outputs: pool_jit(inputs, outputs)

def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input_shape = get_shape(inputs[0])
    output_shape = get_shape(outputs[0])
    flop = prod(input_shape) * 2 + prod(output_shape)
    flop_counter = Counter({"softmax": flop})
    return flop_counter

def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input1_shape = get_shape(inputs[0])
    input2_shape = get_shape(inputs[1])
    assert len(input1_shape) == len(input2_shape) == 3
    assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
    flop = prod(input1_shape) * input2_shape[-1]
    flop_counter = Counter({"bmm": flop})
    return flop_counter

# Wrapper class for CLIP model
class ForwardWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, *inputs):
        return self.model(*inputs)

# Load the CLIP model
model, preprocess = clip.load("ViT-B/32", device="cpu")

# Wrap the CLIP model for FLOP counting
model_wrapper = ForwardWrapper(model.visual)

# Prepare a sample input (e.g., a batch with a single image)
image = torch.randn(1, 3, 224, 224)  # Example input tensor

# Define supported operations for FLOP counting
supported_ops = {
    "aten::batch_norm": batchnorm_flop_jit,
    "aten::relu": generic_activation_jit("relu"),
    "aten::linear": bmm_flop_jit,
    "aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
    "aten::softmax": softmax_jit,
    # Add more operations as needed based on the CLIP model architecture
}

# Perform FLOP counting
flops, _ = flop_count(model_wrapper, inputs=(image,), supported_ops=supported_ops)

print(f"Total FLOPs for a single inference pass: {flops}")
