In [39]:
# Load embeddings
from pickle import load
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

embeds = load(open("jepa_encodings.pkl", "rb"))

# Break into categories
clothing_embed = {k:v for k, v in embeds.items() if k.endswith("c.png")}
clothing_labels = [k.split("-")[0].split("_") for k in clothing_embed.keys()]
clothing_labels = [[lab[0], lab[1], lab[2][:-1]] if lab[2].endswith("s") else lab for lab in clothing_labels]
clothing_embed = np.array(list(clothing_embed.values()))

full_body_embed = {k:v for k, v in embeds.items() if k.endswith("fb.png")}
full_body_labels = [k.split("-")[0].split("_") for k in full_body_embed.keys()]
full_body_labels = [[lab[0], lab[1], lab[2][:-1]] if lab[2].endswith("s") else lab for lab in full_body_labels]
full_body_embed = np.array(list(full_body_embed.values()))

In [40]:
# Preprocess embedings
clothing_genders = np.zeros((clothing_embed.shape[0], 2))
clothing_genders[np.array([{"mens": 0, "womens": 1}[lab[1]] for lab in clothing_labels])] = 1

clothing_type = np.zeros((clothing_embed.shape[0], 6))
clothing_type[np.array([
    {"shirt": 0, "top": 1, "sweater": 2, "pant": 3, "skirt": 4, "short": 5}[lab[2]]
    for lab in clothing_labels
])] = 1

In [41]:
# Train model
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Input((clothing_embed.shape[1] + 2 + 6,)),
    tf.keras.layers.Dense(clothing_embed.shape[1] * 8, activation="elu"),
    tf.keras.layers.Dense(clothing_embed.shape[1] * 8, activation="elu"),
    tf.keras.layers.Dense(clothing_embed.shape[1] * 8, activation="elu"),
    tf.keras.layers.Dense(full_body_embed.shape[1])
])

model.compile(
    optimizer="adam",
    loss="mse",
)

model.fit(
    np.concatenate([clothing_embed, clothing_genders, clothing_type], axis=-1),
    full_body_embed,
    batch_size=32,
    epochs=100,
    validation_split=0.1,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)],
)   

Epoch 1/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 479ms/step - loss: 210.9706 - val_loss: 6.7898
Epoch 2/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 454ms/step - loss: 4.8367 - val_loss: 1.1446
Epoch 3/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 451ms/step - loss: 0.7686 - val_loss: 0.2549
Epoch 4/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 451ms/step - loss: 0.2406 - val_loss: 0.1516
Epoch 5/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 432ms/step - loss: 0.1522 - val_loss: 0.1901
Epoch 6/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 431ms/step - loss: 0.1762 - val_loss: 0.2133
Epoch 7/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 429ms/step - loss: 0.1759 - val_loss: 0.1568
Epoch 8/100
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 452ms/step - loss: 0.1369 - val_loss: 0.1281
Epoch 9/100
[1m18/18[0m [3

<keras.src.callbacks.history.History at 0x76e19aa5bd50>

In [43]:
# Test model
converted_clothing_embeds = model(np.concatenate([clothing_embed, clothing_genders, clothing_type], axis=-1))
converted_clothing_embeds = tf.expand_dims(converted_clothing_embeds, axis=0)
print(converted_clothing_embeds.dtype)

tf_full_body_embed = tf.cast(tf.expand_dims(full_body_embed.astype(np.float64), axis=1), float)

distances = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(tf_full_body_embed - converted_clothing_embeds), axis=-1))
best_samples = tf.math.top_k(distances, k=5).indices

best_samples

<dtype: 'float32'>


<tf.Tensor: shape=(640, 5), dtype=int32, numpy=
array([[319, 510, 573, 179, 351],
       [319, 351, 510, 573,  11],
       [351,  11, 510, 319, 155],
       ...,
       [510, 319, 573, 351, 179],
       [319, 510, 573, 179, 463],
       [351, 510,  11, 319, 326]], dtype=int32)>

In [45]:
converted_clothing_embeds, tf_full_body_embed

(<tf.Tensor: shape=(1, 640, 1280), dtype=float32, numpy=
 array([[[ 0.4439231 ,  0.07123362,  0.08202247, ...,  0.17721836,
           0.07446754, -0.06796871],
         [ 0.42016384,  0.16272257,  0.12194226, ...,  0.0834841 ,
           0.06801963, -0.1994041 ],
         [ 0.4250056 ,  0.06213699,  0.07934127, ...,  0.1313514 ,
           0.07301083, -0.04101493],
         ...,
         [ 0.40245378,  0.05559085,  0.12521657, ...,  0.12405013,
           0.04581612, -0.04844816],
         [ 0.4105565 ,  0.1253966 ,  0.19989783, ...,  0.10359238,
           0.01225823, -0.1439134 ],
         [ 0.40059572,  0.21454595,  0.19475797, ..., -0.02809591,
           0.05056907, -0.22802149]]], dtype=float32)>,
 <tf.Tensor: shape=(640, 1, 1280), dtype=float32, numpy=
 array([[[ 0.6672959 , -0.04892977,  0.05864802, ...,  0.14453836,
           0.04995417, -0.02062187]],
 
        [[ 1.265888  , -0.07156598,  0.11879012, ...,  0.06167518,
           0.05703071, -0.04923891]],
 
        [[ 0.04