In [None]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import umap
from sklearn.preprocessing import StandardScaler

sns.set_theme(style="darkgrid")
sns.set(font_scale = 1.8)
colors = sns.color_palette("colorblind")

In [None]:
modality = 'video' # 'music' or 'video'

embedding_dimensions = {
    'video': 2048, # 256 if using the fast pathway
    'music': 256,
}

fn_suffix = {
    'music': '_backend',
    'video': '_slow',
}

## Load ground truth

In [None]:
groundtruth_df = pd.read_csv("groundtruth.csv")
groundtruth_df.set_index("stimulus_id", inplace=True)

In [None]:
merge_cases = [
    {
        "merged_name": "Outdoor and Sports",
        "cases_to_merge": ["Outdoor and Sports", "Ride-Ons, Bikes, Scooters and Skateboards"]
    },
    {
        "merged_name": "None of the Above",
        "cases_to_merge": ["None of the Above", "Video Gaming"]
    },
    {
        "merged_name": "Toy Vehicles, Building and Construction",
        "cases_to_merge": ["Toy Vehicles (Powered and Non-Powered)", "Building and Construction"]
    }
]

for merging_dict in merge_cases:
    for case in merging_dict["cases_to_merge"]:
        groundtruth_df.loc[groundtruth_df.product_category==case, "product_category"] = merging_dict["merged_name"]

groundtruth_df.head()

In [None]:
not_found = 0
for stimulus_id in groundtruth_df.index:
    if os.path.exists(f"{modality}/embeddings/{stimulus_id}{fn_suffix[modality]}.npy"):
        continue
    else:
        print(f"Embedding for {stimulus_id} not found")
        not_found += 1

assert not_found == 0

## Load embeddings

In [None]:
embedding_dim = embedding_dimensions[modality]

all_embeddings = np.empty((groundtruth_df.shape[0], embedding_dim))

for i,stimulus_id in enumerate(groundtruth_df.index):
    embedding = np.load(f"{modality}/embeddings/{stimulus_id}{fn_suffix[modality]}.npy")
    all_embeddings[i] = embedding.mean(axis=0)

all_embeddings.shape

In [None]:
reducer = umap.UMAP(n_neighbors=20, random_state=42)
embeddings = StandardScaler().fit_transform(all_embeddings)
embeddings = reducer.fit_transform(embeddings)

In [None]:
palette = {'Mixed':'C2','Girls/women':'C3','Boys/men':'C0','No actors':'C1'}

mask = groundtruth_df.all_genders.isin(["Girls/women", "Mixed", "Boys/men"])

plt.figure(figsize=(12,8))
g = sns.scatterplot(
        x= embeddings[mask,0], # type: ignore
        y= embeddings[mask,1], # type: ignore
        hue= groundtruth_df.loc[mask, "all_genders"],
        palette= palette, s=50
    )
g.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
g.grid(False)
g.legend(title="Target")
plt.show()
