In [1]:
from torchtnt.utils.flops import FlopTensorDispatchMode
from torch import nn     
import torch
import copy
from brainaudio.models.transformer import Transformer

  import pkg_resources
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
neural_dim = 256
kernel_len = 32
layer_dim = 5
hidden_dim = 1024

tf_hidden_dim = 384
tf_layers = 5
tf_num_heads = 6
tf_head_size = 64
tf_mlp_dim_ratio = 4
tf_dropout = 0

In [12]:
gru_decoder = nn.GRU(
            (neural_dim) * kernel_len,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=0,
            bidirectional=False
        )

fc_decoder_out = nn.Linear(hidden_dim, 40 + 1) 


gru_decoder_bidirectional = nn.GRU(
            (neural_dim) * kernel_len,
            hidden_dim,
            layer_dim,
            batch_first=True,
            dropout=0,
            bidirectional=True
        )

fc_decoder_out_bi = nn.Linear(hidden_dim*2, 40 + 1) 


tf_model = Transformer(384, 5, 6, 64, 4, 
                                    0, use_relative_bias=True)


fc_decoder_out = nn.Linear(hidden_dim, 40 + 1) 

fc_decoder_out_2 = nn.Linear(384, 40 + 1) 

In [4]:
inputs = torch.randn(1,500,256) # 10 seconds of input, each bin is 20 ms, 20 ms * 500 = 10 sec

In [7]:
daySpecific = nn.Linear(256, 256)
with FlopTensorDispatchMode(daySpecific) as ftdm:
    res = daySpecific(inputs)
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")     
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")   

Total forward FLOPs: 32,768,000
≈ 3.28 MFLOPs


In [9]:
"""
Checking the shape after strided inputs 
should be ((X_len - kernel_len) / strideLen) + 1
((500 - 32) / 4) + 1 = 118.
"""
unfolder = torch.nn.Unfold(
            (kernel_len, 1), dilation=1, padding=0, stride=4
        )
stridedInputs = torch.permute(
            unfolder(
                torch.unsqueeze(torch.permute(inputs, (0, 2, 1)), 3)
            ),
            (0, 2, 1),
        )
stridedInputs.shape


torch.Size([1, 118, 8192])

In [10]:
stridedInputs = torch.randn(1,118, neural_dim*kernel_len)
with FlopTensorDispatchMode(gru_decoder) as ftdm:
    res_gru, _ = gru_decoder(stridedInputs)
    flops_forward = copy.deepcopy(ftdm.flop_counts)


    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

with FlopTensorDispatchMode(fc_decoder_out) as ftdm:
    res2 = fc_decoder_out(res_gru)
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

Total forward FLOPs: 6,310,330,368
≈ 631.03 MFLOPs
Total forward FLOPs: 4,954,112
≈ 0.50 MFLOPs


In [14]:
stridedInputs = torch.randn(1,118, neural_dim*kernel_len)
with FlopTensorDispatchMode(gru_decoder) as ftdm:
    res_gru, _ = gru_decoder_bidirectional(stridedInputs)
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

with FlopTensorDispatchMode(fc_decoder_out_bi) as ftdm:
    res2 = fc_decoder_out_bi(res_gru)
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

Total forward FLOPs: 15,590,227,968
≈ 1559.02 MFLOPs
Total forward FLOPs: 9,908,224
≈ 0.99 MFLOPs


In [None]:
print(f"MFLOPS FOR UNDIRECTIONAL GRU: {631.03 + 0.50 + 3.28}")

MFLOPS FOR UNDIRECTIONAL GRU: 634.81


In [15]:
print(f"MFLOPS FOR BIDIRECTIONAL GRU: {1559.02  + 0.99 + 3.28}")

MFLOPS FOR BIDIRECTIONAL GRU: 1563.29


In [16]:
tf_inputs = torch.randn(1,500,256*5)
patch_transform = nn.Linear(256*5, 384)
with FlopTensorDispatchMode(patch_transform) as ftdm:
    patched_inputs = patch_transform(tf_inputs)
    #seq_out = fc_decoder_out_2(res)
    flops_forward = copy.deepcopy(ftdm.flop_counts)
    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

Total forward FLOPs: 245,760,000
≈ 24.58 MFLOPs


In [24]:
tf_inputs = torch.randn(1,100,384)
model = Transformer(384, 5, 6, 64, 4)

with FlopTensorDispatchMode(model) as ftdm:
    res = model(tf_inputs)
    #seq_out = fc_decoder_out_2(res)
    flops_forward = copy.deepcopy(ftdm.flop_counts)
    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs

Total forward FLOPs: 3,394,560,000
≈ 339.46 MFLOPs


In [None]:
with FlopTensorDispatchMode(fc_decoder_out_2) as ftdm:
    res2 = fc_decoder_out_2(res)
    #seq_out = fc_decoder_out(res)
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    
total_flops = sum(                # outer sum
    sum(inner.values())           #  ← inner sum
    for inner in flops_forward.values()
)

print(f"Total forward FLOPs: {total_flops:,}")        # e.g. 53,502,976
print(f"≈ {total_flops/1e6/10:.2f} MFLOPs")              # or /1e9 for GFLOPs