In [1]:
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
from fused_encoder import create_fused_encoder
from triplet_generator import prepare_triplets

In [None]:
TRAIN_IMG_DIR = "train-Copy1"
TRAIN_AUDIO_DIR = "audiotrain"
VAL_IMG_DIR = "val-Copy1"
VAL_AUDIO_DIR = "audio_val(2)-Copy1"
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 1e-4
MARGIN = 0.5
OUTPUT_DIR = "model_outputs_strong"


In [3]:
def triplet_loss(y_true, y_pred):
    anchor, positive, negative = tf.split(y_pred, num_or_size_splits=3, axis=1)
    pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=-1)
    neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=-1)
    loss = tf.maximum(pos_dist - neg_dist + MARGIN, 0.0)
    return tf.reduce_mean(loss)

def create_triplet_model(encoder):
    a_img = tf.keras.Input(shape=(224, 224, 3))
    a_audio = tf.keras.Input(shape=(None, 1024))
    p_img = tf.keras.Input(shape=(224, 224, 3))
    p_audio = tf.keras.Input(shape=(None, 1024))
    n_img = tf.keras.Input(shape=(224, 224, 3))
    n_audio = tf.keras.Input(shape=(None, 1024))

    a_emb = encoder([a_img, a_audio])
    p_emb = encoder([p_img, p_audio])
    n_emb = encoder([n_img, n_audio])

    out = tf.keras.layers.Concatenate(axis=1)([a_emb, p_emb, n_emb])
    return tf.keras.Model(inputs=[a_img, a_audio, p_img, p_audio, n_img, n_audio], outputs=out)

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.environ['TFHUB_CACHE_DIR'] = os.path.join(os.getcwd(), 'tfhub_cache')
if not os.path.exists(os.environ['TFHUB_CACHE_DIR']):
    os.makedirs(os.environ['TFHUB_CACHE_DIR'])

try:
    yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
except Exception as e:
    print(f"Error loading YAMNet: {e}")
    # Clear cache and try again
    cache_dir = os.environ['TFHUB_CACHE_DIR']
    if os.path.exists(cache_dir):
        for root, dirs, files in os.walk(cache_dir, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                os.rmdir(os.path.join(root, name))
        os.rmdir(cache_dir)
    yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')


In [None]:
print("Preparing training triplets...")
train_triplets = prepare_triplets(TRAIN_IMG_DIR, TRAIN_AUDIO_DIR, yamnet_model)
print("Preparing validation triplets...")
val_triplets = prepare_triplets(VAL_IMG_DIR, VAL_AUDIO_DIR, yamnet_model)


In [8]:
def generator(triplets, batch_size):
    while True:
        np.random.shuffle(triplets)
        for i in range(0, len(triplets), batch_size):
            batch = triplets[i:i+batch_size]
            if not batch: continue
            a_img, a_audio, p_img, p_audio, n_img, n_audio = zip(*batch)
            #a_img = [tf.image.resize(tf.io.decode_image(tf.io.read_file(x), channels=3)/255., (224,224)) for x in a_img]
            #p_img = [tf.image.resize(tf.io.decode_image(tf.io.read_file(x), channels=3)/255., (224,224)) for x in p_img]
            #n_img = [tf.image.resize(tf.io.decode_image(tf.io.read_file(x), channels=3)/255., (224,224)) for x in n_img]
            a_img = [tf.image.resize(tf.cast(tf.io.decode_image(tf.io.read_file(x), channels=3), tf.float32) / 255.0, (224, 224)) if x != 'empty_image' else tf.zeros((224, 224, 3)) for x in a_img]
            p_img = [tf.image.resize(tf.cast(tf.io.decode_image(tf.io.read_file(x), channels=3), tf.float32) / 255.0, (224, 224)) if x != 'empty_image' else tf.zeros((224, 224, 3)) for x in p_img]
            n_img = [tf.image.resize(tf.cast(tf.io.decode_image(tf.io.read_file(x), channels=3), tf.float32) / 255.0, (224, 224)) if x != 'empty_image' else tf.zeros((224, 224, 3)) for x in n_img]
            yield [tf.stack(a_img), tf.stack(a_audio),
                   tf.stack(p_img), tf.stack(p_audio),
                   tf.stack(n_img), tf.stack(n_audio)], tf.zeros((len(batch),))



In [9]:
encoder = create_fused_encoder()
triplet_model = create_triplet_model(encoder)


In [None]:
triplet_model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE), loss=triplet_loss)

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR, "triplet_model"), save_best_only=True),
    tf.keras.callbacks.TensorBoard(log_dir=os.path.join(OUTPUT_DIR, "logs"))
]

steps_per_epoch = len(train_triplets) // BATCH_SIZE
val_steps = len(val_triplets) // BATCH_SIZE

history = triplet_model.fit(
    generator(train_triplets, BATCH_SIZE),
    validation_data=generator(val_triplets, BATCH_SIZE),
    steps_per_epoch=steps_per_epoch,
    validation_steps=val_steps,
    epochs=EPOCHS,
    callbacks=callbacks
)

encoder.save(os.path.join(OUTPUT_DIR, "fused_encoder_triplet"))



In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title("Triplet Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Train", "Val"])
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "triplet_training_history.png"))
