# Contents
* [Intro](#Intro)
* [Imports and config](#Imports-and-config)
* [Load data](#Load-data)
* [Preprocess](#Preprocess)
* [Minimally Random Convolutional Kernel Transform](#Minimally-Random-Convolutional-Kernel-Transform)
  * [Ternary](#Ternary)
    * [Results ternary](#Results-ternary)
  * [Binary](#Binary)
      * [Results binary](#Results-binary)
* [Discussion](#Discussion)

## Intro

This notebook explores the MINIROCKET classification algorithm on the scaled spectrograms extracted from samples of short duration. Both the ternary and three binary cases are considered. MINIROCKET outperformed the dummy classifiers in all cases except the positive/non-positive case, potentially due to class imbalance.

## Imports and config

In [1]:
# Extensions
%load_ext lab_black
%load_ext nb_black
%load_ext autotime

In [2]:
# Core
import numpy as np
import pandas as pd

# metrics
from sklearn.metrics import classification_report, confusion_matrix

# util
from tqdm import tqdm
import swifter

# display outputs w/o print calls
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# suppress warnings
import warnings

warnings.filterwarnings("ignore")

time: 4.33 s


In [3]:
from tsai.all import *

computer_setup()

os             : Windows-10-10.0.22000-SP0
python         : 3.8.12
tsai           : 0.2.23
fastai         : 2.5.2
fastcore       : 1.3.26
torch          : 1.9.1+cpu
n_cpus         : 8
device         : cpu
time: 7.76 s


In [4]:
SEED = 2021

# Location of parquet
PARQUET_DF_FOLDER = "../5.0-mic-extract_spectrograms_and_MFCCs_short"

# Location where this notebook will output
DATA_OUT_FOLDER = "."

# The preprocessed data from the Unified Multilingual Dataset of Emotional Human utterances
WAV_DIRECTORY = (
    "../../../unified_multilingual_dataset_of_emotional_human_utterances/data/preprocessed"
)

time: 8 ms


## Load data

In [5]:
short_df = pd.read_parquet(f"{PARQUET_DF_FOLDER}/short_plus.parquet")
short_df.head(1)

Unnamed: 0,file,duration,source,speaker_id,speaker_gender,emo,valence,lang1,lang2,neg,neu,pos,length,padded,mfcc,melspec_db
0,01788+BAUM1+BAUM1.s028+f+hap+1+tur+tr-tr.wav,0.387,BAUM1,BAUM1.s028,f,hap,1,tur,tr-tr,0,0,1,short,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[[-680.11646, -680.11646, -673.7514, -377.4224, -281.58826, -261.989, -171.36475, -55.95906, 1.2606233, 15.852701, -9.603989, -57.960983, -107.54922, -140.82532, -152.95964, -169.95496], [0.0, 0.0, 8.79389, 66.162895, 79.53461, 100.93402, 75.350586, 13.998974, -14.617619, -17.756765, -5.6782565, 8.551853, 14.135569, 3.8511767, -6.7314606, -5.6710396], [0.0, 0.0, 8.264061, 9.75589, 13.253286, 15.912096, 18.082317, 2.4743164, -16.232258, -29.686052, -31.33509, -27.387304, -19.973206, -4.8711815, 1.358885, 10.830128], [0.0, 0.0, 7.477417, 24.733551, 16.511929, 10.745639, 15.796231, 35.82299, ...","[[-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -78.2808, -78.36134, -77.20024, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0], [-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -76.948425, -75.62396, -73.44333, -62.47532, -59.695614, -63.3192, -68.97307, -69.830055, -71.20323, -74.88162], [-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -65.29121, -44.23695, -34.33359, -31.309872, -37.83037, -55.600906, -74.79756, -76.61576, -80.0, -80.0], [-80.0, -80.0, -80.0, -80.0, -80.0, -58.58845, -29.926083, -13.307529, -13.760824, -19.257774, -27.894817, -47.51501, -57.81294, -67.32547, -80.0, -61.717278], [-80.0..."


time: 283 ms


## Preprocess

MINIROCKET will need each nested array as its own feature rather than having them all in one column of arrays.

In [6]:
X = short_df[["speaker_id", "neg", "neu", "pos", "valence"]].merge(
    pd.concat(
        short_df.melspec_db.swifter.apply(
            lambda row: pd.concat(
                [pd.DataFrame([array.tolist()]) for array in row]
            ).swifter.apply(lambda _: [_.values], axis=0)
        ).tolist()
    ).set_index(short_df.index),
    left_index=True,
    right_index=True,
)

Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5325.68it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5325.26it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 8014.91it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5328.64it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5331.60it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 7865.55it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5334.57it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5326.52it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 3199.93it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5334.57it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 4000.77it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5327.37it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5331.18it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 7852.66it/s]
Pandas Apply: 100%|██████████| 16/16 [00:00<00:00, 5332.87it/s]
Pandas Apply: 100%|██████████| 16/16 [00

time: 1min 19s





In [7]:
X.head(1)

Unnamed: 0,speaker_id,neg,neu,pos,valence,0,1,2,3,4,...,6,7,8,9,10,11,12,13,14,15
0,BAUM1.s028,0,0,1,1,"[-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, ...","[-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, ...","[-80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, ...","[-80.0, -80.0, -80.0, -80.0, -77.2548599243164, -79.70550537109375, -80.0, -69.44082641601562, -66.25546264648438, -69.19878387451172, -71.14031982421875, -74.84446716308594, -80.0, -80.0, -80.0, -80.0, -80.0, -73.64651489257812, -78.93495178222656, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80.0, -80...","[-80.0, -80.0, -80.0, -80.0, -59.851402282714844, -59.04191589355469, -70.99812316894531, -53.87456130981445, -52.97367477416992, -55.18295669555664, -50.43194580078125, -52.465999603271484, -64.17927551269531, -77.96973419189453, -62.39212417602539, -61.073001861572266, -68.43352508544922, -52.45552062988281, -50.00508117675781, -54.50634765625, -64.66706848144531, -79.73577117919922, -76.53602600097656, -68.9566650390625, -63.631622314453125, -69.55207824707031, -74.49180603027344, -72.06159973144531, -69.00641632080078, -65.84856414794922, -65.40316772460938, -68.32160186767578, -71.911...",...,"[-78.28079986572266, -76.94842529296875, -65.29120635986328, -29.926082611083984, -20.786357879638672, -41.40391159057617, -55.556949615478516, -42.110816955566406, -35.11319351196289, -30.687532424926758, -48.27846145629883, -56.18641662597656, -52.60308837890625, -45.06304931640625, -29.639202117919922, -33.26887512207031, -64.4759750366211, -50.071144104003906, -36.430850982666016, -31.91204071044922, -46.232627868652344, -51.94215774536133, -41.2099494934082, -38.048675537109375, -34.27606964111328, -43.65860366821289, -64.54448699951172, -63.10736846923828, -64.24110412597656, -59.016...","[-78.36134338378906, -75.62396240234375, -44.2369499206543, -13.30752944946289, -10.245218276977539, -39.08159637451172, -52.220191955566406, -31.311737060546875, -32.509033203125, -32.64528274536133, -42.83766555786133, -48.123260498046875, -44.93642044067383, -49.48857879638672, -32.84919738769531, -34.02351760864258, -60.33882141113281, -45.84920883178711, -38.484535217285156, -31.734546661376953, -30.57470703125, -28.105697631835938, -18.366127014160156, -33.19658279418945, -35.482051849365234, -39.08641815185547, -50.78274917602539, -53.16108322143555, -60.81424331665039, -55.69026184...","[-77.20024108886719, -73.44332885742188, -34.33359146118164, -13.760824203491211, -11.706056594848633, -40.50092315673828, -42.2137336730957, -35.0772705078125, -35.725528717041016, -37.92886734008789, -44.263893127441406, -42.03424072265625, -47.27333068847656, -58.910247802734375, -42.587589263916016, -44.46452331542969, -56.85912322998047, -51.08148956298828, -48.30231475830078, -45.6781005859375, -29.689037322998047, -19.842275619506836, -4.21173095703125, -11.608495712280273, -31.116748809814453, -35.99433135986328, -39.43935012817383, -36.56568908691406, -44.377906799316406, -52.2635...","[-80.0, -62.475318908691406, -31.309871673583984, -19.257774353027344, -22.53108024597168, -51.26852035522461, -40.57087326049805, -39.39089584350586, -38.738861083984375, -37.54289627075195, -32.239261627197266, -35.594764709472656, -54.247772216796875, -59.737274169921875, -47.799251556396484, -49.705806732177734, -53.480499267578125, -62.63383483886719, -56.67560577392578, -45.4731559753418, -31.607576370239258, -25.816905975341797, -7.346954345703125, -5.799198150634766, -20.91485023498535, -23.769254684448242, -28.56028938293457, -32.72368621826172, -39.24309158325195, -51.16029739379...","[-80.0, -59.695613861083984, -37.83037185668945, -27.894817352294922, -37.170738220214844, -53.01800537109375, -45.635284423828125, -52.82870864868164, -41.49824142456055, -40.766048431396484, -34.60990905761719, -39.86014938354492, -57.195213317871094, -58.46406555175781, -45.35465621948242, -43.599464416503906, -51.80852127075195, -68.45512390136719, -60.81126403808594, -48.4878044128418, -33.55009460449219, -26.1533203125, -19.656538009643555, -9.331846237182617, -19.3648681640625, -21.447675704956055, -26.884279251098633, -40.84545135498047, -44.97481918334961, -52.80524444580078, -48....","[-80.0, -63.31919860839844, -55.60090637207031, -47.515010833740234, -47.27655792236328, -55.14811706542969, -58.62758255004883, -52.241127014160156, -43.44975662231445, -49.97288131713867, -48.98359680175781, -51.32598876953125, -48.76338195800781, -56.089942932128906, -46.338722229003906, -46.215599060058594, -57.30488586425781, -65.53511047363281, -65.91812896728516, -52.196380615234375, -44.250701904296875, -34.1265869140625, -29.39117431640625, -17.95199203491211, -30.532569885253906, -35.72290802001953, -43.875274658203125, -43.411827087402344, -42.693260192871094, -55.98747253417969...","[-80.0, -68.97306823730469, -74.79756164550781, -57.81293869018555, -51.42822265625, -71.21202087402344, -71.90763854980469, -57.60711669921875, -47.89700698852539, -48.83576583862305, -52.684471130371094, -50.40840148925781, -44.298255920410156, -55.135379791259766, -56.479557037353516, -59.837154388427734, -69.92721557617188, -65.92144775390625, -59.669593811035156, -51.74650192260742, -47.74032211303711, -50.009578704833984, -33.86803436279297, -34.48558044433594, -42.85845184326172, -51.34673309326172, -53.18769836425781, -51.01543426513672, -45.80146789550781, -62.66001892089844, -61....","[-80.0, -69.8300552368164, -76.61576080322266, -67.32546997070312, -58.63043975830078, -78.68060302734375, -74.96197509765625, -64.14753723144531, -60.910667419433594, -46.20560073852539, -49.165283203125, -52.562591552734375, -49.99781799316406, -58.131378173828125, -64.16108703613281, -69.93087768554688, -77.82179260253906, -68.82881164550781, -58.88389205932617, -56.55305480957031, -54.687721252441406, -55.32316970825195, -41.91043472290039, -40.845924377441406, -46.96237564086914, -49.220619201660156, -54.55414962768555, -74.0584716796875, -58.977317810058594, -65.11494445800781, -74.7...","[-80.0, -71.20323181152344, -80.0, -80.0, -72.69197082519531, -72.04997253417969, -65.36764526367188, -59.39289474487305, -57.756309509277344, -52.72191619873047, -50.070274353027344, -60.61381912231445, -58.013336181640625, -71.18731689453125, -70.85536193847656, -71.7876968383789, -71.02880859375, -68.7911376953125, -58.668663024902344, -67.47920227050781, -66.56585693359375, -61.9245491027832, -55.721893310546875, -49.80500411987305, -55.589317321777344, -48.8618049621582, -47.73573684692383, -70.38716125488281, -65.65824890136719, -68.27877044677734, -77.5796890258789, -80.0, -71.97278...","[-80.0, -74.88162231445312, -80.0, -61.71727752685547, -61.415321350097656, -64.78623962402344, -61.61638259887695, -53.7772331237793, -52.545860290527344, -57.637874603271484, -60.613807678222656, -55.82317352294922, -55.07569885253906, -80.0, -71.5846939086914, -73.16383361816406, -69.7527084350586, -67.80366516113281, -67.27218627929688, -80.0, -69.83843994140625, -66.18220520019531, -57.07673645019531, -56.54643249511719, -67.94055938720703, -53.509376525878906, -55.519439697265625, -73.25375366210938, -74.35104370117188, -77.8223876953125, -80.0, -80.0, -75.8759765625, -80.0, -80.0, -..."


time: 47 ms


## Train test split

The custom split ensures no data leakage due to speaker characteristics.

In [8]:
short_speakers = (
    pd.DataFrame(np.unique(X.speaker_id)).sample(frac=0.30, random_state=SEED)[0].values
)

criterion = X.speaker_id.isin(short_speakers)

drop_columns = ["speaker_id", "neg", "neu", "pos"]
X_test = (_ := X.loc[criterion].drop(columns=drop_columns)).drop(columns="valence")
y_test = _.valence
X_train = (_ := X.loc[~criterion].drop(columns=drop_columns)).drop(columns="valence")
y_train = _.valence

len(X) == len(y_test) + len(y_train)
print(f"{len(y_test)} in test, {len(y_train)} in train")

True

190 in test, 290 in train
time: 28.7 ms


## Minimally Random Convolutional Kernel Transform

MiniRocket was [published in August 2021](https://doi.org/10.1145/3447548.3467231), touting state-of-the-art performance on benchmark time series classification tasks.

In [9]:
model = MiniRocketClassifier(random_state=SEED)

time: 3.28 ms


### Ternary results

In [10]:
fitted_minirocket = model.fit(X_train, y_train)

time: 1.11 s


How well would a dummy classifier do?

In [11]:
counts = y_test.value_counts()
len_test = len(y_test)
for valence in ("-1", "0", "1"):
    print(
        f"{(_ := counts[valence])} samples of valence {valence}: {(100 * _)/len_test:.2f}% of {len_test}"
    )

66 samples of valence -1: 34.74% of 190
85 samples of valence 0: 44.74% of 190
39 samples of valence 1: 20.53% of 190
time: 12 ms


How well did MINIROCKET do?

In [12]:
print(
    confusion_matrix(
        y_test,
        _ := fitted_minirocket.predict(X_test),
        labels=["-1", "0", "1"],
    ),
    classification_report(y_test, _),
)

[[35 22  9]
 [15 57 13]
 [12 17 10]]               precision    recall  f1-score   support

          -1       0.56      0.53      0.55        66
           0       0.59      0.67      0.63        85
           1       0.31      0.26      0.28        39

    accuracy                           0.54       190
   macro avg       0.49      0.49      0.49       190
weighted avg       0.53      0.54      0.53       190

time: 575 ms


Validation accuracy of 54% exceeds the proportion of the majority class (~45%) by about 9%.

### Binary Results

Next, we will repeat the above analysis with the binary cases. First, we need to set up the data.

In [13]:
OvrSet = namedtuple("OvrSet", "name, y_test, y_train")
binary_valence = [
    OvrSet(
        name=valence,
        y_test=X.loc[criterion][valence],
        y_train=X.loc[~criterion][valence],
    )
    for valence in ("neg", "neu", "pos")
]

time: 27 ms


How does MINIROCKET do in comparison to dummy classifiers in the binary cases?

In [14]:
for labels in binary_valence:
    y_test = labels.y_test
    percent = (100 * y_test.sum()) / len(y_test)
    print(
        f"majority classification percentage for {labels.name} valence: {percent if percent > 50 else 100 - percent:.3f}"
    )
    print(
        confusion_matrix(
            y_test,
            _ := model.fit(X_train, labels.y_train).predict(X_test),
        ),
        classification_report(y_test, _),
    )

majority classification percentage for neg valence: 65.263
[[115   9]
 [ 38  28]]               precision    recall  f1-score   support

           0       0.75      0.93      0.83       124
           1       0.76      0.42      0.54        66

    accuracy                           0.75       190
   macro avg       0.75      0.68      0.69       190
weighted avg       0.75      0.75      0.73       190

majority classification percentage for neu valence: 55.263
[[85 20]
 [44 41]]               precision    recall  f1-score   support

           0       0.66      0.81      0.73       105
           1       0.67      0.48      0.56        85

    accuracy                           0.66       190
   macro avg       0.67      0.65      0.64       190
weighted avg       0.66      0.66      0.65       190

majority classification percentage for pos valence: 79.474
[[142   9]
 [ 34   5]]               precision    recall  f1-score   support

           0       0.81      0.94      0.87      

In the negative/non-negative case, the dummy score on the test set was 65.3%, which underperformed the MINIROCKET classifier's score of 75% by about 9.7%.

In the neutral/non-neutral case, the dummy score on the test set was 55.3%, which underperformed the MINIROCKET classifier's score of 66% by about 10.7%.

In the positive/non-positive case, the dummy score on the test set was 79.5%, which outperformed the MINIROCKET classifier's score of 77% by about 2.5%.

## Discussion

In this notebook, we tested MINIROCKET on the spectrograms of the short set. Both the ternary and binary cases were considered. The MINIROCKET classifier was able to outperform the dummy classifier in all cases except the positive/non-positive binary case.

Class imbalance was the most drastic in the positive/non-positive case, which may have contributed to the poor performance in this scenario.

The MINIROCKET algorithm may have potential, especially if ensembled for one-vs-rest classification. Conversely, although preprocessing for `tsai` only needs to be computed once (and can be sped up signficantly with `swifter`), storing two versions of spectrogram arrays may be cumbersome in comparison to other methods.

[^top](#Contents)