## 데이터셋 생성

In [None]:
import time
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [None]:
# config
root_dir = "/mnt/d/Exa/227.건설 현장 위험 상태 판단 데이터/01-1.정식개방데이터/Minisample/"
makeset = tf.keras.utils.image_dataset_from_directory
class_type = ["binary", "categorical"][1]

# make dataset
batch_size = 64
image_size = (224, 224)
seed = 9140183 # random number
train_set = makeset(
    f"{root_dir}train",
    image_size=image_size,
    batch_size=batch_size,
    seed=seed,
    label_mode=class_type
)
valid_set = makeset(
    f"{root_dir}validation",
    image_size=image_size,
    batch_size=batch_size,
    seed=seed,
    label_mode=class_type
)
train_set.shuffle(buffer_size=200000)

In [None]:
# check dataset
category_num = 0
for d, l in train_set:
    _, category_num = l.shape
    print(f"data batch size: {d.shape} / label batch size: {l.shape}")
    break
if category_num == 1:
	assert class_type == "binary", "분류 설정 오류"
elif category_num >= 2:
	assert class_type == "categorical", "분류 설정 오류"
else:
	raise Exception("검증 데이터셋 오류")

In [None]:
ref_model = tf.keras.applications.EfficientNetB0(weights=None) # 처음부터 다시 훈련: weights=None
base_model = tf.keras.models.Model(
	inputs=ref_model.input,
	outputs=ref_model.layers[-2].output
)

x = base_model.output
if class_type == "binary":
    predictions = layers.Dense(1, activation = "sigmoid")(x)
else:
    predictions = layers.Dense(category_num, activation="softmax")(x)

model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions)
loss_function = f"{class_type}_crossentropy"
model.compile(loss=loss_function, optimizer="rmsprop", metrics=["accuracy"])

In [None]:
# benchmark
start = time.time()
history = model.fit(
    train_set,
    epochs=5,
    validation_data=valid_set
)
end = time.time()
print(f"Time spent: {int((end - start) // 60)}m {(end - start) % 60:.3f}s")

In [None]:
# plot
epochs = range(1, len(history.history["accuracy"]) + 1)
loss = history.history["accuracy"]
val_loss = history.history["val_accuracy"]
plt.figure()
plt.plot(epochs, loss, "bo", label="Training accuracy")
plt.plot(epochs, val_loss, "b", label="Validation accuracy")
plt.legend()