In [None]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import umap
from sklearn.preprocessing import StandardScaler

sns.set_theme(style="darkgrid")
sns.set(font_scale = 1.8)
colors = sns.color_palette("colorblind")

In [None]:
modality = 'music' # 'music' or 'video'
which = 'openl3' # 'mfcc', 'msd' or 'openl3' for music, 'slow_fast' for video

fn_suffix = {
    'music': {
        'mfcc': '',
        'msd': '_backend', 
        'openl3': '_music', # '_music' or '_env'
    },
    'video': {
        'slow_fast': '_slow', # '_slow' or '_fast'
    },
    'speech': {
        'hubert': '_transformer', # '_wave_encoder' or '_transformer'
    }
}

embedding_dimensions = {
    'video': {
        'slow_fast': 2048 if fn_suffix['video']=='_slow' else 256,
    },
    'music': {
        'mfcc': 60,
        'msd': 256,
        'openl3': 512,
    },
    'speech': {
        'hubert': 1024 if fn_suffix['speech']=='_transformer' else 512,
    }
}

## Load ground truth

In [None]:
groundtruth_df = pd.read_csv("groundtruth_merged.csv")
groundtruth_df.set_index("stimulus_id", inplace=True)
groundtruth_df.head()

In [None]:
not_found = 0
for stimulus_id in groundtruth_df.index:
    if os.path.exists(f"{modality}/embeddings_{which}/{stimulus_id}{fn_suffix[modality][which]}.npy"):
        continue
    else:
        print(f"Embedding for {stimulus_id} not found")
        not_found += 1

assert not_found == 0

## Load embeddings

In [None]:
embedding_dim = embedding_dimensions[modality][which]

all_embeddings = np.empty((groundtruth_df.shape[0], embedding_dim))

for i,stimulus_id in enumerate(groundtruth_df.index):
    embedding = np.load(f"{modality}/embeddings_{which}/{stimulus_id}{fn_suffix[modality][which]}.npy")
    all_embeddings[i] = embedding.mean(axis=0)

all_embeddings.shape

In [None]:
reducer = umap.UMAP(n_neighbors=20, random_state=42)
embeddings = StandardScaler().fit_transform(all_embeddings)
embeddings = reducer.fit_transform(embeddings)

In [None]:
palette = {'Mixed':'C2','Girls/women':'C3','Boys/men':'C0','No actors':'C1'}

mask = groundtruth_df.target.isin(["Girls/women", "Mixed", "Boys/men"])

plt.figure(figsize=(12,8))
g = sns.scatterplot(
        x= embeddings[mask,0], # type: ignore
        y= embeddings[mask,1], # type: ignore
        hue= groundtruth_df.loc[mask, "target"],
        palette= palette, s=50
    )
g.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
g.grid(False)
g.legend(title="Target")
plt.show()
