In [1]:
%matplotlib inline
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader, TensorDataset
from utils import MultipleRegression, load_parameters

from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import r2_score

In [2]:
modality = 'music' # 'music', 'speech', or 'video'
which = 'openl3' # 'mfcc', 'msd' or 'openl3' for music, 'slow_fast' for video, 'hubert' for speech
voice = True 

fn_suffix = {
    'music': {
        'mfcc': '',
        'msd': '_backend', 
        'openl3': '_music', # '_music' or '_env'
    },
    'video': {
        'slow_fast': '_slow', # '_slow' or '_fast'
    },
    'speech': {
        'hubert': '_wave_encoder', # '_wave_encoder' or '_transformer'
    }
}

embedding_dimensions = {
    'video': {
        'slow_fast': 2048 if fn_suffix['video']['slow_fast']=='_slow' else 256,
    },
    'music': {
        'mfcc': 60,
        'msd': 256,
        'openl3': 512,
    },
    'speech': {
        'hubert': 1024 if fn_suffix['speech']['hubert']=='_transformer' else 512,
    }
}

## Load ground truth

In [3]:
groundtruth_df = pd.read_csv("groundtruth_merged.csv")
groundtruth_df.set_index("stimulus_id", inplace=True)
groundtruth_df.head()

Unnamed: 0_level_0,product_category,filming_location,target,interaction,voice_type,voice_age,voice_gender,voice_exagg,asian,black,...,description,upload_date,duration,view_count,categories,tags,like_count,requested_subtitles,download,error_logs
stimulus_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ndzo2ZIWfiQ,High-tech Interactive Playmates and Robotics,Non-specific,Girls/women,They do not interact with each other or there ...,BOTH spoken and sung,Adults (including young adults),Feminine,"No, all voices are normal-sounding",No,No,...,,20120907,20,3047,['Autos & Vehicles'],['122975'],5,"{'en': {'ext': 'vtt', 'url': 'https://www.yout...",True,
yRUiwRKk6QM,High-tech Interactive Playmates and Robotics,Indoors,Mixed,They do not interact with each other or there ...,Spoken,Adults (including young adults),Masculine,Yes a masculine voice is gender exaggerated,No,Yes,...,Now you can train like a Jedi! Use the power o...,20170920,34,25861,['Entertainment'],"['Smyths Toys', 'Toys', '(Industry)kids', 'Sta...",46,,True,
3ysC1-foJT4,"Apparel, Fashion, Accessories, Cosmetics, Cost...",Indoors,Girls/women,They are working or playing together in a coop...,Sung,Adults (including young adults),Feminine,Yes a feminine voice is gender exaggerated,No,No,...,Just add water to make endless crystal creatio...,20180802,30,2545,['Entertainment'],"['Smyths Toys', 'Toys (Industry)', 'kids', 'sm...",19,,True,
cYszuGaptkk,"Action Figures, Battling Toys and Toy Weapons",Non-specific,Mixed,They are working or playing together in a coop...,Spoken,Adults (including young adults),Feminine,"No, all voices are normal-sounding",No,Yes,...,Create your own exciting dino rescue missions ...,20201028,21,2035695,['Entertainment'],"['Smyths Toys', 'Toys (Industry)', 'kids', 'sm...",51,"{'en': {'ext': 'vtt', 'url': 'https://www.yout...",True,
2LZjLBipdfI,Dolls,Indoors,Girls/women,They do not interact with each other or there ...,BOTH spoken and sung,Adults (including young adults),Feminine,Yes a feminine voice is gender exaggerated,No,No,...,Fabio & Fabia's Hair salon is the most loved p...,20201106,15,1834,['Entertainment'],"['Smyths Toys', 'Toys (Industry)', 'kids', 'sm...",10,,True,


In [4]:
mid_level_features = pd.read_csv("mid_level_features.csv").drop(columns=["target"])
mid_level_features.set_index("stimulus_id", inplace=True)
mid_level_features.head()

