In [1]:
import tensorflow as tf
import google.protobuf

In [2]:
def count_flops(chkpoint_dir_path):
    chkpoint_path = tf.train.latest_checkpoint(chkpoint_dir_path)
    variables = { u:v for u,v in tf.train.list_variables(chkpoint_path) }

    def parse_output_shape(f):
        return [el.size for el in f.attr['_output_shapes'].list.shape[0].dim]

    def get_var_name(name):
        name = str(name)
        assert name.endswith('/read'), "invalid name: %s" % name
        return name[:-5]

    graph = tf.GraphDef()
    graph = google.protobuf.text_format.Parse(open(chkpoint_dir_path + "/graph.pbtxt", "r").read(), graph)
    fields = graph.ListFields()

    flops = []
    for f in fields[0][1]:
        if f.op == "DepthwiseConv2dNative":
            kernel = variables[get_var_name(f.input[1])]
            output = parse_output_shape(f)
            flops += [kernel[0] * kernel[1] * output[1] * output[2] * output[3]]
        elif f.op == "Conv2D":
            kernel = variables[get_var_name(f.input[1])]
            output = parse_output_shape(f)
            flops += [kernel[0] * kernel[1] * kernel[2] * output[1] * output[2] * output[3]]
    print (sum(flops), "Flops =", sum(flops)/10**6, "MFlops")
    return sum(flops)

# Flops in auxiliary networks

In [3]:
count_flops('datasets/cityscapes/exp/nus64_128_mobilenetv2_lr0.007_mt0.25/train')

17600064 Flops = 17.600064 MFlops


17600064

In [4]:
count_flops('datasets/cityscapes/exp/nus32_128_mobilenetv2_lr0.007_mt1/train')

124646464 Flops = 124.646464 MFlops


124646464

# Flops in the base network
## 128x128

In [5]:
count_flops('datasets/cityscapes/exp/train_on_train_set_mobilenetv2_w_nus/train')

161547264 Flops = 161.547264 MFlops


161547264

## 160x160

In [6]:
count_flops('datasets/cityscapes/exp/train_on_train_set_mobilenetv2_w_nus_at_160/train')

252371520 Flops = 252.37152 MFlops


252371520

## 192x192

In [7]:
count_flops('datasets/cityscapes/exp/train_on_train_set_mobilenetv2_w_nus_at_192/train')

363378944 Flops = 363.378944 MFlops


363378944

## 224x224

In [8]:
count_flops('datasets/cityscapes/exp/train_on_train_set_mobilenetv2_w_nus_at_224/train')

494569536 Flops = 494.569536 MFlops


494569536