In [None]:
import tensorflow as tf
from models import decoder, encoder, vae
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
from constants import *

In [185]:
apes_info = pd.read_csv(APES_INFO_FILEPATH)
train_ids = apes_info.loc[apes_info["dataset"] == "train", "image"].to_list()
validation_ids = apes_info.loc[apes_info["dataset"] == "validation", "image"].to_list()
test_ids = apes_info.loc[apes_info["dataset"] == "test", "image"].to_list()
images_ids = sorted([item.stem for item in pathlib.Path(DATA_FILEPATH).iterdir() if item.suffix == ".png"])

In [186]:
dataset = (
    tf.keras.utils.image_dataset_from_directory(
        directory=DATA_FILEPATH,
        batch_size=1,
        image_size=IMAGE_SIZE,
        shuffle=False,
        labels=images_ids,
    )
    .unbatch()
    .map(lambda x, y: (x / 255, y))
)


@tf.autograph.experimental.do_not_convert
def select_x(x, _):
    return x

Found 10000 files belonging to 10000 classes.


In [187]:
decoder_model = decoder.build_decoder(LATENT_DIM)
encoder_model = encoder.build_encoder(LATENT_DIM)

vae_model = vae.VAE(encoder_model, decoder_model, 100, 1)
vae_model.load_weights("../data/models/vae/")

for i in range(len(vae_model.layers)):
    vae_model.layers[i].trainable = False

In [264]:
apes_info["Eyes"].value_counts()

23

In [268]:
bck = pd.concat([pd.get_dummies(apes_info["Background"]), apes_info[["dataset"]]], axis=1)
bck_train = bck[bck["dataset"] == "train"].drop(columns="dataset")
bck_val = bck[bck["dataset"] == "validation"].drop(columns="dataset")

mth = pd.concat([pd.get_dummies(apes_info["Mouth"]), apes_info[["dataset"]]], axis=1)
mth_train = mth[mth["dataset"] == "train"].drop(columns="dataset")
mth_val = mth[mth["dataset"] == "validation"].drop(columns="dataset")

hat = pd.concat([pd.get_dummies(apes_info["Hat"]), apes_info[["dataset"]]], axis=1)
hat_train = hat[hat["dataset"] == "train"].drop(columns="dataset")
hat_val = hat[hat["dataset"] == "validation"].drop(columns="dataset")

eyes = pd.concat([pd.get_dummies(apes_info["Eyes"]), apes_info[["dataset"]]], axis=1)
eyes_train = eyes[eyes["dataset"] == "train"].drop(columns="dataset")
eyes_val = eyes[eyes["dataset"] == "validation"].drop(columns="dataset")

In [269]:
inp = tf.keras.layers.Input((256, 256, 3))
enc = encoder_model(inp)
concat = tf.keras.layers.Concatenate()([enc[0], enc[1]])
x = tf.keras.layers.Dense(128, activation="relu")(concat)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
out1 = tf.keras.layers.Dense(8, activation="sigmoid", name="bck")(x)
out2 = tf.keras.layers.Dense(33, activation="sigmoid", name="mth")(x)
out3 = tf.keras.layers.Dense(37, activation="sigmoid", name="hat")(x)
out4 = tf.keras.layers.Dense(23, activation="sigmoid", name="eyes")(x)
mod2 = tf.keras.Model(inp, [out1, out2, out3, out4], name="abc")

In [270]:
mod2.compile(
    optimizer="adam",
    loss=[tf.keras.losses.BinaryCrossentropy(from_logits=False),
          tf.keras.losses.BinaryCrossentropy(from_logits=False),
          tf.keras.losses.BinaryCrossentropy(from_logits=False),
          tf.keras.losses.BinaryCrossentropy(from_logits=False)],
    metrics=["accuracy"],
)

In [271]:
y_ds_train = tf.data.Dataset.zip(
    (
        tf.data.Dataset.from_tensor_slices(bck_train),
        tf.data.Dataset.from_tensor_slices(mth_train),
        tf.data.Dataset.from_tensor_slices(hat_train),
        tf.data.Dataset.from_tensor_slices(eyes_train),
    )
)
x_ds_train = dataset.filter(lambda _, y: tf.math.reduce_any(y == train_ids)).map(select_x).batch(1)
ds_train = tf.data.Dataset.zip((x_ds_train.unbatch(), y_ds_train)).batch(32)


y_ds_val = tf.data.Dataset.zip(
    (
        tf.data.Dataset.from_tensor_slices(bck_val),
        tf.data.Dataset.from_tensor_slices(mth_val),
        tf.data.Dataset.from_tensor_slices(hat_val),
        tf.data.Dataset.from_tensor_slices(eyes_val),
    )
)
x_ds_val = dataset.filter(lambda _, y: tf.math.reduce_any(y == validation_ids)).map(select_x).batch(1)
ds_val = tf.data.Dataset.zip((x_ds_val.unbatch(), y_ds_val)).batch(32)

In [273]:
mod2.fit(ds_train, validation_data=ds_val, epochs=10, batch_size=None)

Epoch 1/10
    219/Unknown - 29s 132ms/step - loss: 0.5521 - bck_loss: 0.1384 - mth_loss: 0.1193 - hat_loss: 0.1186 - eyes_loss: 0.1759 - bck_accuracy: 0.8204 - mth_accuracy: 0.2220 - hat_accuracy: 0.2234 - eyes_accuracy: 0.1434

2023-05-20 11:53:33.884337: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [10000]
	 [[{{node Placeholder/_0}}]]
2023-05-20 11:53:33.884530: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [10000]
	 [[{{node Placeholder/_0}}]]


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x28df43eb0>

In [245]:
mod2.summary()

Model: "abc"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_25 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 encoder (Functional)        [(None, 256),                2201184   ['input_25[0][0]']            
                              (None, 256),                                                        
                              (None, 256)]                                                        
                                                                                                  
 concatenate_14 (Concatenat  (None, 512)                  0         ['encoder[12][0]',            
 e)                                                                  'encoder[12][1]']          