## Subsampling data

This notebook follows the subsampling methodology of the paper `Transfer Learning with Real-World Nonverbal Vocalizations from Minimally Speaking Individuals`. For each participant, we restrict to a subset of the labels, and keep at most 10 samples from each (label,session) pair. We then compute an OOS unweighted f1 score, which should be comparable to the scores in Figure 2 of the paper. The model used here is a regularized logistic regression using features coming from a pre-trained HuBERT model. On average, the performance is better than any of the models from the paper, but this varies somewhat by participant and it is not universally the best.

In [1]:
from pathlib import Path


import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import (
    RepeatedStratifiedKFold,
    cross_val_score,
)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from skopt import BayesSearchCV
import torch
import torchaudio
from tqdm.notebook import tqdm

# The labels for each participant used in the paper
classes = {
    "P01": ["delighted", "dysregulated", "request", "frustrated"],
    "P02": ["delighted", "social", "frustrated", "selftalk"],
    "P03": ["dysregulated", "request", "frustrated", "selftalk"],
    "P05": ["dysregulated", "delighted", "frustrated", "selftalk"],
    "P06": ["delighted", "request", "selftalk", "frustrated", "yes"],
    "P08": ["frustrated", "delighted", "social", "selftalk", "request"],
    "P11": ["delighted", "frustrated", "selftalk", "social"],
    "P16": ["delighted", "frustrated", "selftalk", "social"],
}

In [4]:
# Generate features with pre-trained HuBERT model and add
# to dataframe.
# We end up with a dataframe `data`, with relevant columns:
# - Participant
# - session
# - y (this is Label encoded as an integer)
# - feature_n for n in [0,1,...,767], which are HuBERT features
#
# This cell took around 10 minutes to run on my machine.
data_dir = Path("../data")
data_files = pd.read_csv(data_dir / "directory_w_train_test.csv")
bundle = torchaudio.pipelines.HUBERT_BASE
model = bundle.get_model()
with torch.no_grad():
    t_list = []
    for filename in tqdm(data_files.Filename):
        waveform, sample_rate = torchaudio.load(data_dir / "wav" / filename)
        waveform = torchaudio.functional.resample(
            waveform, sample_rate, bundle.sample_rate
        )

        features, _ = model.extract_features(waveform)
        t_list.append(features[0].mean((0, 1)))
X = torch.stack(t_list).detach()
labels = data_files.Label.unique()
y = torch.zeros(len(data_files), dtype=torch.int)
for idx, label in enumerate(labels):
    y[(data_files.Label == label).values] = idx
data = pd.concat(
    [
        data_files,
        pd.DataFrame(X, columns=[f"feature_{n}" for n in range(X.shape[1])]),
        pd.DataFrame({"y": y}),
    ],
    axis=1,
)
data["session"] = data.Filename.apply(
    lambda name: "_".join(name.split("_")[:2])
)
for participant, class_list in classes.items():
    df = data.loc[data.Participant == participant]
    assert all(label in df.Label.values for label in class_list)

  0%|          | 0/7077 [00:00<?, ?it/s]

In [5]:
# For each participant, we subsample the data, optimize the
# regularization parameter C, then compute OOS unweighted f1
# score. The process is repeated 7 times (with different
# subsamples), and the median OOS score is recorded.

for participant, class_list in classes.items():
    scores = []
    data_one = data.loc[data.Participant == participant]
    for n in tqdm(range(7)):
        # List of data files
        trimmed = pd.DataFrame(
            data_one.groupby(["session", "Label"])
            .apply(lambda df: df.sample(min(10, len(df))))
            .values,
            columns=data_one.columns,
        )
        vc = trimmed.groupby("Label").size().loc[class_list]
        training_files = pd.DataFrame(
            trimmed.loc[trimmed.Label.isin(vc.index)]
            .groupby("Label")
            .apply(lambda df: df.sample(vc.min()))
            .values,
            columns=trimmed.columns,
        )
        if n == 0:
            print(participant)
            print(training_files.Label.value_counts())

        # There are 768 generated features, which is a lot
        # relative to how many training data there are. So we
        # will need regularization. Using sk-optimize to optimize
        # strength of regularization parameter
        est = make_pipeline(
            StandardScaler(),
            LogisticRegression(
                max_iter=10**6,
            ),
        )
        opt = BayesSearchCV(
            est,
            {
                "logisticregression__C": (3e-4, 2e-2, "log-uniform"),
            },
            n_iter=20,
            cv=RepeatedStratifiedKFold(
                n_splits=10, n_repeats=3, random_state=12345
            ),
            scoring="f1_macro",
        )
        opt.fit(
            training_files[[f"feature_{n}" for n in range(X.shape[1])]],
            training_files.y.astype("category"),
        )

        # Using the optimal parameter, compute cross-validated unweighted f1
        # score. We use a different random seed here to reduce the the value
        # of C being overfitted to the data (though nested CV would be better).
        score = cross_val_score(
            opt.best_estimator_,
            training_files[[f"feature_{n}" for n in range(X.shape[1])]],
            training_files.y.astype("category"),
            cv=RepeatedStratifiedKFold(
                n_splits=5, n_repeats=3, random_state=123456
            ),
            scoring="f1_macro",
        ).mean()
        scores.append(score)
    print(opt.best_params_["logisticregression__C"], np.median(scores))

  0%|          | 0/7 [00:00<?, ?it/s]

P01
Label
delighted       31
dysregulated    31
frustrated      31
request         31
Name: count, dtype: int64
0.02 0.6700527522586346


  0%|          | 0/7 [00:00<?, ?it/s]

P02
Label
delighted     34
frustrated    34
selftalk      34
social        34
Name: count, dtype: int64




0.0009156611590927516 0.46147773250714424


  0%|          | 0/7 [00:00<?, ?it/s]

P03
Label
dysregulated    25
frustrated      25
request         25
selftalk        25
Name: count, dtype: int64




0.0016408175909896968 0.578073408073408


  0%|          | 0/7 [00:00<?, ?it/s]

P05
Label
delighted       62
dysregulated    62
frustrated      62
selftalk        62
Name: count, dtype: int64
0.008172883609112969 0.5954266957449839


  0%|          | 0/7 [00:00<?, ?it/s]

P06
Label
delighted     22
frustrated    22
request       22
selftalk      22
yes           22
Name: count, dtype: int64
0.0018158896927669855 0.3341370851370852


  0%|          | 0/7 [00:00<?, ?it/s]

P08
Label
delighted     33
frustrated    33
request       33
selftalk      33
social        33
Name: count, dtype: int64
0.0032471823189565265 0.3480805685542528


  0%|          | 0/7 [00:00<?, ?it/s]

P11
Label
delighted     27
frustrated    27
selftalk      27
social        27
Name: count, dtype: int64
0.005576803128249259 0.4351826495944143


  0%|          | 0/7 [00:00<?, ?it/s]

P16
Label
delighted     47
frustrated    47
selftalk      47
social        47
Name: count, dtype: int64
0.011364841381726321 0.6774631822911538