Unnamed: 0_level_0,Electric/Acoustic,Distorted/Clear,Many Instruments/Few Instruments,Loud/Soft,Heavy/Light,High pitch/Low pitch,Wide pitch variation/Narrow pitch variation,Punchy/Smooth,Harmonious/Disharmonious,Clear melody/No melody,Repetitive/Non-repetitive,Complex rhythm/Simple rhythm,Fast tempo/Slow tempo,Dense/Sparse,Strong beat/Weak beat
stimulus_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
-BaTPbE0Gdo,3.666667,6.5,4.333333,5.833333,6.666667,2.666667,4.666667,6.0,1.5,2.166667,4.333333,5.666667,3.5,4.166667,4.166667
-KKsNKY4V8k,3.0,5.166667,2.5,2.5,5.0,2.5,3.666667,4.0,2.666667,2.5,2.666667,5.0,2.5,3.166667,3.166667
-Mqc2csT3ZM,2.666667,3.5,4.666667,4.666667,5.333333,2.666667,4.666667,3.5,3.5,3.333333,2.833333,5.333333,3.333333,3.666667,4.0
-NEHGAMiA2I,4.833333,6.166667,5.5,5.666667,6.166667,3.666667,4.5,5.5,2.333333,3.666667,2.833333,4.833333,4.333333,4.5,4.333333
-SEKfzdaIK0,3.666667,3.166667,5.333333,3.5,2.833333,4.166667,4.5,2.166667,3.0,3.0,2.666667,4.833333,2.666667,4.0,3.0


In [5]:
not_found = 0
for stimulus_id in groundtruth_df.index:
    if os.path.exists(f"{modality}/embeddings_{which}/{stimulus_id}{fn_suffix[modality][which]}.npy"):
        continue
    else:
        print(f"Embedding for {stimulus_id} not found")
        not_found += 1

assert not_found == 0

## Load embeddings

In [6]:
embedding_dim = embedding_dimensions[modality][which]

X = np.empty((groundtruth_df.shape[0], embedding_dim))
y = np.empty((mid_level_features.shape[0], mid_level_features.shape[1]))

for i,stimulus_id in enumerate(groundtruth_df.index):
    embedding = np.load(f"{modality}/embeddings_{which}{'' if voice else '_novoice'}/" +
                        f"{stimulus_id}{fn_suffix[modality][which]}.npy")
    X[i] = embedding.mean(axis=0)
    y[i] = mid_level_features.loc[stimulus_id].values

X.shape

(606, 512)

In [7]:
# split into train, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

In [8]:
train_dataset = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
val_dataset = TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float())

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

model = MultipleRegression(input_dim=X_train.shape[1], n_regressions=y_train.shape[1])

checkpoint_callback = ModelCheckpoint(monitor='val_loss')
trainer = Trainer(max_epochs=100, callbacks=[checkpoint_callback, 
                                            EarlyStopping(monitor='val_loss', patience=20)])
trainer.fit(model, train_loader, val_loader)

model = MultipleRegression(input_dim=X_train.shape[1], n_regressions=y_train.shape[1])
model = model.load_from_checkpoint(checkpoint_callback.best_model_path, input_dim=X_train.shape[1], n_regressions=y_train.shape[1])

model.eval()
with torch.no_grad():
    y_pred = model(torch.from_numpy(X_test).float()).detach().numpy()


r2_values = r2_score(y_test, y_pred, multioutput='raw_values')


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name    | Type        | Params
----------------------------------------
0 | bn      | BatchNorm1d | 1.0 K 
1 | linear  | Linear      | 131 K 
2

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [9]:
features_names = mid_level_features.columns

# print the R-squared values for each target with the corresponding feature name
for i, r2 in enumerate(r2_values):
    print(f'{features_names[i]}: {r2:.2f}')

Electric/Acoustic: 0.62
Distorted/Clear: 0.36
Many Instruments/Few Instruments: 0.07
Loud/Soft: 0.50
Heavy/Light: 0.57
High pitch/Low pitch: 0.12
Wide pitch variation/Narrow pitch variation: -0.06
Punchy/Smooth: 0.53
Harmonious/Disharmonious: 0.04
Clear melody/No melody: 0.03
Repetitive/Non-repetitive: 0.05
Complex rhythm/Simple rhythm: -0.09
Fast tempo/Slow tempo: 0.25
Dense/Sparse: 0.39
Strong beat/Weak beat: 0.38
