# Get the FLOPs for various model architectures studied in the paper

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ..

/home/yiboyang/projects/code_releases/shallow-ntc


In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [4]:
import tensorflow as tf
import tensorflow_compression as tfc
import numpy as np
import pandas as pd
from absl import logging
from ml_collections import ConfigDict

import json
import inspect

In [5]:
from common.profile_utils import get_flops

In [6]:
from common import transforms

In [7]:
img_shape = [1, 512, 768, 3]
x = tf.zeros(img_shape)
npixels = np.prod(img_shape[1:3])

2023-10-03 19:38:59.977750: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [8]:
all_flops = dict()

In [9]:
all_params = dict()
def count_params(module):
  return np.sum([np.prod(v.get_shape().as_list()) for v in module.trainable_variables])

In [10]:
# Trick to getting flops counter to work with custom keras model/module
# (wrap the module in keras functional API):
def get_model_with_known_input(module, input_shape):
  input_layer = tf.keras.Input(shape=input_shape[1:])  # Input wants the shape of input tensor without batch dim
  tmp_model = tf.keras.Model(inputs=input_layer, outputs=module(input_layer))
  return tmp_model

## Baseline - Factorized prior

In [11]:
channels_base = 192
bottleneck_size = 320

ana = transforms.CNNAnalysis(channels_base=channels_base, output_channels=bottleneck_size)
y = ana(x)
print('y shape', y.shape)

syn = transforms.CNNSynthesis(channels_base=channels_base, output_channels=3)
xhat = syn(y)
print('xhat shape', xhat.shape)

# Set FLOPs
method = 'Ballé 2017 Factorized Prior'
all_flops[method] = {}
all_params[method] = {}

for (key, T) in zip(['f', 'g'], (ana, syn)):
  flops = get_flops(T, batch_size=1).total_float_ops
  all_flops[method][key] = flops
  all_params[method][key] = count_params(T)
print(all_flops[method])
print(all_params[method])

y shape (1, 32, 48, 320)
xhat shape (1, 512, 768, 3)
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
{'f': 64198115328, 'g': 64198803456}
{'f': 3394496, 'g': 3394179}


## Baseline - Mean-Scale Hyperprior

In [12]:
channels_base = 192
bottleneck_size = 320

ana_act = 'gdn'
ana = transforms.CNNAnalysis(channels_base=channels_base, output_channels=bottleneck_size, activation_type=ana_act)
y = ana(x)
print('y shape', y.shape)

hana = transforms.HyperAnalysis(bottleneck_size)

z = hana(y)
print('z shape', z.shape)

hsyn = transforms.HyperSynthesis(bottleneck_size)
py_params = hsyn(z)
syn_act = 'igdn'
syn = transforms.CNNSynthesis(channels_base=channels_base, output_channels=3, activation_type=syn_act)
xhat = syn(y)
print('xhat shape', xhat.shape)

# Set FLOPs
method = 'Minnen 2018 Hyperprior'
all_flops[method] = {}
all_params[method] = {}

for (key, T) in zip(['f', 'f_h', 'g', 'g_h'], (ana, hana, syn, hsyn)):
  flops = get_flops(T, batch_size=1).total_float_ops
  all_flops[method][key] = flops
  all_params[method][key] = count_params(T)


y shape (1, 32, 48, 320)
z shape (1, 8, 12, 320)
xhat shape (1, 512, 768, 3)


In [13]:
all_params

{'Ballé 2017 Factorized Prior': {'f': 3394496, 'g': 3394179},
 'Minnen 2018 Hyperprior': {'f': 3431552,
  'f_h': 6042560,
  'g': 3431235,
  'g_h': 9166240}}

## Linearized CNNSynthesis, for comparison with JPEG-like syn

In [14]:
channels_base = 192
bottleneck_size = 320

ana_act = None
ana = transforms.CNNAnalysis(channels_base=channels_base, output_channels=bottleneck_size, activation_type=ana_act)
y = ana(x)
print('y shape', y.shape)

hana = transforms.HyperAnalysis(bottleneck_size)

z = hana(y)
print('z shape', z.shape)


syn_act = None
syn = transforms.CNNSynthesis(channels_base=channels_base, output_channels=3, activation_type=syn_act)
xhat = syn(y)
print('xhat shape', xhat.shape)

flops = get_flops(syn, batch_size=1).total_float_ops

y shape (1, 32, 48, 320)
z shape (1, 8, 12, 320)
xhat shape (1, 512, 768, 3)


In [15]:
flops / npixels

