In [1]:
import fvcore
import torch
import fvcore.nn
import random

In [2]:
from flop_count.flop_handlers import elementwise_flop_jit, transpose_flop_jit, softmax_flop_jit, gelu_flop_jit

In [13]:
# custom handlers
flop_handlers = {"aten::add": elementwise_flop_jit,
            "aten::sub": elementwise_flop_jit,
            "aten::div": elementwise_flop_jit,
            "aten::numpy_T": transpose_flop_jit,
            "aten::softmax": softmax_flop_jit,
            "aten::gelu": gelu_flop_jit}


class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):

    def __init__(self, model, inputs):
        super().__init__(model, inputs)
        self.set_op_handle(**flop_handlers)


In [None]:
from models.vringformer import VRingFormer, VRingFormer_CONFIGS

VRingFormer_CONFIGS['VRingFormer-B_16']

In [16]:
model = VRingFormer(VRingFormer_CONFIGS['VRingFormer-B_16'])

In [None]:
input = torch.randn(1, 3, 224, 224)

ring_flops = FlopCountAnalysis(model, input)

# convert to GFLOPs
print("Ring Model GFLOPS: ", ring_flops.total() / 1e9)

In [9]:
from models.universal_transformer import UiT, UiT_CONFIGS

In [None]:
UiT_CONFIGS['UiT-B_16']

In [11]:
uit_model = UiT(UiT_CONFIGS['UiT-B_16'])

In [None]:
input = torch.randn(1, 3, 224, 224)

uit_flops = FlopCountAnalysis(uit_model, input)
uit_flops.total()

# convert to GFLOPs
print("UiT Model GFLOPS: ", uit_flops.total() / 1e9)

In [13]:
from models.vanilla_transformer import VisionTransformer, ViT_CONFIGS

In [None]:
ViT_CONFIGS['ViT-B_16']

In [15]:
vit_model = VisionTransformer(ViT_CONFIGS['ViT-B_16'])

In [None]:
input = torch.randn(1, 3, 224, 224)

vit_flops = FlopCountAnalysis(vit_model, input)

# convert to GFLOPs
print("ViT Model GFLOPS: ", vit_flops.total() / 1e9)

In [19]:
from models.one_wide_feed_forward import OWF, OWF_CONFIGS

In [20]:
owf_model = OWF(OWF_CONFIGS['OWF-B_16'])

In [None]:
input = torch.randn(1, 3, 224, 224)

owf_flops = FlopCountAnalysis(owf_model, input)

# convert to GFLOPs
print("OWF Model GFLOPS: ", owf_flops.total() / 1e9)