In [1]:
import dataset
import glob
import matplotlib.pyplot as plt
import metrics
import os
import time

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

In [2]:
def test_model(model_class):
    _, _2, test_ds = dataset.load_dataset("voice_dialect_ds:1.0.0")
    test_ds = model_class.preprocess_dataset(test_ds)
    test_ds = test_ds.batch(model_class.batch_size, drop_remainder=True)
    test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    model = model_class.get_model()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
              loss={
                  "age": "categorical_crossentropy",
                  "dialect": "categorical_crossentropy",
                  "gender": "binary_crossentropy"
              },
              metrics={
                  "age": ['accuracy', metrics.PerClassAccuracy(len(dataset.age_list)), tf.keras.metrics.TopKCategoricalAccuracy(2), tf.keras.metrics.TopKCategoricalAccuracy(3, name="top3"), tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tfa.metrics.F1Score(len(dataset.age_list))],
                  "dialect": ['accuracy', metrics.PerClassAccuracy(len(dataset.dialect_list)), tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tfa.metrics.F1Score(len(dataset.dialect_list))],
                  "gender": ['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tfa.metrics.F1Score(1, threshold=0.5)],
              })
    
    latest_checkpoint_path = os.path.join("checkpoint", model_class.model_name, "latest.ckpt")
    if len(glob.glob(latest_checkpoint_path + "*")) > 0:
        model.load_weights(latest_checkpoint_path)
    return model.evaluate(test_ds)

In [3]:
import models.dvector_single
test_model(models.dvector_single)



In [4]:
import models.resnet
test_model(models.resnet)



























In [3]:
import models.resnet101
test_model(models.resnet101)



[2.814854860305786,
 1.8369615077972412,
 0.7618625164031982,
 0.21603304147720337,
 0.39273107051849365,
 array([0.4218154 , 0.4202838 , 0.26511225, 0.30896345, 0.3651493 ,
        0.56559765], dtype=float32),
 0.6413671374320984,
 0.7977374792098999,
 0.7582483291625977,
 0.4288657605648041,
 0.3173138499259949,
 array([0.3913659 , 0.42723802, 0.29740858, 0.31738684, 0.34290966,
        0.54103583], dtype=float32),
 0.7991014122962952,
 array([0.99249196, 0.6638601 , 0.8972641 , 0.81135225, 0.48776907,
        0.8514799 ], dtype=float32),
 0.9555674195289612,
 0.8195003271102905,
 0.7842586636543274,
 array([0.9857955 , 0.65789473, 0.8132591 , 0.8053024 , 0.5878537 ,
        0.80389225], dtype=float32),
 0.9420731663703918,
 0.9735222458839417,
 0.9516009092330933,
 0.9415058493614197,
 array([0.94652647], dtype=float32)]

In [6]:
import models.efficientnet
test_model(models.efficientnet)



























In [7]:
import models.efficientnetb5
test_model(models.efficientnetb5)



























In [4]:
import models.efficientnetb7
test_model(models.efficientnetb7)



























[2.438215732574463,
 1.5943219661712646,
 0.6660344004631042,
 0.1778581440448761,
 0.40933889150619507,
 array([0.470229  , 0.3833959 , 0.31231102, 0.3094642 , 0.3318649 ,
        0.64446294], dtype=float32),
 0.6810815334320068,
 0.8390564918518066,
 0.7880060076713562,
 0.4877054691314697,
 0.2880295217037201,
 array([0.42134058, 0.40078503, 0.3295351 , 0.31570882, 0.3362261 ,
        0.60920894], dtype=float32),
 0.8113767504692078,
 array([0.98711985, 0.68996763, 0.9004474 , 0.8315263 , 0.5107527 ,
        0.86627907], dtype=float32),
 0.9614785313606262,
 0.8377140760421753,
 0.792682945728302,
 array([0.97629154, 0.67489713, 0.87714523, 0.78970295, 0.6103972 ,
        0.825277  ], dtype=float32),
 0.9528241157531738,
 0.9762876629829407,
 0.953454315662384,
 0.9602004885673523,
 array([0.95681554], dtype=float32)]