# Libraries

In [None]:
# @author: innat
import os, warnings 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings("ignore")

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
try: 
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    tf.config.optimizer.set_jit(True)
except: 
    pass 

from tensorflow.keras import applications
from tensorflow import keras

# Build Backbone for DOLG

Create backbone model with multi-ouput, one for local branch and other for global branch processing. 

In [3]:
img_size   = 128
num_classe = 10

base = applications.EfficientNetB0(
        include_top=False,
        weights=None,
        input_tensor=keras.Input((img_size, img_size, 3))
    )
new_base = keras.Model(
    [base.inputs], 
    [
        base.get_layer('block5c_add').output,       # fol local branch 
        base.get_layer('block7a_project_bn').output # for global branch 
    ], 
    name='base_model'
)

# Create DOLGNet 

We use previously build backbone to build the DOLGNet.

In [8]:
from models.DOLG import DOLGNet

dolg_net = DOLGNet(new_base, num_classes=num_classe, activation='softmax')
dolg_net.build_graph().summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 base_model (Functional)        [(None, 8, 8, 112),  3634851     ['input_3[0][0]']                
                                 (None, 4, 4, 320)]                                               
                                                                                                  
 LocalBranch (DOLGLocalBranch)  (None, 4, 4, 1024)   4582656     ['base_model[1][0]']             
                                                                                            

# Dummy Train with MNIST

In [7]:
# prepare data 
def mnist_process(x, y):
    x = tf.expand_dims(tf.cast(x, dtype=tf.float32), axis=-1)  
    x = tf.repeat(x, repeats=3, axis=-1)
    x = tf.divide(x, 255)       
    x = tf.image.resize(x, [img_size, img_size])  
    y = tf.one_hot(y , depth=num_classe)  
    return x, y

(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(mnist_process)
train_ds = train_ds.shuffle(buffer_size=100).batch(16)

# execute model 
dolg_net.compile(
    optimizer='adam',
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy'])

dolg_net.fit(train_ds, epochs=1)



<keras.callbacks.History at 0x1a3ef874220>

# Custom Model with DOLGNet's Layers

In [19]:
# general 
from layers.GeM import GeneralizedMeanPooling2D

# special for dolgnet 
from layers.LocalBranch import DOLGLocalBranch
from layers.OrtholFusion import OrthogonalFusion

In [27]:
img_shape = 512

vision_input = keras.Input(shape=(img_shape, img_shape, 1), name="img")
x = keras.layers.Conv2D(16, 3, activation="relu")(vision_input)
x = keras.layers.Conv2D(32, 3, activation="relu")(x)
y = x = DOLGLocalBranch(IMG_SIZE=img_shape)(x)

x = keras.layers.MaxPooling2D(3)(x)
x = keras.layers.Conv2D(32, 3, activation="relu")(x)
x = keras.layers.Conv2D(16, 3, activation="relu")(x)
gem_pool = GeneralizedMeanPooling2D()(x)
gem_dens = keras.layers.Dense(1024, activation=None)(gem_pool)

vision_output = OrthogonalFusion()([y, gem_dens])
vision = keras.Model(vision_input, vision_output, name="vision")
vision.summary(expand_nested=True, line_length=110)

Model: "vision"
______________________________________________________________________________________________________________
 Layer (type)                       Output Shape            Param #      Connected to                         
 img (InputLayer)                   [(None, 512, 512, 1)]   0            []                                   
                                                                                                              
 conv2d_165 (Conv2D)                (None, 510, 510, 16)    160          ['img[0][0]']                        
                                                                                                              
 conv2d_166 (Conv2D)                (None, 508, 508, 32)    4640         ['conv2d_165[0][0]']                 
                                                                                                              
 LocalBranch (DOLGLocalBranch)      (None, 16, 16, 1024)    4029696      ['conv2d_166[0][0]']   