163266.0

In [16]:
flops / npixels / 2   # in MACs

81633.0

## Proposed (1-layer) JPEG-like synthesis, with varying kernel sizes

In [17]:
jpeg_syn_flops = {}
for k in [16, 18, 20, 26, 32]:
  jpeg_syn = tf.keras.Sequential(tf.keras.layers.Conv2DTranspose(3, kernel_size=k, strides=16, padding='SAME', use_bias=True, input_shape=[32, 48, 320]))
  flops = get_flops(jpeg_syn, batch_size=1).total_float_ops
  print(k, ', parameter count =', count_params(jpeg_syn))
  jpeg_syn_flops[k] = flops

16 , parameter count = 245763
18 , parameter count = 311043
20 , parameter count = 384003
26 , parameter count = 648963
32 , parameter count = 983043


In [18]:
all_params['JPEG-like syn. (proposed)'] = {'g': 311043}

In [19]:
jpeg_syn_flops

{16: 756154368, 18: 956694528, 20: 1180827648, 26: 1994784768, 32: 3021078528}

In [20]:
# (pd.DataFrame(jpeg_syn_flops, index=['flops'])/ npixels).to_csv('results/jpeg_syn_fpp.csv')

## Proposed 2-layer syn

In [21]:
channels_base = 192
bottleneck_size = 320
C1 = 12  # for the hidden layer size

from common.elic import ElicAnalysis, ElicSynthesis

ana = tf.keras.Sequential([
  ElicAnalysis(channels=[192, 192, 192, 320])
])
y = ana(x)
print('y shape', y.shape)

hana = transforms.HyperAnalysis(bottleneck_size)

z = hana(y)
print('z shape', z.shape)

hsyn = transforms.HyperSynthesis(bottleneck_size)
py_params = hsyn(z)
syn_act = 'igdn'
syn = transforms.TwoLayerResSynthesis(channels=(C1, 3), strides=(8, 2),
                                              kernel_sizes=(13, 5), activation_type=syn_act, res_type='conv')
# flops counter crashes on a non-standard model like the above
# with a residual connection. To get around this, we do the
# following trick using keras functional API. 
syn_input = tf.keras.Input(shape=y.shape[1:])
syn = tf.keras.Model(inputs=syn_input, outputs=syn(syn_input))
xhat = syn(y)
print('xhat shape', xhat.shape)

print('synthesis parameter count =', count_params(syn))


# Set FLOPs
method = '2-layer syn. (proposed)'
all_flops[method] = {}
all_params[method] = {}

for (key, T) in zip(['f', 'f_h', 'g', 'g_h'], (ana, hana, syn, hsyn)):
  flops = get_flops(T, batch_size=1).total_float_ops
  all_flops[method][key] = flops
  all_params[method][key] = count_params(T)

y shape (1, 32, 48, 320)
z shape (1, 8, 12, 320)
xhat shape (1, 512, 768, 3)
synthesis parameter count = 1299003


In [22]:
all_params['2-layer syn. + SGA (proposed)'] = all_params['2-layer syn. (proposed)'].copy()

## Comparable 2-layer syn *without* residual connection, just two layers of conv transposed

In [23]:
channels_base = 192
bottleneck_size = 320
C1 = 24  # for the hidden layer size of 2layerRes

y = tf.random.normal([1, 32, 48, 320])
syn_act = 'igdn'
syn = transforms.TwoLayerSynthesis(channels=(C1, 3), strides=(8, 2),
                                              kernel_sizes=(13, 5), activation_type=syn_act)

# To trick flops counter to work using functional API
syn_input = tf.keras.Input(shape=y.shape[1:])
syn = tf.keras.Model(inputs=syn_input, outputs=syn(syn_input))

flops = get_flops(syn, batch_size=1).total_float_ops
flops, flops / npixels

(4462610184, 11349.004577636719)

In [24]:
count_params(syn)

1300347

In [25]:
all_flops['2-layer syn. (proposed)']['g'] / npixels

10677.001190185547

In [26]:
print('Equivalent no-residual syn uses', 11349 - 10680, 'more FLOPs per pixel')

Equivalent no-residual syn uses 669 more FLOPs per pixel


## For ELIC

In [27]:
channels_base = 192
bottleneck_size = 320
from common.elic import ElicAnalysis, ElicSynthesis

ana = tf.keras.Sequential([
  ElicAnalysis(channels=[192, 192, 192, 320])
])
y = ana(x)
print('y shape', y.shape)

hana = transforms.HyperAnalysis(bottleneck_size)

z = hana(y)
print('z shape', z.shape)

hsyn = transforms.HyperSynthesis(bottleneck_size)
py_params = hsyn(z)
syn = tf.keras.Sequential([
  ElicSynthesis(channels=[192, 192, 192, 3])
])
xhat = syn(y)
print('xhat shape', xhat.shape)

# Set FLOPs
method = 'He 2022 ELIC'
all_flops[method] = {}
all_params[method] = {}

for (key, T) in zip(['f', 'f_h', 'g', 'g_h'], (ana, hana, syn, hsyn)):
  flops = get_flops(T, batch_size=1).total_float_ops
  all_flops[method][key] = flops
  all_params[method][key] = count_params(T)

y shape (1, 32, 48, 320)
z shape (1, 8, 12, 320)
xhat shape (1, 512, 768, 3)


In [28]:
### back-of-envelope calculation for the CHARM component of the ELIC hyperprior

In [29]:
# Modified from ms2020 to allow custom slice_depth
import functools
class SliceTransform(tf.keras.layers.Layer):
    """Transform for channel-conditional params and latent residual prediction."""

    def __init__(self, slice_depth=None, latent_depth=None, num_slices=None):
        super().__init__()
        conv = functools.partial(
            tfc.SignalConv2D, corr=False, strides_up=1, padding="same_zeros",
            use_bias=True, kernel_parameter="variable")

        # Note that the number of channels in the output tensor must match the
        # size of the corresponding slice. If we have 10 slices and a bottleneck
        # with 320 channels, the output is 320 / 10 = 32 channels.
        if slice_depth is None:
          slice_depth = latent_depth // num_slices
          if slice_depth * num_slices != latent_depth:
              raise ValueError("Slices do not evenly divide latent depth (%d / %d)" % (
                  latent_depth, num_slices))

        self.transform = tf.keras.Sequential([
            conv(224, (5, 5), name="layer_0", activation=tf.nn.relu),
            conv(128, (5, 5), name="layer_1", activation=tf.nn.relu),
            conv(slice_depth, (3, 3), name="layer_2", activation=None),
        ])

    def call(self, tensor):
        return self.transform(tensor)



In [30]:
latent_means, latent_scales = tf.split(py_params, 2, axis=-1)

In [31]:
latent_means.shape

TensorShape([1, 32, 48, 320])

In [32]:
num_slices = 5
max_support_slices = 5

latent_depth = 320

# y_slices = tf.split(y, num_slices, axis=-1)   # CHARM
# cc_mean_transforms = [SliceTransform(latent_depth=latent_depth, num_slices=num_slices) for _ in range(num_slices)]

elic_slice_depths = [16, 16, 32, 64, latent_depth-128]
y_slices = tf.split(y, elic_slice_depths, axis=-1)  # ELIC
cc_mean_transforms = [SliceTransform(slice_depth=sd) for sd in elic_slice_depths]
cc_mean_transforms_Models = []  # tmp keras Models for computing flops


y_hat_slices = []

for slice_index, y_slice in enumerate(y_slices):
    # Model may condition on only a subset of previous slices.
    support_slices = (y_hat_slices if max_support_slices < 0 else
                      y_hat_slices[:max_support_slices])

    # Predict mu and sigma for the current slice.
    mean_support = tf.concat([latent_means] + support_slices, axis=-1)
    mu = cc_mean_transforms[slice_index](mean_support)
    cc_mean_transforms_Models.append(get_model_with_known_input(cc_mean_transforms[slice_index], mean_support.shape))
    # mu = mu[:, :y_shape[0], :y_shape[1], :]

    # # Note that in this implementation, `sigma` represents scale indices,
    # # not actual scale values.
    # scale_support = tf.concat([latent_scales] + support_slices, axis=-1)
    # sigma = self.cc_scale_transforms[slice_index](scale_support)
    y_hat_slice = mu
    
    y_hat_slices.append(y_hat_slice)


In [33]:
elic_flops_across_slices = [get_flops(T.transform, batch_size=1).total_float_ops for T in cc_mean_transforms]



In [34]:
np.array(elic_flops_across_slices) / npixels

array([19745.4375, 20445.4375, 21289.5   , 22977.625 , 26930.125 ])

In [35]:
elic_flops_across_slices = [get_flops(T, batch_size=1).total_float_ops for T in cc_mean_transforms_Models]

In [36]:
np.array(elic_flops_across_slices) / npixels

array([19745.4375, 20445.4375, 21289.5   , 22977.625 , 26930.125 ])

### Manually add to hyper synthesis FLOPs; not worrying about the checkerboard stuff for now

In [37]:
all_flops['He 2022 ELIC']['g_h'] = all_flops['Minnen 2018 Hyperprior']['g_h'] + np.sum(elic_flops_across_slices) * 2 # x2 for both mean and scale

### Add a few more methods with derived stats

In [38]:
method = 'JPEG-like syn. (proposed)'
all_flops[method] = all_flops['Minnen 2018 Hyperprior'].copy()  # COPY!
all_flops[method]['f'] = all_flops['He 2022 ELIC']['f']
all_flops[method]['g'] = jpeg_syn_flops[18]

In [39]:
method = '2-layer syn. + SGA (proposed)'
all_flops[method] = all_flops['2-layer syn. (proposed)'].copy()

### Add EVC based on paper Fig. 2 of the paper https://openreview.net/pdf?id=XUxad2Gj40n

In [40]:
method = 'Wang 2023 EVC' # corresponding to the "Large" configuration
all_flops[method] = dict(
  f=549.92 * 1e9 * 2, g=538.83 * 1e9 * 2, f_h=3.89 * 1e9 * 2, g_h=(44.68+28.06) * 1e9 * 2)
all_flops[method] = {k: (v / 1920 / 1088) * npixels for (k,v) in all_flops[method].items()}

all_params[method] = {'f': 3.19e6, 'g': 3.38e6}  # From Table 3 in the Appendix

# Note the paper listed MACs for 1920 x 1088 imgs; hence my conversion above.

## Save results for paper

In [41]:
import pandas as pd

In [42]:
all_fpp = pd.DataFrame(all_flops) / npixels
all_fpp = all_fpp.transpose()

In [43]:
all_fpp

Unnamed: 0,f,g,f_h,g_h
Ballé 2017 Factorized Prior,163264.25,163266.0,,
Minnen 2018 Hyperprior,187583.098145,187584.848145,13451.640625,30354.6875
2-layer syn. (proposed),510563.75,10677.00119,13451.640625,30354.6875
He 2022 ELIC,510563.75,510565.5,13451.640625,253130.9375
JPEG-like syn. (proposed),510563.75,2433.0,13451.640625,30354.6875
2-layer syn. + SGA (proposed),510563.75,10677.00119,13451.640625,30354.6875
Wang 2023 EVC,526501.22549,515883.501838,3724.341299,69642.310049


In [44]:
all_fpp['f_tot'] = all_fpp['f'] + all_fpp['f_h']
all_fpp['g_tot'] = all_fpp['g'] + all_fpp['g_h']

In [45]:
all_fpp.to_csv('results/all_fpp.csv')

In [46]:
all_fpp = pd.read_csv('results/all_fpp.csv', index_col=0)

In [47]:
all_fpp

Unnamed: 0,f,g,f_h,g_h,f_tot,g_tot
Ballé 2017 Factorized Prior,163264.25,163266.0,,,,
Minnen 2018 Hyperprior,187583.098145,187584.848145,13451.640625,30354.6875,201034.73877,217939.535645
2-layer syn. (proposed),510563.75,10677.00119,13451.640625,30354.6875,524015.390625,41031.68869
He 2022 ELIC,510563.75,510565.5,13451.640625,253130.9375,524015.390625,763696.4375
JPEG-like syn. (proposed),510563.75,2433.0,13451.640625,30354.6875,524015.390625,32787.6875
2-layer syn. + SGA (proposed),510563.75,10677.00119,13451.640625,30354.6875,524015.390625,41031.68869
Wang 2023 EVC,526501.22549,515883.501838,3724.341299,69642.310049,530225.566789,585525.811887


In [48]:
all_params_df = pd.DataFrame(all_params)
all_params_df = all_params_df.transpose()
all_params_df

Unnamed: 0,f,g,f_h,g_h
Ballé 2017 Factorized Prior,3394496.0,3394179.0,,
Minnen 2018 Hyperprior,3431552.0,3431235.0,6042560.0,9166240.0
JPEG-like syn. (proposed),,311043.0,,
2-layer syn. (proposed),7337792.0,1299003.0,6042560.0,9166240.0
2-layer syn. + SGA (proposed),7337792.0,1299003.0,6042560.0,9166240.0
He 2022 ELIC,7337792.0,7337475.0,6042560.0,9166240.0
Wang 2023 EVC,3190000.0,3380000.0,,


In [49]:
all_params_df.to_csv('results/all_params.csv')