In [1]:
import re 
import os
import pandas as pd
import numpy as np
import json
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

import IPython.display as ipd
import torchaudio
import torchaudio.transforms as T

import evaluate
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # cool M2 chip GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # for NVIDIA GPUs
print(f"Device: {device}")


eda_df = pd.read_csv("eda_iemocap_no_utts_dataset.csv")
eda_df = eda_df[["speaker", "utt_id", "EDA"]]
filename_ids = []
speaker_M_F = []
session_numbers = []
for i, row in eda_df.iterrows():
    match = re.search(r"b'(Ses(\d+)[MF]_.+\d+.*)_([MF])", row["speaker"])
    filename_ids.append(match.group(1))
    session_numbers.append(int(match.group(2))) 
    speaker_M_F.append(match.group(3))
eda_df = eda_df.drop(columns=["speaker"])
eda_df["filename"] = filename_ids
eda_df["filename"] = eda_df["filename"].astype(str)
eda_df["session_number"] = session_numbers
eda_df["session_number"] = eda_df["session_number"].astype(int)
eda_df["speaker"] = speaker_M_F
eda_df["speaker"] = eda_df["speaker"].astype(str)
eda_df["utt_id"] = eda_df["utt_id"].astype(int)

# Access transcipt files based on filename
utt_df = []
root_dir = "IEMOCAP_full_release/"
for i in range(1, 6):
    directory = os.path.join(root_dir, f"Session{i}/dialog/transcriptions/")
    for entry in os.scandir(directory):  
        if entry.is_file() and entry.path.endswith(".txt"):  # check if it's a file
            try:
                with open(entry.path, "r") as file:
                    filename = entry.path.split("/")[-1][:-4]
                    lines = file.readlines()
                    for order, line in enumerate(lines):
                        speaker_info, utterance = line.split(":")[0], line.split(":")[1]
                        pattern = r"(F|M)(\d+)\s\[(\d+\.\d+)-(\d+\.\d+)\]"
                        match = re.search(pattern, speaker_info)
                        if match is None:
                            continue
                        speaker_f_m = match.group(1)
                        utt_id = match.group(2)
                        start = match.group(3)
                        end = match.group(4)
                        utt_df.append({"utt_id": int(utt_id), "filename": str(filename), "start": float(start), "end": float(end), "num_frames":16000*(float(end)-float(start)), "speaker": str(speaker_f_m.strip()), "utterance": utterance.strip(), "session_number": int(i), "original_order": order})
            except:
                #print(entry.path) # these are meta files with ._ prepended to text file name
                continue
utt_df = pd.DataFrame(utt_df)
# Combine the EDA and utterances together
final_df = pd.merge(eda_df, utt_df, on=["utt_id", "session_number", "filename", "speaker"])
final_df
labels = list(set(Counter(final_df["EDA"]).keys())) # there are 34 labels
labels_to_num_mapping = {}
for i, label in enumerate(labels):
    labels_to_num_mapping[label] = i


train_scripted, val_scripted, test_scripted = pd.read_csv(f"gemaps_dataset_train_{"scripted"}.csv"), pd.read_csv(f"gemaps_dataset_val_{"scripted"}.csv"), pd.read_csv(f"gemaps_dataset_test_{"scripted"}.csv")
train_improv, val_improv, test_improv = pd.read_csv(f"gemaps_dataset_train_{"improv"}.csv"), pd.read_csv(f"gemaps_dataset_val_{"improv"}.csv"), pd.read_csv(f"gemaps_dataset_test_{"improv"}.csv")

scripted_improv = "improv"
if scripted_improv == "scripted":
    train = train_scripted
    val = val_scripted
    test = test_scripted

    train_other = train_improv
    val_other = val_improv
    test_other = test_improv
else:
    train = train_improv
    val = val_improv
    test = test_improv

    train_other = train_scripted
    val_other = val_scripted
    test_other = test_scripted


train["labels"] = train["labels"].map(lambda x: labels_to_num_mapping[x])
val["labels"] = val["labels"].map(lambda x: labels_to_num_mapping[x])
test["labels"] = test["labels"].map(lambda x: labels_to_num_mapping[x])
train_features, train_labels = train.drop(["labels", "Unnamed: 0"], axis=1).values, train["labels"].values
val_features, val_labels = val.drop(["labels", "Unnamed: 0"], axis=1).values, val["labels"].values
test_features, test_labels = test.drop(["labels", "Unnamed: 0"], axis=1).values, test["labels"].values


    
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

features_scaled_train = scaler.fit_transform(train_features)
features_tensor_train = torch.tensor(features_scaled_train, dtype=torch.float32)
labels_tensor_train = torch.tensor(train_labels, dtype=torch.long).view(-1, 1)

features_scaled_val = scaler.fit_transform(val_features)
features_tensor_val = torch.tensor(features_scaled_val, dtype=torch.float32)
labels_tensor_val = torch.tensor(val_labels, dtype=torch.long).view(-1, 1)

features_scaled_test = scaler.fit_transform(test_features)
features_tensor_test = torch.tensor(features_scaled_test, dtype=torch.float32)
labels_tensor_test = torch.tensor(test_labels, dtype=torch.long).view(-1, 1)

from torch.utils.data import DataLoader, TensorDataset
# Create TensorDatasets for both training and testing data
train_dataset = TensorDataset(features_tensor_train, labels_tensor_train)
val_dataset = TensorDataset(features_tensor_val, labels_tensor_val)
test_dataset = TensorDataset(features_tensor_test, labels_tensor_test)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class GemapsClassifier(nn.Module):
    def __init__(self, feature_dim=88):
        super().__init__()
        self.fc1 = nn.Linear(feature_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, len(labels_to_num_mapping.keys()))
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        X = self.fc3(x)
        return x
    
model = GemapsClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        """
        Args:
            patience (int): How many epochs to wait after last time validation loss improved.
                            Default: 5
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                           Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss)
            self.counter = 0

    def save_checkpoint(self, val_loss):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        self.val_loss_min = val_loss
        self.best_model_state = model.state_dict() 

def train_model(num_epochs, model, loaders, early_stopping):
    model.train()
    for epoch in range(num_epochs):
        for data, target in loaders['train']:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target.squeeze())
            loss.backward()
            optimizer.step()

        # Validation loss during training
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, target in loaders['test']:
                output = model(data)
                val_loss += criterion(output, target.squeeze()).item()
        
        val_loss /= len(loaders['test'].dataset)
        print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')

        # Call early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    if early_stopping.best_model_state:
        model.load_state_dict(early_stopping.best_model_state)

# Initialize the EarlyStopping
early_stopping = EarlyStopping(patience=7, verbose=True)

loaders = {
    'train': train_loader,
    'test': val_loader
}
# Train the model with early stopping
train_model(50, model, loaders, early_stopping)

Epoch 1, Validation Loss: 0.047049596549707876
Validation loss decreased (inf --> 0.047050).  Saving model ...
Epoch 2, Validation Loss: 0.03454713497959306
Validation loss decreased (0.047050 --> 0.034547).  Saving model ...
Epoch 3, Validation Loss: 0.03292963086745031
Validation loss decreased (0.034547 --> 0.032930).  Saving model ...
Epoch 4, Validation Loss: 0.03269256620887197
Validation loss decreased (0.032930 --> 0.032693).  Saving model ...
Epoch 5, Validation Loss: 0.03184311035956946
Validation loss decreased (0.032693 --> 0.031843).  Saving model ...
Epoch 6, Validation Loss: 0.03185714837227258
EarlyStopping counter: 1 out of 7
Epoch 7, Validation Loss: 0.03183039249820514
Validation loss decreased (0.031843 --> 0.031830).  Saving model ...
Epoch 8, Validation Loss: 0.03199966587507684
EarlyStopping counter: 1 out of 7
Epoch 9, Validation Loss: 0.031755352793293196
Validation loss decreased (0.031830 --> 0.031755).  Saving model ...
Epoch 10, Validation Loss: 0.031816739

In [3]:
from sklearn.metrics import classification_report

def evaluate_model(model, loader, criterion):
    model.eval() 
    total_loss = 0
    preds_eval = []
    labels_eval = []
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            _, predictions = torch.max(output, 1)
            preds_eval.extend(list(predictions))
            labels_eval.extend(list(target.squeeze()))
            loss = criterion(output, target.squeeze())
            total_loss += loss.item()
    avg_loss = total_loss / len(loader.dataset)
    print(f'Average Validation Loss: {avg_loss:.4f}')
    print(f"Classification report:")
    print(classification_report(labels_eval, preds_eval))
    return avg_loss

evaluate_model(model, test_loader, criterion)

Average Validation Loss: 0.0381
Classification report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        24
           1       0.00      0.00      0.00         1
           2       0.00      0.00      0.00         2
           3       0.06      0.07      0.06        15
           4       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         1
           9       0.27      0.22      0.24        18
          14       0.00      0.00      0.00         2
          17       0.00      0.00      0.00         3
          18       0.00      0.00      0.00         5
          19       0.00      0.00      0.00        11
          20       0.42      0.82      0.56       114
          23       0.00      0.00      0.00         1
          24       0.00      0.00      0.00         1
          27       0.00      0.00      0.00         1
          28       0.37   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0.03810112646647862

In [4]:
test_other["labels"] = test_other["labels"].map(lambda x: labels_to_num_mapping[x])
test_features_other, test_labels_other = test_other.drop(["labels", "Unnamed: 0"], axis=1).values, test_other["labels"].values

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

features_scaled_test_other = scaler.fit_transform(test_features_other)
features_tensor_test_other = torch.tensor(features_scaled_test_other, dtype=torch.float32)
labels_tensor_test_other = torch.tensor(test_labels_other, dtype=torch.long).view(-1, 1)

from torch.utils.data import DataLoader, TensorDataset

test_dataset_other = TensorDataset(features_tensor_test_other, labels_tensor_test_other)
test_loader_other = DataLoader(test_dataset_other, batch_size=64, shuffle=False)

evaluate_model(model, test_loader_other, criterion)

Average Validation Loss: 0.0393
Classification report:
              precision    recall  f1-score   support

           0       0.25      0.05      0.08        62
           1       0.00      0.00      0.00        21
           3       0.12      0.05      0.07        42
           6       0.00      0.00      0.00         5
           9       0.14      0.57      0.22         7
          12       0.00      0.00      0.00        16
          14       0.00      0.00      0.00         1
          17       0.00      0.00      0.00         1
          19       0.17      0.04      0.07        23
          20       0.48      0.85      0.61       212
          23       0.00      0.00      0.00         1
          24       0.00      0.00      0.00         1
          27       0.00      0.00      0.00         5
          28       0.10      0.04      0.06        68
          31       0.00      0.00      0.00         4
          33       0.00      0.00      0.00         1

    accuracy             

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0.03934157985322019

# Feature Extraction (GeMaps)

In [None]:
# Feature Extraction (GeMaps)
# # 1. Access the corresponding .txt, .wav, and .avi files for each EDA label
# # Extract the conversation filename and speaker informationfrom the dataset
# eda_df = pd.read_csv("eda_iemocap_no_utts_dataset.csv")
# eda_df = eda_df[["speaker", "utt_id", "EDA"]]
# filename_ids = []
# speaker_M_F = []
# session_numbers = []
# for i, row in eda_df.iterrows():
#     match = re.search(r"b'(Ses(\d+)[MF]_.+\d+.*)_([MF])", row["speaker"])
#     filename_ids.append(match.group(1))
#     session_numbers.append(int(match.group(2))) 
#     speaker_M_F.append(match.group(3))
# eda_df = eda_df.drop(columns=["speaker"])
# eda_df["filename"] = filename_ids
# eda_df["filename"] = eda_df["filename"].astype(str)
# eda_df["session_number"] = session_numbers
# eda_df["session_number"] = eda_df["session_number"].astype(int)
# eda_df["speaker"] = speaker_M_F
# eda_df["speaker"] = eda_df["speaker"].astype(str)
# eda_df["utt_id"] = eda_df["utt_id"].astype(int)

# # Access transcipt files based on filename
# utt_df = []
# root_dir = "IEMOCAP_full_release/"
# for i in range(1, 6):
#     directory = os.path.join(root_dir, f"Session{i}/dialog/transcriptions/")
#     for entry in os.scandir(directory):  
#         if entry.is_file() and entry.path.endswith(".txt"):  # check if it's a file
#             try:
#                 with open(entry.path, "r") as file:
#                     filename = entry.path.split("/")[-1][:-4]
#                     lines = file.readlines()
#                     for order, line in enumerate(lines):
#                         speaker_info, utterance = line.split(":")[0], line.split(":")[1]
#                         pattern = r"(F|M)(\d+)\s\[(\d+\.\d+)-(\d+\.\d+)\]"
#                         match = re.search(pattern, speaker_info)
#                         if match is None:
#                             continue
#                         speaker_f_m = match.group(1)
#                         utt_id = match.group(2)
#                         start = match.group(3)
#                         end = match.group(4)
#                         utt_df.append({"utt_id": int(utt_id), "filename": str(filename), "start": float(start), "end": float(end), "speaker": str(speaker_f_m.strip()), "utterance": utterance.strip(), "session_number": int(i), "original_order": order})
#             except:
#                 #print(entry.path) # these are meta files with ._ prepended to text file name
#                 continue
# utt_df = pd.DataFrame(utt_df)
# # Combine the EDA and utterances together
# final_df = pd.merge(eda_df, utt_df, on=["utt_id", "session_number", "filename", "speaker"])
# final_df

# def scripted_splits():    
#     # session 1
#     df_scripted_session_1_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_1_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_1_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_1_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_1_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_1_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])

#     # session 2
#     df_scripted_session_2_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_2_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_2_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_2_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_2_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_2_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])

#     # session 3
#     df_scripted_session_3_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_3_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_3_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_3_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_3_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_3_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])

#     # session 4
#     df_scripted_session_4_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_4_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_4_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_4_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_4_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_4_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])

#     # session 5
#     df_scripted_session_5_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_5_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_5_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_5_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_5_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
#     df_scripted_session_5_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])

#     # need to this for each script across sections because although they are the same script, the lines are not memorized perfectly and so there are some length differences
#     min_script_1 = min([len(df_scripted_session_1_script_1_F), len(df_scripted_session_1_script_1_M),\
#                         len(df_scripted_session_2_script_1_F), len(df_scripted_session_2_script_1_M),\
#                         len(df_scripted_session_3_script_1_F), len(df_scripted_session_3_script_1_M),\
#                         len(df_scripted_session_4_script_1_F), len(df_scripted_session_4_script_1_M),\
#                         len(df_scripted_session_5_script_1_F), len(df_scripted_session_5_script_1_M)])

#     min_script_2 = min([len(df_scripted_session_1_script_2_F), len(df_scripted_session_1_script_2_M),\
#                         len(df_scripted_session_2_script_2_F), len(df_scripted_session_2_script_2_M),\
#                         len(df_scripted_session_3_script_2_F), len(df_scripted_session_3_script_2_M),\
#                         len(df_scripted_session_4_script_2_F), len(df_scripted_session_4_script_2_M),\
#                         len(df_scripted_session_5_script_2_F), len(df_scripted_session_5_script_2_M)])

#     min_script_3 = min([len(df_scripted_session_1_script_3_F), len(df_scripted_session_1_script_3_M),\
#                         len(df_scripted_session_2_script_3_F), len(df_scripted_session_2_script_3_M),\
#                         len(df_scripted_session_3_script_3_F), len(df_scripted_session_3_script_3_M),\
#                         len(df_scripted_session_4_script_3_F), len(df_scripted_session_4_script_3_M),\
#                         len(df_scripted_session_5_script_3_F), len(df_scripted_session_5_script_3_M)])

#     train_script_1 = int(min_script_1*0.8)
#     val_script_1 = (min_script_1-train_script_1)//2

#     train_script_2 = int(min_script_2*0.8)
#     val_script_2 = (min_script_2-train_script_2)//2

#     train_script_3 = int(min_script_3*0.8)
#     val_script_3 = (min_script_3-train_script_3)//2

#     #df_scripted_session_1_script_1_F.sample(frac=1) # going to ignore this for now and not shuffle within each script because that means merging sessions... to confusing and not perfect matching
#     df_scripted_train = pd.concat([df_scripted_session_1_script_1_F[:train_script_1],
#                                 df_scripted_session_1_script_1_M[:train_script_1],
#                                 df_scripted_session_2_script_1_F[:train_script_1], 
#                                 df_scripted_session_2_script_1_M[:train_script_1], 
#                                 df_scripted_session_3_script_1_F[:train_script_1],
#                                 df_scripted_session_3_script_1_M[:train_script_1],
#                                 df_scripted_session_4_script_1_F[:train_script_1],
#                                 df_scripted_session_4_script_1_M[:train_script_1],
#                                 df_scripted_session_5_script_1_F[:train_script_1],
#                                 df_scripted_session_5_script_1_M[:train_script_1],
#                                 df_scripted_session_1_script_2_F[:train_script_2],
#                                 df_scripted_session_1_script_2_M[:train_script_2],
#                                 df_scripted_session_2_script_2_F[:train_script_2], 
#                                 df_scripted_session_2_script_2_M[:train_script_2], 
#                                 df_scripted_session_3_script_2_F[:train_script_2],
#                                 df_scripted_session_3_script_2_M[:train_script_2],
#                                 df_scripted_session_4_script_2_F[:train_script_2],
#                                 df_scripted_session_4_script_2_M[:train_script_2],
#                                 df_scripted_session_5_script_2_F[:train_script_2],
#                                 df_scripted_session_5_script_2_M[:train_script_2],
#                                 df_scripted_session_1_script_3_F[:train_script_3],
#                                 df_scripted_session_1_script_3_M[:train_script_3],
#                                 df_scripted_session_2_script_3_F[:train_script_3], 
#                                 df_scripted_session_2_script_3_M[:train_script_3], 
#                                 df_scripted_session_3_script_3_F[:train_script_3],
#                                 df_scripted_session_3_script_3_M[:train_script_3],
#                                 df_scripted_session_4_script_3_F[:train_script_3],
#                                 df_scripted_session_4_script_3_M[:train_script_3],
#                                 df_scripted_session_5_script_3_F[:train_script_3],
#                                 df_scripted_session_5_script_3_M[:train_script_3]])

#     df_scripted_val = pd.concat([df_scripted_session_1_script_1_F[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_1_script_1_M[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_2_script_1_F[train_script_1: train_script_1+val_script_1], 
#                                 df_scripted_session_2_script_1_M[train_script_1: train_script_1+val_script_1], 
#                                 df_scripted_session_3_script_1_F[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_3_script_1_M[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_4_script_1_F[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_4_script_1_M[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_5_script_1_F[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_5_script_1_M[train_script_1: train_script_1+val_script_1],
#                                 df_scripted_session_1_script_2_F[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_1_script_2_M[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_2_script_2_F[train_script_2: train_script_2+val_script_2], 
#                                 df_scripted_session_2_script_2_M[train_script_2: train_script_2+val_script_2], 
#                                 df_scripted_session_3_script_2_F[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_3_script_2_M[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_4_script_2_F[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_4_script_2_M[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_5_script_2_F[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_5_script_2_M[train_script_2: train_script_2+val_script_2],
#                                 df_scripted_session_1_script_3_F[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_1_script_3_M[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_2_script_3_F[train_script_3: train_script_3+val_script_3], 
#                                 df_scripted_session_2_script_3_M[train_script_3: train_script_3+val_script_3], 
#                                 df_scripted_session_3_script_3_F[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_3_script_3_M[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_4_script_3_F[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_4_script_3_M[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_5_script_3_F[train_script_3: train_script_3+val_script_3],
#                                 df_scripted_session_5_script_3_M[train_script_3: train_script_3+val_script_3]])

#     df_scripted_test = pd.concat([df_scripted_session_1_script_1_F[train_script_1+val_script_1:],
#                                 df_scripted_session_1_script_1_M[train_script_1+val_script_1:],
#                                 df_scripted_session_2_script_1_F[train_script_1+val_script_1:], 
#                                 df_scripted_session_2_script_1_M[train_script_1+val_script_1:], 
#                                 df_scripted_session_3_script_1_F[train_script_1+val_script_1:],
#                                 df_scripted_session_3_script_1_M[train_script_1+val_script_1:],
#                                 df_scripted_session_4_script_1_F[train_script_1+val_script_1:],
#                                 df_scripted_session_5_script_1_F[train_script_1+val_script_1:],
#                                 df_scripted_session_5_script_1_M[train_script_1+val_script_1:],
#                                 df_scripted_session_1_script_2_F[train_script_2+val_script_2:],
#                                 df_scripted_session_1_script_2_M[train_script_2+val_script_2:],
#                                 df_scripted_session_2_script_2_F[train_script_2+val_script_2:], 
#                                 df_scripted_session_2_script_2_M[train_script_2+val_script_2:], 
#                                 df_scripted_session_3_script_2_F[train_script_2+val_script_2:],
#                                 df_scripted_session_3_script_2_M[train_script_2+val_script_2:],
#                                 df_scripted_session_4_script_2_F[train_script_2+val_script_2:],
#                                 df_scripted_session_4_script_2_M[train_script_2+val_script_2:],
#                                 df_scripted_session_5_script_2_F[train_script_2+val_script_2:],
#                                 df_scripted_session_5_script_2_M[train_script_2+val_script_2:],
#                                 df_scripted_session_1_script_3_F[train_script_3+val_script_3:],
#                                 df_scripted_session_1_script_3_M[train_script_3+val_script_3:],
#                                 df_scripted_session_2_script_3_F[train_script_3+val_script_3:], 
#                                 df_scripted_session_2_script_3_M[train_script_3+val_script_3:], 
#                                 df_scripted_session_3_script_3_F[train_script_3+val_script_3:],
#                                 df_scripted_session_3_script_3_M[train_script_3+val_script_3:],
#                                 df_scripted_session_4_script_3_F[train_script_3+val_script_3:],
#                                 df_scripted_session_4_script_3_M[train_script_3+val_script_3:],
#                                 df_scripted_session_5_script_3_F[train_script_3+val_script_3:],
#                                 df_scripted_session_5_script_3_M[train_script_3+val_script_3:]])

#     # need to split separately across the different sessions since they all have the same scripts
#     # df_session_1_script_1 = pd.merge(df_scripted_session_1_script_1_F, df_scripted_session_1_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
#     # df_session_2_script_1 = pd.merge(df_scripted_session_2_script_1_F, df_scripted_session_2_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
#     # df_session_3_script_1 = pd.merge(df_scripted_session_3_script_1_F, df_scripted_session_3_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
#     # df_session_4_script_1 = pd.merge(df_scripted_session_4_script_1_F, df_scripted_session_4_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
#     # df_session_5_script_1 = pd.merge(df_scripted_session_5_script_1_F, df_scripted_session_5_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])

#     return df_scripted_train, df_scripted_val, df_scripted_test

# def impro_splits():    
#     # session 1
#     df_impro_session_1_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_1_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])

#     # session 2
#     df_impro_session_2_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_2_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])

#     # session 3
#     df_impro_session_3_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_3_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])

#     # session 4
#     df_impro_session_4_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_4_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])

#     # session 5
#     df_impro_session_5_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
#     df_impro_session_5_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])

#     # need to this for each impro across sections because although they are the same impro, the lines are not memorized perfectly and so there are some length differences
#     min_impro_1 = min([len(df_impro_session_1_impro_1_F), len(df_impro_session_1_impro_1_M),\
#                         len(df_impro_session_2_impro_1_F), len(df_impro_session_2_impro_1_M),\
#                         len(df_impro_session_3_impro_1_F), len(df_impro_session_3_impro_1_M),\
#                         len(df_impro_session_4_impro_1_F), len(df_impro_session_4_impro_1_M),\
#                         len(df_impro_session_5_impro_1_F), len(df_impro_session_5_impro_1_M)])

#     min_impro_2 = min([len(df_impro_session_1_impro_2_F), len(df_impro_session_1_impro_2_M),\
#                         len(df_impro_session_2_impro_2_F), len(df_impro_session_2_impro_2_M),\
#                         len(df_impro_session_3_impro_2_F), len(df_impro_session_3_impro_2_M),\
#                         len(df_impro_session_4_impro_2_F), len(df_impro_session_4_impro_2_M),\
#                         len(df_impro_session_5_impro_2_F), len(df_impro_session_5_impro_2_M)])

#     min_impro_3 = min([len(df_impro_session_1_impro_3_F), len(df_impro_session_1_impro_3_M),\
#                         len(df_impro_session_2_impro_3_F), len(df_impro_session_2_impro_3_M),\
#                         len(df_impro_session_3_impro_3_F), len(df_impro_session_3_impro_3_M),\
#                         len(df_impro_session_4_impro_3_F), len(df_impro_session_4_impro_3_M),\
#                         len(df_impro_session_5_impro_3_F), len(df_impro_session_5_impro_3_M)])

#     min_impro_4 = min([len(df_impro_session_1_impro_4_F), len(df_impro_session_1_impro_4_M),\
#                         len(df_impro_session_2_impro_4_F), len(df_impro_session_2_impro_4_M),\
#                         len(df_impro_session_3_impro_4_F), len(df_impro_session_3_impro_4_M),\
#                         len(df_impro_session_4_impro_4_F), len(df_impro_session_4_impro_4_M),\
#                         len(df_impro_session_5_impro_4_F), len(df_impro_session_5_impro_4_M)])

#     min_impro_5 = min([len(df_impro_session_1_impro_5_F), len(df_impro_session_1_impro_5_M),\
#                             len(df_impro_session_2_impro_5_F), len(df_impro_session_2_impro_5_M),\
#                             len(df_impro_session_3_impro_5_F), len(df_impro_session_3_impro_5_M),\
#                             len(df_impro_session_4_impro_5_F), len(df_impro_session_4_impro_5_M),\
#                             len(df_impro_session_5_impro_5_F), len(df_impro_session_5_impro_5_M)])

#     min_impro_6 = min([len(df_impro_session_1_impro_6_F), len(df_impro_session_1_impro_6_M),\
#                             len(df_impro_session_2_impro_6_F), len(df_impro_session_2_impro_6_M),\
#                             len(df_impro_session_3_impro_6_F), len(df_impro_session_3_impro_6_M),\
#                             len(df_impro_session_4_impro_6_F), len(df_impro_session_4_impro_6_M),\
#                             len(df_impro_session_5_impro_6_F), len(df_impro_session_5_impro_6_M)])

#     min_impro_7 = min([len(df_impro_session_1_impro_7_F), len(df_impro_session_1_impro_7_M),\
#                             len(df_impro_session_2_impro_7_F), len(df_impro_session_2_impro_7_M),\
#                             len(df_impro_session_3_impro_7_F), len(df_impro_session_3_impro_7_M),\
#                             len(df_impro_session_4_impro_7_F), len(df_impro_session_4_impro_7_M),\
#                             len(df_impro_session_5_impro_7_F), len(df_impro_session_5_impro_7_M)])

#     min_impro_8 = min([len(df_impro_session_2_impro_8_F), len(df_impro_session_2_impro_8_M),\
#                             len(df_impro_session_3_impro_8_F), len(df_impro_session_3_impro_8_M),\
#                             len(df_impro_session_4_impro_8_F), len(df_impro_session_4_impro_8_M),\
#                             len(df_impro_session_5_impro_8_F), len(df_impro_session_5_impro_8_M)])

#     train_impro_1 = int(min_impro_1*0.8)
#     val_impro_1 = (min_impro_1-train_impro_1)//2

#     train_impro_2 = int(min_impro_2*0.8)
#     val_impro_2 = (min_impro_2-train_impro_2)//2

#     train_impro_3 = int(min_impro_3*0.8)
#     val_impro_3 = (min_impro_3-train_impro_3)//2

#     train_impro_4 = int(min_impro_4*0.8)
#     val_impro_4 = (min_impro_4-train_impro_4)//2

#     train_impro_5 = int(min_impro_5*0.8)
#     val_impro_5 = (min_impro_5-train_impro_5)//2

#     train_impro_6 = int(min_impro_6*0.8)
#     val_impro_6 = (min_impro_6-train_impro_6)//2

#     train_impro_7 = int(min_impro_7*0.8)
#     val_impro_7 = (min_impro_7-train_impro_7)//2

#     train_impro_8 = int(min_impro_8*0.8)
#     val_impro_8 = (min_impro_8-train_impro_8)//2


#     #df_impro_session_1_impro_1_F.sample(frac=1) # going to ignore this for now and not shuffle within each impro because that means merging sessions... to confusing and not perfect matching
#     df_impro_train = pd.concat([df_impro_session_1_impro_1_F[:train_impro_1],
#                                 df_impro_session_1_impro_1_M[:train_impro_1],
#                                 df_impro_session_2_impro_1_F[:train_impro_1], 
#                                 df_impro_session_2_impro_1_M[:train_impro_1], 
#                                 df_impro_session_3_impro_1_F[:train_impro_1],
#                                 df_impro_session_3_impro_1_M[:train_impro_1],
#                                 df_impro_session_4_impro_1_F[:train_impro_1],
#                                 df_impro_session_4_impro_1_M[:train_impro_1],
#                                 df_impro_session_5_impro_1_F[:train_impro_1],
#                                 df_impro_session_5_impro_1_M[:train_impro_1],
#                                 df_impro_session_1_impro_2_F[:train_impro_2],
#                                 df_impro_session_1_impro_2_M[:train_impro_2],
#                                 df_impro_session_2_impro_2_F[:train_impro_2], 
#                                 df_impro_session_2_impro_2_M[:train_impro_2], 
#                                 df_impro_session_3_impro_2_F[:train_impro_2],
#                                 df_impro_session_3_impro_2_M[:train_impro_2],
#                                 df_impro_session_4_impro_2_F[:train_impro_2],
#                                 df_impro_session_4_impro_2_M[:train_impro_2],
#                                 df_impro_session_5_impro_2_F[:train_impro_2],
#                                 df_impro_session_5_impro_2_M[:train_impro_2],
#                                 df_impro_session_1_impro_3_F[:train_impro_3],
#                                 df_impro_session_1_impro_3_M[:train_impro_3],
#                                 df_impro_session_2_impro_3_F[:train_impro_3], 
#                                 df_impro_session_2_impro_3_M[:train_impro_3], 
#                                 df_impro_session_3_impro_3_F[:train_impro_3],
#                                 df_impro_session_3_impro_3_M[:train_impro_3],
#                                 df_impro_session_4_impro_3_F[:train_impro_3],
#                                 df_impro_session_4_impro_3_M[:train_impro_3],
#                                 df_impro_session_5_impro_3_F[:train_impro_3],
#                                 df_impro_session_5_impro_3_M[:train_impro_3],
#                                 df_impro_session_1_impro_4_F[:train_impro_4],
#                                 df_impro_session_1_impro_4_M[:train_impro_4],
#                                 df_impro_session_2_impro_4_F[:train_impro_4], 
#                                 df_impro_session_2_impro_4_M[:train_impro_4], 
#                                 df_impro_session_3_impro_4_F[:train_impro_4],
#                                 df_impro_session_3_impro_4_M[:train_impro_4],
#                                 df_impro_session_4_impro_4_F[:train_impro_4],
#                                 df_impro_session_4_impro_4_M[:train_impro_4],
#                                 df_impro_session_5_impro_4_F[:train_impro_4],
#                                 df_impro_session_5_impro_4_M[:train_impro_4],
#                                 df_impro_session_1_impro_5_F[:train_impro_5],
#                                 df_impro_session_1_impro_5_M[:train_impro_5],
#                                 df_impro_session_2_impro_5_F[:train_impro_5], 
#                                 df_impro_session_2_impro_5_M[:train_impro_5], 
#                                 df_impro_session_3_impro_5_F[:train_impro_5],
#                                 df_impro_session_3_impro_5_M[:train_impro_5],
#                                 df_impro_session_4_impro_5_F[:train_impro_5],
#                                 df_impro_session_4_impro_5_M[:train_impro_5],
#                                 df_impro_session_5_impro_5_F[:train_impro_5],
#                                 df_impro_session_5_impro_5_M[:train_impro_5],
#                                 df_impro_session_1_impro_6_F[:train_impro_6],
#                                 df_impro_session_1_impro_6_M[:train_impro_6],
#                                 df_impro_session_2_impro_6_F[:train_impro_6], 
#                                 df_impro_session_2_impro_6_M[:train_impro_6], 
#                                 df_impro_session_3_impro_6_F[:train_impro_6],
#                                 df_impro_session_3_impro_6_M[:train_impro_6],
#                                 df_impro_session_4_impro_6_F[:train_impro_6],
#                                 df_impro_session_4_impro_6_M[:train_impro_6],
#                                 df_impro_session_5_impro_6_F[:train_impro_6],
#                                 df_impro_session_5_impro_6_M[:train_impro_6],
#                                 df_impro_session_1_impro_7_F[:train_impro_7],
#                                 df_impro_session_1_impro_7_M[:train_impro_7],
#                                 df_impro_session_2_impro_7_F[:train_impro_7], 
#                                 df_impro_session_2_impro_7_M[:train_impro_7], 
#                                 df_impro_session_3_impro_7_F[:train_impro_7],
#                                 df_impro_session_3_impro_7_M[:train_impro_7],
#                                 df_impro_session_4_impro_7_F[:train_impro_7],
#                                 df_impro_session_4_impro_7_M[:train_impro_7],
#                                 df_impro_session_5_impro_7_F[:train_impro_7],
#                                 df_impro_session_5_impro_7_M[:train_impro_7],
#                                 df_impro_session_2_impro_8_F[:train_impro_8], 
#                                 df_impro_session_2_impro_8_M[:train_impro_8], 
#                                 df_impro_session_3_impro_8_F[:train_impro_8],
#                                 df_impro_session_3_impro_8_M[:train_impro_8],
#                                 df_impro_session_4_impro_8_F[:train_impro_8],
#                                 df_impro_session_4_impro_8_M[:train_impro_8],
#                                 df_impro_session_5_impro_8_F[:train_impro_8],
#                                 df_impro_session_5_impro_8_M[:train_impro_8]])  
#     df_impro_val = pd.concat([df_impro_session_1_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_1_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_2_impro_1_F[train_impro_1: train_impro_1+val_impro_1], 
#                                 df_impro_session_2_impro_1_M[train_impro_1: train_impro_1+val_impro_1], 
#                                 df_impro_session_3_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_3_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_4_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_4_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_5_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_5_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
#                                 df_impro_session_1_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_1_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_2_impro_2_F[train_impro_2: train_impro_2+val_impro_2], 
#                                 df_impro_session_2_impro_2_M[train_impro_2: train_impro_2+val_impro_2], 
#                                 df_impro_session_3_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_3_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_4_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_4_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_5_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_5_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
#                                 df_impro_session_1_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_1_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_2_impro_3_F[train_impro_3: train_impro_3+val_impro_3], 
#                                 df_impro_session_2_impro_3_M[train_impro_3: train_impro_3+val_impro_3], 
#                                 df_impro_session_3_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_3_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_4_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_4_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_5_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_5_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
#                                 df_impro_session_1_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_1_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_2_impro_4_F[train_impro_4: train_impro_4+val_impro_4], 
#                                 df_impro_session_2_impro_4_M[train_impro_4: train_impro_4+val_impro_4], 
#                                 df_impro_session_3_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_3_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_4_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_4_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_5_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_5_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
#                                 df_impro_session_1_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_1_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_2_impro_5_F[train_impro_5: train_impro_5+val_impro_5], 
#                                 df_impro_session_2_impro_5_M[train_impro_5: train_impro_5+val_impro_5], 
#                                 df_impro_session_3_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_3_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_4_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_4_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_5_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_5_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
#                                 df_impro_session_1_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_1_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_2_impro_6_F[train_impro_6: train_impro_6+val_impro_6], 
#                                 df_impro_session_2_impro_6_M[train_impro_6: train_impro_6+val_impro_6], 
#                                 df_impro_session_3_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_3_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_4_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_4_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_5_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_5_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
#                                 df_impro_session_1_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_1_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_2_impro_7_F[train_impro_7: train_impro_7+val_impro_7], 
#                                 df_impro_session_2_impro_7_M[train_impro_7: train_impro_7+val_impro_7], 
#                                 df_impro_session_3_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_3_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_4_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_4_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_5_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_5_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
#                                 df_impro_session_2_impro_8_F[train_impro_8: train_impro_8+val_impro_8], 
#                                 df_impro_session_2_impro_8_M[train_impro_8: train_impro_8+val_impro_8], 
#                                 df_impro_session_3_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
#                                 df_impro_session_3_impro_8_M[train_impro_8: train_impro_8+val_impro_8],
#                                 df_impro_session_4_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
#                                 df_impro_session_4_impro_8_M[train_impro_8: train_impro_8+val_impro_8],
#                                 df_impro_session_5_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
#                                 df_impro_session_5_impro_8_M[train_impro_8: train_impro_8+val_impro_8]]) 
#     df_impro_test = pd.concat([df_impro_session_1_impro_1_F[train_impro_1+val_impro_1:],
#                                 df_impro_session_1_impro_1_M[train_impro_1+val_impro_1:],
#                                 df_impro_session_2_impro_1_F[train_impro_1+val_impro_1:], 
#                                 df_impro_session_2_impro_1_M[train_impro_1+val_impro_1:], 
#                                 df_impro_session_3_impro_1_F[train_impro_1+val_impro_1:],
#                                 df_impro_session_3_impro_1_M[train_impro_1+val_impro_1:],
#                                 df_impro_session_4_impro_1_F[train_impro_1+val_impro_1:],
#                                 df_impro_session_5_impro_1_F[train_impro_1+val_impro_1:],
#                                 df_impro_session_5_impro_1_M[train_impro_1+val_impro_1:],
#                                 df_impro_session_1_impro_2_F[train_impro_2+val_impro_2:],
#                                 df_impro_session_1_impro_2_M[train_impro_2+val_impro_2:],
#                                 df_impro_session_2_impro_2_F[train_impro_2+val_impro_2:], 
#                                 df_impro_session_2_impro_2_M[train_impro_2+val_impro_2:], 
#                                 df_impro_session_3_impro_2_F[train_impro_2+val_impro_2:],
#                                 df_impro_session_3_impro_2_M[train_impro_2+val_impro_2:],
#                                 df_impro_session_4_impro_2_F[train_impro_2+val_impro_2:],
#                                 df_impro_session_4_impro_2_M[train_impro_2+val_impro_2:],
#                                 df_impro_session_5_impro_2_F[train_impro_2+val_impro_2:],
#                                 df_impro_session_5_impro_2_M[train_impro_2+val_impro_2:],
#                                 df_impro_session_1_impro_3_F[train_impro_3+val_impro_3:],
#                                 df_impro_session_1_impro_3_M[train_impro_3+val_impro_3:],
#                                 df_impro_session_2_impro_3_F[train_impro_3+val_impro_3:], 
#                                 df_impro_session_2_impro_3_M[train_impro_3+val_impro_3:], 
#                                 df_impro_session_3_impro_3_F[train_impro_3+val_impro_3:],
#                                 df_impro_session_3_impro_3_M[train_impro_3+val_impro_3:],
#                                 df_impro_session_4_impro_3_F[train_impro_3+val_impro_3:],
#                                 df_impro_session_4_impro_3_M[train_impro_3+val_impro_3:],
#                                 df_impro_session_5_impro_3_F[train_impro_3+val_impro_3:],
#                                 df_impro_session_5_impro_3_M[train_impro_3+val_impro_3:],
#                                 df_impro_session_1_impro_4_F[train_impro_4+val_impro_4:],
#                                 df_impro_session_1_impro_4_M[train_impro_4+val_impro_4:],
#                                 df_impro_session_2_impro_4_F[train_impro_4+val_impro_4:], 
#                                 df_impro_session_2_impro_4_M[train_impro_4+val_impro_4:], 
#                                 df_impro_session_3_impro_4_F[train_impro_4+val_impro_4:],
#                                 df_impro_session_3_impro_4_M[train_impro_4+val_impro_4:],
#                                 df_impro_session_4_impro_4_F[train_impro_4+val_impro_4:],
#                                 df_impro_session_4_impro_4_M[train_impro_4+val_impro_4:],
#                                 df_impro_session_5_impro_4_F[train_impro_4+val_impro_4:],
#                                 df_impro_session_5_impro_4_M[train_impro_4+val_impro_4:],
#                                 df_impro_session_1_impro_5_F[train_impro_5+val_impro_5:],
#                                 df_impro_session_1_impro_5_M[train_impro_5+val_impro_5:],
#                                 df_impro_session_2_impro_5_F[train_impro_5+val_impro_5:], 
#                                 df_impro_session_2_impro_5_M[train_impro_5+val_impro_5:], 
#                                 df_impro_session_3_impro_5_F[train_impro_5+val_impro_5:],
#                                 df_impro_session_3_impro_5_M[train_impro_5+val_impro_5:],
#                                 df_impro_session_4_impro_5_F[train_impro_5+val_impro_5:],
#                                 df_impro_session_4_impro_5_M[train_impro_5+val_impro_5:],
#                                 df_impro_session_5_impro_5_F[train_impro_5+val_impro_5:],
#                                 df_impro_session_5_impro_5_M[train_impro_5+val_impro_5:],
#                                 df_impro_session_1_impro_6_F[train_impro_6+val_impro_6:],
#                                 df_impro_session_1_impro_6_M[train_impro_6+val_impro_6:],
#                                 df_impro_session_2_impro_6_F[train_impro_6+val_impro_6:], 
#                                 df_impro_session_2_impro_6_M[train_impro_6+val_impro_6:], 
#                                 df_impro_session_3_impro_6_F[train_impro_6+val_impro_6:],
#                                 df_impro_session_3_impro_6_M[train_impro_6+val_impro_6:],
#                                 df_impro_session_4_impro_6_F[train_impro_6+val_impro_6:],
#                                 df_impro_session_4_impro_6_M[train_impro_6+val_impro_6:],
#                                 df_impro_session_5_impro_6_F[train_impro_6+val_impro_6:],
#                                 df_impro_session_5_impro_6_M[train_impro_6+val_impro_6:],
#                                 df_impro_session_1_impro_7_F[train_impro_7+val_impro_7:],
#                                 df_impro_session_1_impro_7_M[train_impro_7+val_impro_7:],
#                                 df_impro_session_2_impro_7_F[train_impro_7+val_impro_7:], 
#                                 df_impro_session_2_impro_7_M[train_impro_7+val_impro_7:], 
#                                 df_impro_session_3_impro_7_F[train_impro_7+val_impro_7:],
#                                 df_impro_session_3_impro_7_M[train_impro_7+val_impro_7:],
#                                 df_impro_session_4_impro_7_F[train_impro_7+val_impro_7:],
#                                 df_impro_session_4_impro_7_M[train_impro_7+val_impro_7:],
#                                 df_impro_session_5_impro_7_F[train_impro_7+val_impro_7:],
#                                 df_impro_session_5_impro_7_M[train_impro_7+val_impro_7:],
#                                 df_impro_session_2_impro_8_F[train_impro_8+val_impro_8:], 
#                                 df_impro_session_2_impro_8_M[train_impro_8+val_impro_8:], 
#                                 df_impro_session_3_impro_8_F[train_impro_8+val_impro_8:],
#                                 df_impro_session_3_impro_8_M[train_impro_8+val_impro_8:],
#                                 df_impro_session_4_impro_8_F[train_impro_8+val_impro_8:],
#                                 df_impro_session_4_impro_8_M[train_impro_8+val_impro_8:],
#                                 df_impro_session_5_impro_8_F[train_impro_8+val_impro_8:],
#                                 df_impro_session_5_impro_8_M[train_impro_8+val_impro_8:]]) 
#     return df_impro_train, df_impro_val, df_impro_test

# df_train_improv, df_val_improv, df_test_improv = impro_splits()
# df_train_scripted, df_val_scripted, df_test_scripted = scripted_splits()

# import audb
# import audiofile
# import opensmile

# labels = list(set(Counter(final_df["EDA"]).keys())) # there are 34 labels
# labels_to_num_mapping = {}
# for i, label in enumerate(labels):
#     labels_to_num_mapping[label] = i
   
# improv_scripted = "scripted"
# if improv_scripted == "improv":
#     df_train = df_train_improv
#     df_val = df_val_improv
#     df_test = df_test_improv

#     df_train_other = df_train_scripted
#     df_val_other = df_val_scripted
#     df_test_other = df_test_scripted
# else:
#     df_train = df_train_scripted
#     df_val = df_val_scripted
#     df_test = df_test_scripted

#     df_train_other = df_train_improv
#     df_val_other = df_val_improv
#     df_test_other = df_test_improv

# smile = opensmile.Smile(
#     feature_set=opensmile.FeatureSet.eGeMAPSv02,
#     feature_level=opensmile.FeatureLevel.Functionals,
# )

# def get_audio_gemap_features(df_split):
#     audio = []
#     for i, row in df_split.iterrows():
#         file_path = f"IEMOCAP_full_release/Session{row['session_number']}/dialog/wav/{row['filename']}.wav"
#         waveform, sample_rate = torchaudio.load(f"IEMOCAP_full_release/Session{row['session_number']}/dialog/wav/{row['filename']}.wav")
#         start, end = row["start"], row["end"]
#         # assert sample_rate == 16000
#         frame_start = round(start*sample_rate)
#         frame_end = round(end*sample_rate)
#         clipped_waveform = waveform[:,frame_start:frame_end].squeeze().cpu().detach().numpy()
#         features = smile.process_signal(clipped_waveform, sample_rate)
#         audio.append(features)
#     return audio

# train_audio, train_labels = get_audio_gemap_features(df_train), list(df_train["EDA"])
# val_audio, val_labels = get_audio_gemap_features(df_val), list(df_val["EDA"])
# test_audio, test_labels = get_audio_gemap_features(df_test), list(df_test["EDA"])

# type(train_audio)

# train_audio = pd.DataFrame(np.array(train_audio).squeeze(), columns=train_audio[0].columns)
# val_audio = pd.DataFrame(np.array(val_audio).squeeze(), columns=val_audio[0].columns)
# test_audio = pd.DataFrame(np.array(test_audio).squeeze(), columns=test_audio[0].columns)

# train_audio.to_csv("gemaps_dataset_train_scripted.csv")
# test_audio.to_csv("gemaps_dataset_val_scripted.csv")
# val_audio.to_csv("gemaps_dataset_test_scripted.csv")

# train_audio_other, train_labels_other = get_audio_gemap_features(df_train_other), list(df_train_other["EDA"])
# val_audio_other, val_labels_other = get_audio_gemap_features(df_val_other), list(df_val_other["EDA"])
# test_audio_other, test_labels_other = get_audio_gemap_features(df_test_other), list(df_test_other["EDA"])

# train_audio_other = pd.DataFrame(np.array(train_audio_other).squeeze(), columns=train_audio_other[0].columns)
# val_audio_other = pd.DataFrame(np.array(val_audio_other).squeeze(), columns=val_audio_other[0].columns)
# test_audio_other = pd.DataFrame(np.array(test_audio_other).squeeze(), columns=test_audio_other[0].columns)


# train_audio_other["labels"] = train_labels_other
# test_audio_other["labels"] = test_labels_other
# val_audio_other["labels"] = val_labels_other

# train_audio_other.to_csv("gemaps_dataset_train_improv.csv")
# test_audio_other.to_csv("gemaps_dataset_val_improv.csv")
# val_audio_other.to_csv("gemaps_dataset_test_improv.csv")

# Aside

In [None]:
# model_card = "facebook/wav2vec2-base"
# improv_scripted = "improv"
# output_dir = f"./results_audio/wav2vec2_base_{improv_scripted}/"

labels = list(set(Counter(final_df["EDA"]).keys())) # there are 34 labels
labels_to_num_mapping = {}
for i, label in enumerate(labels):
    labels_to_num_mapping[label] = i

class DialogActDataset(torch.utils.data.Dataset):
    def __init__(self, feature_extractor_checkpoint, filenames, starts, ends, labels):
        # self.encodings = encodings
        self.feature_extractor_checkpoint = feature_extractor_checkpoint
        self.filenames = filenames
        self.starts = starts
        self.ends = ends
        self.labels = labels

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.filenames[idx])
        assert sample_rate == 16000
        start, end = self.starts[idx], self.ends[idx]
        frame_start = round(start*sample_rate)
        frame_end = round(end*sample_rate)
        audio_feature_tensor = self.feature_extractor_checkpoint(waveform[:,frame_start:frame_end].squeeze(), max_length=16000, sampling_rate=sample_rate, return_tensors="pt", truncation=True)
        
        # item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        # label = torch.tensor(labels_to_num_mapping[self.labels[idx]])
   
        item = {key: torch.tensor(val[idx]) for key, val in audio_feature_tensor.items()}
        item['labels'] = torch.tensor(labels_to_num_mapping[self.labels[idx]])
        # return audio_feature_tensor, label     
        return item

    def __len__(self):
        return len(self.labels)
    

if improv_scripted == "improv":
    df_train = df_train_improv
    df_val = df_val_improv
    df_test = df_test_improv

    df_train_other = df_train_scripted
    df_val_other = df_val_scripted
    df_test_other = df_test_scripted
else:
    df_train = df_train_scripted
    df_val = df_val_scripted
    df_test = df_test_scripted

    df_train_other = df_train_improv
    df_val_other = df_val_improv
    df_test_other = df_test_improv

feature_extractor = AutoFeatureExtractor.from_pretrained(model_card)
model = AutoModelForAudioClassification.from_pretrained(model_card)

def get_audio(df_split):
    audio = []
    for i, row in df_split.iterrows():
        file_path = f"IEMOCAP_full_release/Session{row['session_number']}/dialog/wav/{row['filename']}.wav"
        waveform, sample_rate = torchaudio.load(f"IEMOCAP_full_release/Session{row['session_number']}/dialog/wav/{row['filename']}.wav")
        start, end = row["start"], row["end"]
        frame_start = round(start*sample_rate)
        frame_end = round(end*sample_rate)
        
        smile.process_signal(signal, sampling_rate)
        audio.append({"file_path": file_path, "start": start, "end": end})
    return pd.DataFrame(audio)

train_audio, train_labels = get_audio(df_train), list(df_train["EDA"])
val_audio, val_labels = get_audio(df_val), list(df_val["EDA"])
test_audio, test_labels = get_audio(df_test), list(df_test["EDA"])

smile = opensmile.Smile(
    feature_set=opensmile.FeatureSet.eGeMAPSv02,
    feature_level=opensmile.FeatureLevel.Functionals,
)
len(smile.feature_names)

smile.process_signal(signal,sampling_rate)
# train_audio_other, train_labels_other = get_audio(df_train_other), list(df_train_other["EDA"])
# val_audio_other, val_labels_other = get_audio(df_val_other), list(df_val_other["EDA"])
# test_audio_other, test_labels_other = get_audio(df_test_other), list(df_test_other["EDA"])

# train_dataset = DialogActDataset(feature_extractor_checkpoint=feature_extractor, filenames=train_audio["file_path"], starts=train_audio["start"], ends=train_audio["end"], labels=train_labels)
# val_dataset = DialogActDataset(feature_extractor_checkpoint=feature_extractor, filenames=val_audio["file_path"], starts=val_audio["start"], ends=val_audio["end"], labels=val_labels)
# test_dataset = DialogActDataset(feature_extractor_checkpoint=feature_extractor, filenames=test_audio["file_path"], starts=test_audio["start"], ends=test_audio["end"], labels=test_labels)

# # # tokenizer = AutoTokenizer.from_pretrained(model_card)

# # # train_encodings = tokenizer(train_texts, truncation=True, padding=True)
# # # val_encodings = tokenizer(val_texts, truncation=True, padding=True)
# # # test_encodings = tokenizer(test_texts, truncation=True, padding=True)

# # # train_dataset = DialogActDataset(train_encodings, train_labels)
# # # val_dataset = DialogActDataset(val_encodings, val_labels)
# # # test_dataset = DialogActDataset(test_encodings, test_labels)

# # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# # # model = AutoModelForSequenceClassification.from_pretrained(model_card, num_labels=len(labels))

# metric = evaluate.load("accuracy")

# def compute_metrics(eval_pred):
#     logits, labels = eval_pred
#     predictions = np.argmax(logits, axis=-1)
#     return metric.compute(predictions=predictions, references=labels)


# training_args = TrainingArguments(
#     output_dir=output_dir,
#     learning_rate=2e-5,
#     per_device_train_batch_size=1,
#     per_device_eval_batch_size=1,
#     num_train_epochs=5,
#     weight_decay=0.01#,
#     # eval_strategy="epoch"
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     # tokenizer=tokenizer,
#     # data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

# trainer.train()

In [None]:
train_audio, train_labels = [torchaudio.load(f"IEMOCAP_full_release/Session{utt['session_number']}/dialog/wav/{utt['filename']}.wav")[0] for utt in list(df_train)], list(df_train["EDA"])
def get_audio(df_split):
    audio = []
    for i, row in df_split.iterrows():
        waveform, sample_rate = torchaudio.load(f"IEMOCAP_full_release/Session{row['session_number']}/dialog/wav/{row['filename']}.wav")
        start, end = row["start"], row["end"]
        assert sample_rate == 16000
        frame_start = round(start*sample_rate)
        frame_end = round(end*sample_rate)
        audio.append(waveform[:,frame_start:frame_end].squeeze())



In [None]:
import IPython.display as ipd
import torchaudio
import torchaudio.transforms as T

session_number = 1
filename = "Ses01M_impro07"
audio_file = f"IEMOCAP_full_release/Session{session_number}/dialog/wav/{filename}.wav"
# ipd.Audio(audio_file) # load and play the ../peer-mediation-script-audio-files/Jacob_1.wav file

waveform, sample_rate = torchaudio.load(f"IEMOCAP_full_release/Session{session_number}/dialog/wav/{filename}.wav") # torchaudio.load() takes in an audio file and outputs to variables named here "waveform" and "sample_rate"
assert sample_rate == 16000

ipd.Audio(waveform.squeeze().numpy(), rate=sample_rate) # load a local WAV file


In [None]:
start = round(7.6300*sample_rate)
end = round(8.5700*sample_rate)

ipd.Audio(waveform[:,start:end].squeeze().numpy(), rate=sample_rate) # load a local WAV file