In [2]:
import tensorflow as tf
from models import decoder, encoder, vae, classifier
import pandas as pd
import pathlib
from constants import *
from utils import data_loading

In [3]:
apes_info = pd.read_csv(APES_INFO_FILEPATH)
all_images_ids, train_ids, validation_ids, test_ids = data_loading.get_image_ids(apes_info, pathlib.Path(DATA_FILEPATH))

In [4]:
dataset = data_loading.load_full_dataset(DATA_FILEPATH, IMAGE_SIZE, all_images_ids)

y_train = data_loading.get_feature_dataset(apes_info, FEATURE_NAMES, "train")
y_validation = data_loading.get_feature_dataset(apes_info, FEATURE_NAMES, "validation")
y_test = data_loading.get_feature_dataset(apes_info, FEATURE_NAMES, "test")

Found 10000 files belonging to 10000 classes.


In [5]:
decoder_model = decoder.build_decoder(LATENT_DIM)
encoder_model = encoder.build_encoder(LATENT_DIM)

vae_model = vae.VAE(encoder_model, decoder_model, RECONSTRUCTION_LOSS_WEIGHT, KL_LOSS_WEIGHT)
vae_model.load_weights(MODEL_VAE_FILEPATH)

classifier_model = classifier.build_classifier(encoder_model, N_UNIQUE_FEATURES, FEATURE_NAMES)
classifier_model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(),
    loss=[tf.keras.losses.BinaryCrossentropy(from_logits=False)] * len(FEATURE_NAMES),
    metrics=["accuracy"],
)

In [6]:
x_train = data_loading.load_specific_dataset(dataset, train_ids, None)
train_dataset = tf.data.Dataset.zip((x_train, y_train)).batch(BATCH_SIZE)

x_validation = data_loading.load_specific_dataset(dataset, validation_ids, None)
validation_dataset = tf.data.Dataset.zip((x_validation, y_validation)).batch(BATCH_SIZE)

In [7]:
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=MODEL_CLASSIFIER_FILEPATH,
    save_weights_only=True,
    monitor="loss",
    mode="min",
    save_best_only=True,
)

csv_logger = tf.keras.callbacks.CSVLogger(HISTORY_CLASSIFIER_FILEPATH, append=True)

In [9]:
history = classifier_model.fit(
    train_dataset.repeat(STEPS_PER_EPOCH * EPOCHS_CLASSIFIER),
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=10,
    batch_size=None,
    validation_data=validation_dataset,
)

Epoch 1/10


2023-05-23 22:43:15.778423: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_20' with dtype bool and shape [7000,44]
	 [[{{node Placeholder/_20}}]]
2023-05-23 22:43:15.778642: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_20' with dtype bool and shape [7000,44]
	 [[{{node Placeholder/_20}}]]




2023-05-23 22:43:43.700652: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_14' with dtype bool and shape [1500,8]
	 [[{{node Placeholder/_14}}]]
2023-05-23 22:43:43.700830: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [10000]
	 [[{{node Placeholder/_0}}]]


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
