In [1]:
import dataset
import glob
import matplotlib.pyplot as plt
import metrics
import os
import time

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

In [2]:
# 경고 메시지 제거
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 사용할 모델 선택

In [3]:
#from models.cnn import model_name, batch_size, preprocess_dataset, get_model
from models.efficientnetb7 import model_name, batch_size, preprocess_dataset, get_model

# 데이터 불러오기

In [4]:
#train_sql = test_sql = "SELECT FileLocation, Gender, Age, Dialect FROM json \
#        WHERE Gender != 'NotProvided' and Age != 'NotProvided' and Dialect != 'NotProvided' \
#        and FileLocation not like '%zzmt%' \
#        and cast(FileLength as real) > 6 and cast(FileLength as real) < 8 \
#        ORDER BY random()"
#train_ds, val_ds, test_ds = dataset.load_database(train_sql, test_sql)
train_ds, val_ds, test_ds = dataset.load_dataset("voice_dialect_ds:1.0.0")

In [5]:
train_ds = preprocess_dataset(train_ds)
val_ds = preprocess_dataset(val_ds)
test_ds = preprocess_dataset(test_ds)

train_ds = train_ds.batch(batch_size, drop_remainder=True)
val_ds = val_ds.batch(batch_size, drop_remainder=True)
test_ds = test_ds.batch(batch_size, drop_remainder=True)

train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)
test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)

In [6]:
for audio, labels in train_ds.take(1):
    print(audio.shape)

(16, 498, 32, 3)


# 콜백 설정 (Checkpoint, Earlystop)

In [7]:
callbacks = []

CHECKPOINT_DIR_PREFIX = os.path.join("checkpoint", model_name)
os.makedirs(CHECKPOINT_DIR_PREFIX, exist_ok=True)

latest_checkpoint_path = os.path.join(CHECKPOINT_DIR_PREFIX, "latest.ckpt")
latest_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=latest_checkpoint_path,
    save_weights_only=True,
    save_freq="epoch"
)
callbacks.append(latest_checkpoint_callback)

current_time = time.strftime("%y%m%d-%H%M%S")
checkpoint_path = os.path.join(CHECKPOINT_DIR_PREFIX, current_time, "cp-{epoch:04d}.ckpt")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True,
    save_freq="epoch"
)
callbacks.append(checkpoint_callback)

log_path = os.path.join(CHECKPOINT_DIR_PREFIX, f"{model_name}-{current_time}.csv")
logger_callback = tf.keras.callbacks.CSVLogger(
    log_path, separator=',', append=True
)
callbacks.append(logger_callback)

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    verbose=1,
    patience=10,
    restore_best_weights=True
)
callbacks.append(earlystop_callback)

auto_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    verbose=1,
    min_lr=0.0001
)
callbacks.append(auto_lr_callback)

# 모델 불러오기

In [8]:
age_metrics = [
    'accuracy',
    tf.keras.metrics.TopKCategoricalAccuracy(2, name="top_2_accuracy"),
    tf.keras.metrics.TopKCategoricalAccuracy(3, name="top_3_accuracy"),
    tf.keras.metrics.AUC(),
    metrics.PerClassPrecision(len(dataset.age_list)),
    metrics.PerClassRecall(len(dataset.age_list)),
    tfa.metrics.F1Score(len(dataset.age_list)),
    metrics.ConfusionMatrix(len(dataset.age_list)),
]

dialect_metrics = [
    'accuracy',
    tf.keras.metrics.AUC(),
    metrics.PerClassPrecision(len(dataset.dialect_list)),
    metrics.PerClassRecall(len(dataset.dialect_list)),
    tfa.metrics.F1Score(len(dataset.dialect_list)),
    metrics.ConfusionMatrix(len(dataset.age_list))
]

gender_metrics = [
    'accuracy',
    tf.keras.metrics.AUC(),
    tf.keras.metrics.Precision(),
    tf.keras.metrics.Recall(),
    tfa.metrics.F1Score(1, threshold=0.5)
]

model = get_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
          loss={
              "age": "categorical_crossentropy",
              "dialect": "categorical_crossentropy",
              "gender": "binary_crossentropy"
          },
          metrics={
              "age": age_metrics,
              "dialect": dialect_metrics,
              "gender": gender_metrics,
          })

In [9]:
if len(glob.glob(latest_checkpoint_path + "*")) > 0:
    model.load_weights(latest_checkpoint_path)

In [None]:
model.summary()

# 훈련 시작

In [7]:
history = model.fit(train_ds,
          validation_data=val_ds,
          callbacks=callbacks,
          epochs=30,
                   )

NameError: name 'model' is not defined

In [12]:
metrics = history.history
plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.show()

NameError: name 'history' is not defined

In [None]:
model.save_weights(latest_checkpoint_path)

In [10]:
model.evaluate(test_ds)



[2.447340488433838,
 1.5965542793273926,
 0.6703541874885559,
 0.18043281137943268,
 0.41318997740745544,
 0.6782734394073486,
 0.8383343815803528,
 0.7878989577293396,
 array([0.37922558, 0.41363427, 0.35721248, 0.33059126, 0.35137895,
        0.57949096], dtype=float32),
 array([0.47099236, 0.37432262, 0.3171787 , 0.3215    , 0.33692458,
        0.6540383 ], dtype=float32),
 array([0.42015663, 0.3929978 , 0.33600736, 0.32598227, 0.34399998,
        0.61451197], dtype=float32),
 array([[ 617,  388,  152,   58,   44,   51],
        [ 500,  898,  565,  202,  141,   93],
        [ 249,  600,  733,  381,  216,  132],
        [ 154,  156,  370,  643,  432,  245],
        [  76,   86,  147,  426,  688,  619],
        [  31,   43,   85,  235,  437, 1571]]),
 0.8102535009384155,
 0.9611518383026123,
 array([0.96613127, 0.66274023, 0.84806776, 0.75265557, 0.7580175 ,
        0.7832536 ], dtype=float32),
 array([0.98856735, 0.69235754, 0.89597315, 0.82735616, 0.50855744,
        0.86613756], dt