In [6]:
import tensorflow as tf
import numpy as np
import importlib.util

In [7]:
def load_module(path):
    spec = importlib.util.spec_from_file_location("module.name", path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

def load_arch(arch_path, bands):
    arch = load_module(arch_path)
    nn = arch.CNN()

    g = tf.Graph()
    with g.as_default():
        nn.create_architecture(bands=bands)
    return g, nn

In [8]:
baseline, _ = load_arch("arch_baseline.py", 1)
views, _ = load_arch("arch_views.py", 1)
invariant, _ = load_arch("arch_invariant.py", 1)
residual, _ = load_arch("arch_residual.py", 1)

In [9]:
def count(graph):
    variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    shapes = [x.get_shape().as_list() for x in variables]
    return sum([np.prod(x) for x in shapes])

In [10]:
count(baseline), count(views), count(invariant), count(residual)

(5674145, 7035308, 1753705, 4863021)