In [1]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.config import list_physical_devices
from tensorflow.image import ResizeMethod
from tensorflow.image import resize_with_pad

import utils

In [2]:
trainX, trainY, testX, testY = utils.load_dataset()

In [3]:
def ensemble_avg():
    inputs = keras.Input(shape=(32,32,3))

    all_models_str = ['models/custom_1', 'models/custom_2']
    all_models = [keras.models.load_model(i) for i in all_models_str]
    all_model_output = []

    counter = 0
    for m in all_models:
        m.trainable = False
        all_model_output.append(m(inputs))
        print(str(counter) + ' done')
        counter += 1

    outputs = layers.average(all_model_output)
    ensemble_model = keras.Model(inputs=inputs, outputs=outputs, name='ensemble_avg')
    ensemble_model.compile(loss='categorical_crossentropy', metrics=['categorical_accuracy', keras.metrics.AUC()])
    return ensemble_model

In [4]:
model = ensemble_avg()

0 done
1 done


In [5]:
model.summary()

Model: "ensemble_avg"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
custom_1 (Functional)           (None, 10)           2524938     input_1[0][0]                    
__________________________________________________________________________________________________
custom_2 (Functional)           (None, 10)           924106      input_1[0][0]                    
__________________________________________________________________________________________________
average (Average)               (None, 10)           0           custom_1[0][0]                   
                                                                 custom_2[0][0]        

In [6]:
model.evaluate(x=testX, y=testY)



[0.716093897819519, 0.7791000008583069, 0.9688507914543152]