Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added FLOPs in our paper (Table 8, arXiv v4). #125

Closed
ku21fan opened this issue Dec 18, 2019 · 2 comments
Closed

Added FLOPs in our paper (Table 8, arXiv v4). #125

ku21fan opened this issue Dec 18, 2019 · 2 comments

Comments

@ku21fan
Copy link
Contributor

ku21fan commented Dec 18, 2019

Hello,

We received some requests about FLOPs of each model, thus we calculated it and updated our paper.
In this issue, we summarize the detail of our FLOPs calculation.

Our FLOPS calculation is approximate value.
Because

  1. Our calculation is mainly based on THOP, which is not an official PyTorch code. (but popular one)

  2. From this issue and readme of THOP, THOP seems like to calculate MACs instead of FLOPs, thus we just use # MACs * 2 as # FLOPs.

  3. We have some irregular modules, which are not in THOP: GridGenerator and LSTM/LSTMCell.
    Thus, we calculate FLOPs of GridGenerator module by this code.

def count_GridGenerator(m):
    # size
    num_fiducial_point = 20
    image_width = 32
    image_height = 100

    # count calculation # https://arxiv.org/pdf/1904.01906.pdf
    # we count euclidian distance (d_ij) as 3 MACs, since euclidian distance (d_ij) is root(square(c_i - c_j))
    R = num_fiducial_point * num_fiducial_point * 3 * 3  # 3600,  20x20 (size of R), 3 = square, *, ln, 3 = d_ij
    # we count matrix inversion as N^3 MACs
    inv_delta_C = (num_fiducial_point + 3) ** 3 # 12167
    T = (num_fiducial_point + 3) * (num_fiducial_point + 3) * 2  # 1058
    P = image_width * image_height * (num_fiducial_point + 3) * 2  # 147200

    total_ops = R + inv_delta_C + T + P  # 164025, about 0.164M MACs

    m.total_ops += torch.Tensor([int(total_ops)])

and calculate FLOPs of LSTM by this code.

def count_LSTM(m, x, y):
    # size
    input_size = x[0].size(-1)
    hidden_state_size = y[0].size(-1)  # = output_size
    cell_state_size = y[0].size(-1)  # = output_size

    # count calculation https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM
    # count sigmoid/tanh activation function as 0 MACs
    # 3*hidden_state_size = count addition operation.
    input_gate = input_size * hidden_state_size + hidden_state_size * hidden_state_size \
        + 3 * hidden_state_size
    forget_gate = input_size * hidden_state_size + hidden_state_size * hidden_state_size \
        + 3 * hidden_state_size
    cell_gate = input_size * hidden_state_size + hidden_state_size * hidden_state_size \
        + 3 * hidden_state_size
    output_gate = input_size * hidden_state_size + hidden_state_size * hidden_state_size \
        + 3 * hidden_state_size

    update_cell_state = hidden_state_size + hidden_state_size + hidden_state_size
    update_hidden_state = hidden_state_size

    total_ops = input_gate + forget_gate + cell_gate + output_gate + update_cell_state + update_hidden_state

    time_step = x[0].size(-2)

    m.total_ops += torch.Tensor([int(total_ops)]) * time_step

We attached our modified profile code of THOP and we simply use the below code to calculate FLOPs.

import torch
import model
from thop import profile
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input = torch.randn(1, 1, 32, 100).to(device)
text_for_pred = torch.LongTensor(1, opt.batch_max_length + 1).fill_(0).to(device)
model_ = model.Model(opt).to(device)
MACs, params = profile(model_, inputs=(input, text_for_pred, ))
flops = 2 * MACs # approximate FLOPs

If you found some issues, please let us know.

Best.

@ku21fan ku21fan closed this as completed Dec 28, 2019
@piyawat-at
Copy link

According in your paper is it the number of Floating point operations (FLOPs ) not the Floating point operations per second (FLOPS) ?

@ku21fan
Copy link
Contributor Author

ku21fan commented Jun 28, 2022

You are right. It is FLOPs rather than FLOPS.

That was my mistake.

Thank you for the comment :)

@ku21fan ku21fan changed the title Added FLOPS in our paper (Table 8, arXiv v4). Added FLOPs in our paper (Table 8, arXiv v4). Jun 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants