# CNN Model parameter comparison

In [21]:
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_hub as hub
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd

In [3]:
print(f'TensorFlow version: {tf.__version__}')
print(f'TensorFlow Federated version: {tff.__version__}')

TensorFlow version: 2.5.1
TensorFlow Federated version: 0.19.0


In [48]:
IMAGE_SHAPE = (224, 224, 3)

In [49]:
MODELS = {
    'Inception V3': 'https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4',
    'MobileNet V2': 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
    'EfficientNet B7': 'https://tfhub.dev/tensorflow/efficientnet/b7/feature-vector/1',
    'ResNet152 V2': 'https://tfhub.dev/google/imagenet/resnet_v2_152/feature_vector/5',
}

In [50]:
result_all = dict()
result_freeze = dict()
for name, url in MODELS.items():
    for i in range(1, 11):
        feature_extractor_url = url
        feature_extractor_layer = hub.KerasLayer(url, input_shape=IMAGE_SHAPE)

        model = tf.keras.Sequential([
            feature_extractor_layer,
            tf.keras.layers.Dense(i, activation='softmax')
        ])
        
        feature_extractor_layer.trainable = True
        trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
        data = result_all.get(name, [])
        data.append(trainable_count)
        result_all[name] = data
        
        feature_extractor_layer.trainable = False
        trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
        data = result_freeze.get(name, [])
        data.append(trainable_count)
        result_freeze[name] = data

In [56]:
df_all = pd.DataFrame(result_all)
df_freeze = pd.DataFrame(result_freeze)

In [57]:
df_all - df_freeze

Unnamed: 0,Inception V3,MobileNet V2,EfficientNet B7,ResNet152 V2
0,21768352,2223872,63786960,58187904
1,21768352,2223872,63786960,58187904
2,21768352,2223872,63786960,58187904
3,21768352,2223872,63786960,58187904
4,21768352,2223872,63786960,58187904
5,21768352,2223872,63786960,58187904
6,21768352,2223872,63786960,58187904
7,21768352,2223872,63786960,58187904
8,21768352,2223872,63786960,58187904
9,21768352,2223872,63786960,58187904
