In [1]:
import re 
import os
import pandas as pd
import numpy as np
import json
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import evaluate
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # cool M2 chip GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # for NVIDIA GPUs
print(f"Device: {device}")

# 1. Access the corresponding .txt, .wav, and .avi files for each EDA label
# Extract the conversation filename and speaker informationfrom the dataset
eda_df = pd.read_csv("eda_iemocap_no_utts_dataset.csv")
eda_df = eda_df[["speaker", "utt_id", "EDA"]]
filename_ids = []
speaker_M_F = []
session_numbers = []
for i, row in eda_df.iterrows():
    match = re.search(r"b'(Ses(\d+)[MF]_.+\d+.*)_([MF])", row["speaker"])
    filename_ids.append(match.group(1))
    session_numbers.append(int(match.group(2))) 
    speaker_M_F.append(match.group(3))
eda_df = eda_df.drop(columns=["speaker"])
eda_df["filename"] = filename_ids
eda_df["filename"] = eda_df["filename"].astype(str)
eda_df["session_number"] = session_numbers
eda_df["session_number"] = eda_df["session_number"].astype(int)
eda_df["speaker"] = speaker_M_F
eda_df["speaker"] = eda_df["speaker"].astype(str)
eda_df["utt_id"] = eda_df["utt_id"].astype(int)

# Access transcipt files based on filename
utt_df = []
root_dir = "IEMOCAP_full_release/"
for i in range(1, 6):
    directory = os.path.join(root_dir, f"Session{i}/dialog/transcriptions/")
    for entry in os.scandir(directory):  
        if entry.is_file() and entry.path.endswith(".txt"):  # check if it's a file
            try:
                with open(entry.path, "r") as file:
                    filename = entry.path.split("/")[-1][:-4]
                    lines = file.readlines()
                    for order, line in enumerate(lines):
                        speaker_info, utterance = line.split(":")[0], line.split(":")[1]
                        pattern = r"(F|M)(\d+)\s\[(\d+\.\d+)-(\d+\.\d+)\]"
                        match = re.search(pattern, speaker_info)
                        if match is None:
                            continue
                        speaker_f_m = match.group(1)
                        utt_id = match.group(2)
                        start = match.group(3)
                        end = match.group(4)
                        utt_df.append({"utt_id": int(utt_id), "filename": str(filename), "start": float(start), "end": float(end), "speaker": str(speaker_f_m.strip()), "utterance": utterance.strip(), "session_number": int(i), "original_order": order})
            except:
                #print(entry.path) # these are meta files with ._ prepended to text file name
                continue
utt_df = pd.DataFrame(utt_df)
# Combine the EDA and utterances together
final_df = pd.merge(eda_df, utt_df, on=["utt_id", "session_number", "filename", "speaker"])
final_df

def scripted_splits():    
    # session 1
    df_scripted_session_1_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_1_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_1_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_1_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_1_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_1_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 1)].sort_values(by=['filename', 'original_order'])

    # session 2
    df_scripted_session_2_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_2_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_2_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_2_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_2_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_2_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 2)].sort_values(by=['filename', 'original_order'])

    # session 3
    df_scripted_session_3_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_3_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_3_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_3_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_3_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_3_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 3)].sort_values(by=['filename', 'original_order'])

    # session 4
    df_scripted_session_4_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_4_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_4_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_4_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_4_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_4_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 4)].sort_values(by=['filename', 'original_order'])

    # session 5
    df_scripted_session_5_script_1_M = final_df[(final_df["filename"].str.contains("M_script01")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_5_script_1_F = final_df[(final_df["filename"].str.contains("F_script01")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_5_script_2_M = final_df[(final_df["filename"].str.contains("M_script02")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_5_script_2_F = final_df[(final_df["filename"].str.contains("F_script02")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_5_script_3_M = final_df[(final_df["filename"].str.contains("M_script03")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])
    df_scripted_session_5_script_3_F = final_df[(final_df["filename"].str.contains("F_script03")) & (final_df["session_number"] == 5)].sort_values(by=['filename', 'original_order'])

    # need to this for each script across sections because although they are the same script, the lines are not memorized perfectly and so there are some length differences
    min_script_1 = min([len(df_scripted_session_1_script_1_F), len(df_scripted_session_1_script_1_M),\
                        len(df_scripted_session_2_script_1_F), len(df_scripted_session_2_script_1_M),\
                        len(df_scripted_session_3_script_1_F), len(df_scripted_session_3_script_1_M),\
                        len(df_scripted_session_4_script_1_F), len(df_scripted_session_4_script_1_M),\
                        len(df_scripted_session_5_script_1_F), len(df_scripted_session_5_script_1_M)])

    min_script_2 = min([len(df_scripted_session_1_script_2_F), len(df_scripted_session_1_script_2_M),\
                        len(df_scripted_session_2_script_2_F), len(df_scripted_session_2_script_2_M),\
                        len(df_scripted_session_3_script_2_F), len(df_scripted_session_3_script_2_M),\
                        len(df_scripted_session_4_script_2_F), len(df_scripted_session_4_script_2_M),\
                        len(df_scripted_session_5_script_2_F), len(df_scripted_session_5_script_2_M)])

    min_script_3 = min([len(df_scripted_session_1_script_3_F), len(df_scripted_session_1_script_3_M),\
                        len(df_scripted_session_2_script_3_F), len(df_scripted_session_2_script_3_M),\
                        len(df_scripted_session_3_script_3_F), len(df_scripted_session_3_script_3_M),\
                        len(df_scripted_session_4_script_3_F), len(df_scripted_session_4_script_3_M),\
                        len(df_scripted_session_5_script_3_F), len(df_scripted_session_5_script_3_M)])

    train_script_1 = int(min_script_1*0.8)
    val_script_1 = (min_script_1-train_script_1)//2

    train_script_2 = int(min_script_2*0.8)
    val_script_2 = (min_script_2-train_script_2)//2

    train_script_3 = int(min_script_3*0.8)
    val_script_3 = (min_script_3-train_script_3)//2

    #df_scripted_session_1_script_1_F.sample(frac=1) # going to ignore this for now and not shuffle within each script because that means merging sessions... to confusing and not perfect matching
    df_scripted_train = pd.concat([df_scripted_session_1_script_1_F[:train_script_1],
                                df_scripted_session_1_script_1_M[:train_script_1],
                                df_scripted_session_2_script_1_F[:train_script_1], 
                                df_scripted_session_2_script_1_M[:train_script_1], 
                                df_scripted_session_3_script_1_F[:train_script_1],
                                df_scripted_session_3_script_1_M[:train_script_1],
                                df_scripted_session_4_script_1_F[:train_script_1],
                                df_scripted_session_4_script_1_M[:train_script_1],
                                df_scripted_session_5_script_1_F[:train_script_1],
                                df_scripted_session_5_script_1_M[:train_script_1],
                                df_scripted_session_1_script_2_F[:train_script_2],
                                df_scripted_session_1_script_2_M[:train_script_2],
                                df_scripted_session_2_script_2_F[:train_script_2], 
                                df_scripted_session_2_script_2_M[:train_script_2], 
                                df_scripted_session_3_script_2_F[:train_script_2],
                                df_scripted_session_3_script_2_M[:train_script_2],
                                df_scripted_session_4_script_2_F[:train_script_2],
                                df_scripted_session_4_script_2_M[:train_script_2],
                                df_scripted_session_5_script_2_F[:train_script_2],
                                df_scripted_session_5_script_2_M[:train_script_2],
                                df_scripted_session_1_script_3_F[:train_script_3],
                                df_scripted_session_1_script_3_M[:train_script_3],
                                df_scripted_session_2_script_3_F[:train_script_3], 
                                df_scripted_session_2_script_3_M[:train_script_3], 
                                df_scripted_session_3_script_3_F[:train_script_3],
                                df_scripted_session_3_script_3_M[:train_script_3],
                                df_scripted_session_4_script_3_F[:train_script_3],
                                df_scripted_session_4_script_3_M[:train_script_3],
                                df_scripted_session_5_script_3_F[:train_script_3],
                                df_scripted_session_5_script_3_M[:train_script_3]])

    df_scripted_val = pd.concat([df_scripted_session_1_script_1_F[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_1_script_1_M[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_2_script_1_F[train_script_1: train_script_1+val_script_1], 
                                df_scripted_session_2_script_1_M[train_script_1: train_script_1+val_script_1], 
                                df_scripted_session_3_script_1_F[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_3_script_1_M[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_4_script_1_F[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_4_script_1_M[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_5_script_1_F[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_5_script_1_M[train_script_1: train_script_1+val_script_1],
                                df_scripted_session_1_script_2_F[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_1_script_2_M[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_2_script_2_F[train_script_2: train_script_2+val_script_2], 
                                df_scripted_session_2_script_2_M[train_script_2: train_script_2+val_script_2], 
                                df_scripted_session_3_script_2_F[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_3_script_2_M[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_4_script_2_F[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_4_script_2_M[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_5_script_2_F[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_5_script_2_M[train_script_2: train_script_2+val_script_2],
                                df_scripted_session_1_script_3_F[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_1_script_3_M[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_2_script_3_F[train_script_3: train_script_3+val_script_3], 
                                df_scripted_session_2_script_3_M[train_script_3: train_script_3+val_script_3], 
                                df_scripted_session_3_script_3_F[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_3_script_3_M[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_4_script_3_F[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_4_script_3_M[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_5_script_3_F[train_script_3: train_script_3+val_script_3],
                                df_scripted_session_5_script_3_M[train_script_3: train_script_3+val_script_3]])

    df_scripted_test = pd.concat([df_scripted_session_1_script_1_F[train_script_1+val_script_1:],
                                df_scripted_session_1_script_1_M[train_script_1+val_script_1:],
                                df_scripted_session_2_script_1_F[train_script_1+val_script_1:], 
                                df_scripted_session_2_script_1_M[train_script_1+val_script_1:], 
                                df_scripted_session_3_script_1_F[train_script_1+val_script_1:],
                                df_scripted_session_3_script_1_M[train_script_1+val_script_1:],
                                df_scripted_session_4_script_1_F[train_script_1+val_script_1:],
                                df_scripted_session_5_script_1_F[train_script_1+val_script_1:],
                                df_scripted_session_5_script_1_M[train_script_1+val_script_1:],
                                df_scripted_session_1_script_2_F[train_script_2+val_script_2:],
                                df_scripted_session_1_script_2_M[train_script_2+val_script_2:],
                                df_scripted_session_2_script_2_F[train_script_2+val_script_2:], 
                                df_scripted_session_2_script_2_M[train_script_2+val_script_2:], 
                                df_scripted_session_3_script_2_F[train_script_2+val_script_2:],
                                df_scripted_session_3_script_2_M[train_script_2+val_script_2:],
                                df_scripted_session_4_script_2_F[train_script_2+val_script_2:],
                                df_scripted_session_4_script_2_M[train_script_2+val_script_2:],
                                df_scripted_session_5_script_2_F[train_script_2+val_script_2:],
                                df_scripted_session_5_script_2_M[train_script_2+val_script_2:],
                                df_scripted_session_1_script_3_F[train_script_3+val_script_3:],
                                df_scripted_session_1_script_3_M[train_script_3+val_script_3:],
                                df_scripted_session_2_script_3_F[train_script_3+val_script_3:], 
                                df_scripted_session_2_script_3_M[train_script_3+val_script_3:], 
                                df_scripted_session_3_script_3_F[train_script_3+val_script_3:],
                                df_scripted_session_3_script_3_M[train_script_3+val_script_3:],
                                df_scripted_session_4_script_3_F[train_script_3+val_script_3:],
                                df_scripted_session_4_script_3_M[train_script_3+val_script_3:],
                                df_scripted_session_5_script_3_F[train_script_3+val_script_3:],
                                df_scripted_session_5_script_3_M[train_script_3+val_script_3:]])

    # need to split separately across the different sessions since they all have the same scripts
    # df_session_1_script_1 = pd.merge(df_scripted_session_1_script_1_F, df_scripted_session_1_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
    # df_session_2_script_1 = pd.merge(df_scripted_session_2_script_1_F, df_scripted_session_2_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
    # df_session_3_script_1 = pd.merge(df_scripted_session_3_script_1_F, df_scripted_session_3_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
    # df_session_4_script_1 = pd.merge(df_scripted_session_4_script_1_F, df_scripted_session_4_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])
    # df_session_5_script_1 = pd.merge(df_scripted_session_5_script_1_F, df_scripted_session_5_script_1_M, on=["utt_id", "session_number", "improv_script_id", "speaker"])

    return df_scripted_train, df_scripted_val, df_scripted_test

def impro_splits():    
    # session 1
    df_impro_session_1_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])
    df_impro_session_1_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 1)].sort_values(by=['filename', 'original_order'])

    # session 2
    df_impro_session_2_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])
    df_impro_session_2_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 2)].sort_values(by=['filename', 'original_order'])

    # session 3
    df_impro_session_3_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])
    df_impro_session_3_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 3)].sort_values(by=['filename', 'original_order'])

    # session 4
    df_impro_session_4_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])
    df_impro_session_4_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 4)].sort_values(by=['filename', 'original_order'])

    # session 5
    df_impro_session_5_impro_1_M = final_df[(final_df['filename'].str.contains('M_impro01')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_1_F = final_df[(final_df['filename'].str.contains('F_impro01')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_2_M = final_df[(final_df['filename'].str.contains('M_impro02')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_2_F = final_df[(final_df['filename'].str.contains('F_impro02')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_3_M = final_df[(final_df['filename'].str.contains('M_impro03')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_3_F = final_df[(final_df['filename'].str.contains('F_impro03')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_4_M = final_df[(final_df['filename'].str.contains('M_impro04')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_4_F = final_df[(final_df['filename'].str.contains('F_impro04')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_5_M = final_df[(final_df['filename'].str.contains('M_impro05')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_5_F = final_df[(final_df['filename'].str.contains('F_impro05')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_6_M = final_df[(final_df['filename'].str.contains('M_impro06')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_6_F = final_df[(final_df['filename'].str.contains('F_impro06')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_7_M = final_df[(final_df['filename'].str.contains('M_impro07')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_7_F = final_df[(final_df['filename'].str.contains('F_impro07')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_8_M = final_df[(final_df['filename'].str.contains('M_impro08')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])
    df_impro_session_5_impro_8_F = final_df[(final_df['filename'].str.contains('F_impro08')) & (final_df['session_number'] == 5)].sort_values(by=['filename', 'original_order'])

    # need to this for each impro across sections because although they are the same impro, the lines are not memorized perfectly and so there are some length differences
    min_impro_1 = min([len(df_impro_session_1_impro_1_F), len(df_impro_session_1_impro_1_M),\
                        len(df_impro_session_2_impro_1_F), len(df_impro_session_2_impro_1_M),\
                        len(df_impro_session_3_impro_1_F), len(df_impro_session_3_impro_1_M),\
                        len(df_impro_session_4_impro_1_F), len(df_impro_session_4_impro_1_M),\
                        len(df_impro_session_5_impro_1_F), len(df_impro_session_5_impro_1_M)])

    min_impro_2 = min([len(df_impro_session_1_impro_2_F), len(df_impro_session_1_impro_2_M),\
                        len(df_impro_session_2_impro_2_F), len(df_impro_session_2_impro_2_M),\
                        len(df_impro_session_3_impro_2_F), len(df_impro_session_3_impro_2_M),\
                        len(df_impro_session_4_impro_2_F), len(df_impro_session_4_impro_2_M),\
                        len(df_impro_session_5_impro_2_F), len(df_impro_session_5_impro_2_M)])

    min_impro_3 = min([len(df_impro_session_1_impro_3_F), len(df_impro_session_1_impro_3_M),\
                        len(df_impro_session_2_impro_3_F), len(df_impro_session_2_impro_3_M),\
                        len(df_impro_session_3_impro_3_F), len(df_impro_session_3_impro_3_M),\
                        len(df_impro_session_4_impro_3_F), len(df_impro_session_4_impro_3_M),\
                        len(df_impro_session_5_impro_3_F), len(df_impro_session_5_impro_3_M)])

    min_impro_4 = min([len(df_impro_session_1_impro_4_F), len(df_impro_session_1_impro_4_M),\
                        len(df_impro_session_2_impro_4_F), len(df_impro_session_2_impro_4_M),\
                        len(df_impro_session_3_impro_4_F), len(df_impro_session_3_impro_4_M),\
                        len(df_impro_session_4_impro_4_F), len(df_impro_session_4_impro_4_M),\
                        len(df_impro_session_5_impro_4_F), len(df_impro_session_5_impro_4_M)])

    min_impro_5 = min([len(df_impro_session_1_impro_5_F), len(df_impro_session_1_impro_5_M),\
                            len(df_impro_session_2_impro_5_F), len(df_impro_session_2_impro_5_M),\
                            len(df_impro_session_3_impro_5_F), len(df_impro_session_3_impro_5_M),\
                            len(df_impro_session_4_impro_5_F), len(df_impro_session_4_impro_5_M),\
                            len(df_impro_session_5_impro_5_F), len(df_impro_session_5_impro_5_M)])

    min_impro_6 = min([len(df_impro_session_1_impro_6_F), len(df_impro_session_1_impro_6_M),\
                            len(df_impro_session_2_impro_6_F), len(df_impro_session_2_impro_6_M),\
                            len(df_impro_session_3_impro_6_F), len(df_impro_session_3_impro_6_M),\
                            len(df_impro_session_4_impro_6_F), len(df_impro_session_4_impro_6_M),\
                            len(df_impro_session_5_impro_6_F), len(df_impro_session_5_impro_6_M)])

    min_impro_7 = min([len(df_impro_session_1_impro_7_F), len(df_impro_session_1_impro_7_M),\
                            len(df_impro_session_2_impro_7_F), len(df_impro_session_2_impro_7_M),\
                            len(df_impro_session_3_impro_7_F), len(df_impro_session_3_impro_7_M),\
                            len(df_impro_session_4_impro_7_F), len(df_impro_session_4_impro_7_M),\
                            len(df_impro_session_5_impro_7_F), len(df_impro_session_5_impro_7_M)])

    min_impro_8 = min([len(df_impro_session_2_impro_8_F), len(df_impro_session_2_impro_8_M),\
                            len(df_impro_session_3_impro_8_F), len(df_impro_session_3_impro_8_M),\
                            len(df_impro_session_4_impro_8_F), len(df_impro_session_4_impro_8_M),\
                            len(df_impro_session_5_impro_8_F), len(df_impro_session_5_impro_8_M)])

    train_impro_1 = int(min_impro_1*0.8)
    val_impro_1 = (min_impro_1-train_impro_1)//2

    train_impro_2 = int(min_impro_2*0.8)
    val_impro_2 = (min_impro_2-train_impro_2)//2

    train_impro_3 = int(min_impro_3*0.8)
    val_impro_3 = (min_impro_3-train_impro_3)//2

    train_impro_4 = int(min_impro_4*0.8)
    val_impro_4 = (min_impro_4-train_impro_4)//2

    train_impro_5 = int(min_impro_5*0.8)
    val_impro_5 = (min_impro_5-train_impro_5)//2

    train_impro_6 = int(min_impro_6*0.8)
    val_impro_6 = (min_impro_6-train_impro_6)//2

    train_impro_7 = int(min_impro_7*0.8)
    val_impro_7 = (min_impro_7-train_impro_7)//2

    train_impro_8 = int(min_impro_8*0.8)
    val_impro_8 = (min_impro_8-train_impro_8)//2


    #df_impro_session_1_impro_1_F.sample(frac=1) # going to ignore this for now and not shuffle within each impro because that means merging sessions... to confusing and not perfect matching
    df_impro_train = pd.concat([df_impro_session_1_impro_1_F[:train_impro_1],
                                df_impro_session_1_impro_1_M[:train_impro_1],
                                df_impro_session_2_impro_1_F[:train_impro_1], 
                                df_impro_session_2_impro_1_M[:train_impro_1], 
                                df_impro_session_3_impro_1_F[:train_impro_1],
                                df_impro_session_3_impro_1_M[:train_impro_1],
                                df_impro_session_4_impro_1_F[:train_impro_1],
                                df_impro_session_4_impro_1_M[:train_impro_1],
                                df_impro_session_5_impro_1_F[:train_impro_1],
                                df_impro_session_5_impro_1_M[:train_impro_1],
                                df_impro_session_1_impro_2_F[:train_impro_2],
                                df_impro_session_1_impro_2_M[:train_impro_2],
                                df_impro_session_2_impro_2_F[:train_impro_2], 
                                df_impro_session_2_impro_2_M[:train_impro_2], 
                                df_impro_session_3_impro_2_F[:train_impro_2],
                                df_impro_session_3_impro_2_M[:train_impro_2],
                                df_impro_session_4_impro_2_F[:train_impro_2],
                                df_impro_session_4_impro_2_M[:train_impro_2],
                                df_impro_session_5_impro_2_F[:train_impro_2],
                                df_impro_session_5_impro_2_M[:train_impro_2],
                                df_impro_session_1_impro_3_F[:train_impro_3],
                                df_impro_session_1_impro_3_M[:train_impro_3],
                                df_impro_session_2_impro_3_F[:train_impro_3], 
                                df_impro_session_2_impro_3_M[:train_impro_3], 
                                df_impro_session_3_impro_3_F[:train_impro_3],
                                df_impro_session_3_impro_3_M[:train_impro_3],
                                df_impro_session_4_impro_3_F[:train_impro_3],
                                df_impro_session_4_impro_3_M[:train_impro_3],
                                df_impro_session_5_impro_3_F[:train_impro_3],
                                df_impro_session_5_impro_3_M[:train_impro_3],
                                df_impro_session_1_impro_4_F[:train_impro_4],
                                df_impro_session_1_impro_4_M[:train_impro_4],
                                df_impro_session_2_impro_4_F[:train_impro_4], 
                                df_impro_session_2_impro_4_M[:train_impro_4], 
                                df_impro_session_3_impro_4_F[:train_impro_4],
                                df_impro_session_3_impro_4_M[:train_impro_4],
                                df_impro_session_4_impro_4_F[:train_impro_4],
                                df_impro_session_4_impro_4_M[:train_impro_4],
                                df_impro_session_5_impro_4_F[:train_impro_4],
                                df_impro_session_5_impro_4_M[:train_impro_4],
                                df_impro_session_1_impro_5_F[:train_impro_5],
                                df_impro_session_1_impro_5_M[:train_impro_5],
                                df_impro_session_2_impro_5_F[:train_impro_5], 
                                df_impro_session_2_impro_5_M[:train_impro_5], 
                                df_impro_session_3_impro_5_F[:train_impro_5],
                                df_impro_session_3_impro_5_M[:train_impro_5],
                                df_impro_session_4_impro_5_F[:train_impro_5],
                                df_impro_session_4_impro_5_M[:train_impro_5],
                                df_impro_session_5_impro_5_F[:train_impro_5],
                                df_impro_session_5_impro_5_M[:train_impro_5],
                                df_impro_session_1_impro_6_F[:train_impro_6],
                                df_impro_session_1_impro_6_M[:train_impro_6],
                                df_impro_session_2_impro_6_F[:train_impro_6], 
                                df_impro_session_2_impro_6_M[:train_impro_6], 
                                df_impro_session_3_impro_6_F[:train_impro_6],
                                df_impro_session_3_impro_6_M[:train_impro_6],
                                df_impro_session_4_impro_6_F[:train_impro_6],
                                df_impro_session_4_impro_6_M[:train_impro_6],
                                df_impro_session_5_impro_6_F[:train_impro_6],
                                df_impro_session_5_impro_6_M[:train_impro_6],
                                df_impro_session_1_impro_7_F[:train_impro_7],
                                df_impro_session_1_impro_7_M[:train_impro_7],
                                df_impro_session_2_impro_7_F[:train_impro_7], 
                                df_impro_session_2_impro_7_M[:train_impro_7], 
                                df_impro_session_3_impro_7_F[:train_impro_7],
                                df_impro_session_3_impro_7_M[:train_impro_7],
                                df_impro_session_4_impro_7_F[:train_impro_7],
                                df_impro_session_4_impro_7_M[:train_impro_7],
                                df_impro_session_5_impro_7_F[:train_impro_7],
                                df_impro_session_5_impro_7_M[:train_impro_7],
                                df_impro_session_2_impro_8_F[:train_impro_8], 
                                df_impro_session_2_impro_8_M[:train_impro_8], 
                                df_impro_session_3_impro_8_F[:train_impro_8],
                                df_impro_session_3_impro_8_M[:train_impro_8],
                                df_impro_session_4_impro_8_F[:train_impro_8],
                                df_impro_session_4_impro_8_M[:train_impro_8],
                                df_impro_session_5_impro_8_F[:train_impro_8],
                                df_impro_session_5_impro_8_M[:train_impro_8]])  
    df_impro_val = pd.concat([df_impro_session_1_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_1_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_2_impro_1_F[train_impro_1: train_impro_1+val_impro_1], 
                                df_impro_session_2_impro_1_M[train_impro_1: train_impro_1+val_impro_1], 
                                df_impro_session_3_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_3_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_4_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_4_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_5_impro_1_F[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_5_impro_1_M[train_impro_1: train_impro_1+val_impro_1],
                                df_impro_session_1_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_1_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_2_impro_2_F[train_impro_2: train_impro_2+val_impro_2], 
                                df_impro_session_2_impro_2_M[train_impro_2: train_impro_2+val_impro_2], 
                                df_impro_session_3_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_3_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_4_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_4_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_5_impro_2_F[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_5_impro_2_M[train_impro_2: train_impro_2+val_impro_2],
                                df_impro_session_1_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_1_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_2_impro_3_F[train_impro_3: train_impro_3+val_impro_3], 
                                df_impro_session_2_impro_3_M[train_impro_3: train_impro_3+val_impro_3], 
                                df_impro_session_3_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_3_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_4_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_4_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_5_impro_3_F[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_5_impro_3_M[train_impro_3: train_impro_3+val_impro_3],
                                df_impro_session_1_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_1_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_2_impro_4_F[train_impro_4: train_impro_4+val_impro_4], 
                                df_impro_session_2_impro_4_M[train_impro_4: train_impro_4+val_impro_4], 
                                df_impro_session_3_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_3_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_4_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_4_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_5_impro_4_F[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_5_impro_4_M[train_impro_4: train_impro_4+val_impro_4],
                                df_impro_session_1_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_1_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_2_impro_5_F[train_impro_5: train_impro_5+val_impro_5], 
                                df_impro_session_2_impro_5_M[train_impro_5: train_impro_5+val_impro_5], 
                                df_impro_session_3_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_3_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_4_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_4_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_5_impro_5_F[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_5_impro_5_M[train_impro_5: train_impro_5+val_impro_5],
                                df_impro_session_1_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_1_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_2_impro_6_F[train_impro_6: train_impro_6+val_impro_6], 
                                df_impro_session_2_impro_6_M[train_impro_6: train_impro_6+val_impro_6], 
                                df_impro_session_3_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_3_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_4_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_4_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_5_impro_6_F[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_5_impro_6_M[train_impro_6: train_impro_6+val_impro_6],
                                df_impro_session_1_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_1_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_2_impro_7_F[train_impro_7: train_impro_7+val_impro_7], 
                                df_impro_session_2_impro_7_M[train_impro_7: train_impro_7+val_impro_7], 
                                df_impro_session_3_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_3_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_4_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_4_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_5_impro_7_F[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_5_impro_7_M[train_impro_7: train_impro_7+val_impro_7],
                                df_impro_session_2_impro_8_F[train_impro_8: train_impro_8+val_impro_8], 
                                df_impro_session_2_impro_8_M[train_impro_8: train_impro_8+val_impro_8], 
                                df_impro_session_3_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
                                df_impro_session_3_impro_8_M[train_impro_8: train_impro_8+val_impro_8],
                                df_impro_session_4_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
                                df_impro_session_4_impro_8_M[train_impro_8: train_impro_8+val_impro_8],
                                df_impro_session_5_impro_8_F[train_impro_8: train_impro_8+val_impro_8],
                                df_impro_session_5_impro_8_M[train_impro_8: train_impro_8+val_impro_8]]) 
    df_impro_test = pd.concat([df_impro_session_1_impro_1_F[train_impro_1+val_impro_1:],
                                df_impro_session_1_impro_1_M[train_impro_1+val_impro_1:],
                                df_impro_session_2_impro_1_F[train_impro_1+val_impro_1:], 
                                df_impro_session_2_impro_1_M[train_impro_1+val_impro_1:], 
                                df_impro_session_3_impro_1_F[train_impro_1+val_impro_1:],
                                df_impro_session_3_impro_1_M[train_impro_1+val_impro_1:],
                                df_impro_session_4_impro_1_F[train_impro_1+val_impro_1:],
                                df_impro_session_5_impro_1_F[train_impro_1+val_impro_1:],
                                df_impro_session_5_impro_1_M[train_impro_1+val_impro_1:],
                                df_impro_session_1_impro_2_F[train_impro_2+val_impro_2:],
                                df_impro_session_1_impro_2_M[train_impro_2+val_impro_2:],
                                df_impro_session_2_impro_2_F[train_impro_2+val_impro_2:], 
                                df_impro_session_2_impro_2_M[train_impro_2+val_impro_2:], 
                                df_impro_session_3_impro_2_F[train_impro_2+val_impro_2:],
                                df_impro_session_3_impro_2_M[train_impro_2+val_impro_2:],
                                df_impro_session_4_impro_2_F[train_impro_2+val_impro_2:],
                                df_impro_session_4_impro_2_M[train_impro_2+val_impro_2:],
                                df_impro_session_5_impro_2_F[train_impro_2+val_impro_2:],
                                df_impro_session_5_impro_2_M[train_impro_2+val_impro_2:],
                                df_impro_session_1_impro_3_F[train_impro_3+val_impro_3:],
                                df_impro_session_1_impro_3_M[train_impro_3+val_impro_3:],
                                df_impro_session_2_impro_3_F[train_impro_3+val_impro_3:], 
                                df_impro_session_2_impro_3_M[train_impro_3+val_impro_3:], 
                                df_impro_session_3_impro_3_F[train_impro_3+val_impro_3:],
                                df_impro_session_3_impro_3_M[train_impro_3+val_impro_3:],
                                df_impro_session_4_impro_3_F[train_impro_3+val_impro_3:],
                                df_impro_session_4_impro_3_M[train_impro_3+val_impro_3:],
                                df_impro_session_5_impro_3_F[train_impro_3+val_impro_3:],
                                df_impro_session_5_impro_3_M[train_impro_3+val_impro_3:],
                                df_impro_session_1_impro_4_F[train_impro_4+val_impro_4:],
                                df_impro_session_1_impro_4_M[train_impro_4+val_impro_4:],
                                df_impro_session_2_impro_4_F[train_impro_4+val_impro_4:], 
                                df_impro_session_2_impro_4_M[train_impro_4+val_impro_4:], 
                                df_impro_session_3_impro_4_F[train_impro_4+val_impro_4:],
                                df_impro_session_3_impro_4_M[train_impro_4+val_impro_4:],
                                df_impro_session_4_impro_4_F[train_impro_4+val_impro_4:],
                                df_impro_session_4_impro_4_M[train_impro_4+val_impro_4:],
                                df_impro_session_5_impro_4_F[train_impro_4+val_impro_4:],
                                df_impro_session_5_impro_4_M[train_impro_4+val_impro_4:],
                                df_impro_session_1_impro_5_F[train_impro_5+val_impro_5:],
                                df_impro_session_1_impro_5_M[train_impro_5+val_impro_5:],
                                df_impro_session_2_impro_5_F[train_impro_5+val_impro_5:], 
                                df_impro_session_2_impro_5_M[train_impro_5+val_impro_5:], 
                                df_impro_session_3_impro_5_F[train_impro_5+val_impro_5:],
                                df_impro_session_3_impro_5_M[train_impro_5+val_impro_5:],
                                df_impro_session_4_impro_5_F[train_impro_5+val_impro_5:],
                                df_impro_session_4_impro_5_M[train_impro_5+val_impro_5:],
                                df_impro_session_5_impro_5_F[train_impro_5+val_impro_5:],
                                df_impro_session_5_impro_5_M[train_impro_5+val_impro_5:],
                                df_impro_session_1_impro_6_F[train_impro_6+val_impro_6:],
                                df_impro_session_1_impro_6_M[train_impro_6+val_impro_6:],
                                df_impro_session_2_impro_6_F[train_impro_6+val_impro_6:], 
                                df_impro_session_2_impro_6_M[train_impro_6+val_impro_6:], 
                                df_impro_session_3_impro_6_F[train_impro_6+val_impro_6:],
                                df_impro_session_3_impro_6_M[train_impro_6+val_impro_6:],
                                df_impro_session_4_impro_6_F[train_impro_6+val_impro_6:],
                                df_impro_session_4_impro_6_M[train_impro_6+val_impro_6:],
                                df_impro_session_5_impro_6_F[train_impro_6+val_impro_6:],
                                df_impro_session_5_impro_6_M[train_impro_6+val_impro_6:],
                                df_impro_session_1_impro_7_F[train_impro_7+val_impro_7:],
                                df_impro_session_1_impro_7_M[train_impro_7+val_impro_7:],
                                df_impro_session_2_impro_7_F[train_impro_7+val_impro_7:], 
                                df_impro_session_2_impro_7_M[train_impro_7+val_impro_7:], 
                                df_impro_session_3_impro_7_F[train_impro_7+val_impro_7:],
                                df_impro_session_3_impro_7_M[train_impro_7+val_impro_7:],
                                df_impro_session_4_impro_7_F[train_impro_7+val_impro_7:],
                                df_impro_session_4_impro_7_M[train_impro_7+val_impro_7:],
                                df_impro_session_5_impro_7_F[train_impro_7+val_impro_7:],
                                df_impro_session_5_impro_7_M[train_impro_7+val_impro_7:],
                                df_impro_session_2_impro_8_F[train_impro_8+val_impro_8:], 
                                df_impro_session_2_impro_8_M[train_impro_8+val_impro_8:], 
                                df_impro_session_3_impro_8_F[train_impro_8+val_impro_8:],
                                df_impro_session_3_impro_8_M[train_impro_8+val_impro_8:],
                                df_impro_session_4_impro_8_F[train_impro_8+val_impro_8:],
                                df_impro_session_4_impro_8_M[train_impro_8+val_impro_8:],
                                df_impro_session_5_impro_8_F[train_impro_8+val_impro_8:],
                                df_impro_session_5_impro_8_M[train_impro_8+val_impro_8:]]) 
    return df_impro_train, df_impro_val, df_impro_test

df_train_improv, df_val_improv, df_test_improv = impro_splits()
df_train_scripted, df_val_scripted, df_test_scripted = scripted_splits()

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [2]:
model_card = "FacebookAI/roberta-large"
improv_scripted = "improv"
output_dir = f"./results_text/roberta-large_{improv_scripted}/"

labels = list(set(Counter(final_df["EDA"]).keys())) # there are 34 labels
labels_to_num_mapping = {}
for i, label in enumerate(labels):
    labels_to_num_mapping[label] = i

class DialogActDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(labels_to_num_mapping[self.labels[idx]])
        return item     

    def __len__(self):
        return len(self.labels)
    

if improv_scripted == "improv":
    df_train = df_train_improv
    df_val = df_val_improv
    df_test = df_test_improv

    df_train_other = df_train_scripted
    df_val_other = df_val_scripted
    df_test_other = df_test_scripted
else:
    df_train = df_train_scripted
    df_val = df_val_scripted
    df_test = df_test_scripted

    df_train_other = df_train_improv
    df_val_other = df_val_improv
    df_test_other = df_test_improv
    
train_texts, train_labels = [utt.lower() for utt in list(df_train["utterance"])], list(df_train["EDA"])
val_texts, val_labels = [utt.lower() for utt in list(df_val["utterance"])], list(df_val["EDA"])
test_texts, test_labels = [utt.lower() for utt in list(df_test["utterance"])], list(df_test["EDA"])

train_texts_other, train_labels_other = [utt.lower() for utt in list(df_train_other["utterance"])], list(df_train_other["EDA"])
val_texts_other, val_labels_other = [utt.lower() for utt in list(df_val_other["utterance"])], list(df_val_other["EDA"])
test_texts_other, test_labels_other = [utt.lower() for utt in list(df_test_other["utterance"])], list(df_test_other["EDA"])

tokenizer = AutoTokenizer.from_pretrained(model_card)

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

train_dataset = DialogActDataset(train_encodings, train_labels)
val_dataset = DialogActDataset(val_encodings, val_labels)
test_dataset = DialogActDataset(test_encodings, test_labels)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(model_card, num_labels=len(labels))

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01#,
    # eval_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  trainer = Trainer(


Step,Training Loss


In [None]:
checkpoint = "checkpoint-335/"

from sklearn.metrics import accuracy_score, classification_report
import evaluate

labels_to_num_mapping
num_to_label_mapping = {}
for key, val in labels_to_num_mapping.items():
    num_to_label_mapping[val] = key

model = AutoModelForSequenceClassification.from_pretrained(os.path.join(output_dir, checkpoint))

def evaluate_texts(texts, labels):
    inputs = tokenizer(texts, padding=True, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_class_id = torch.argmax(logits, axis=1)
    predictions = list(predicted_class_id)#.squeeze())
    predictions = [num_to_label_mapping[int(x)] for x in predictions]
    metric = evaluate.load("accuracy")
    return classification_report(y_true=labels, y_pred=predictions, output_dict=True), predictions

classification_report_dict, _ = evaluate_texts(train_texts, train_labels)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "train_classification_report.csv"))

classification_report_dict, test_preds = evaluate_texts(test_texts, test_labels)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "test_classification_report.csv"))
with open(os.path.join(output_dir, checkpoint, "test_preds.json"), "w") as f:
    json.dump({"labels": test_labels, "preds": test_preds}, f)

classification_report_dict, _ = evaluate_texts(val_texts, val_labels)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "val_classification_report.csv"))

classification_report_dict, _ = evaluate_texts(train_texts_other, train_labels_other)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "train_classification_report_other.csv"))

classification_report_dict, test_preds_other = evaluate_texts(test_texts_other, test_labels_other)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "test_classification_report_other.csv"))
with open(os.path.join(output_dir, checkpoint, "test_preds_other.json"), "w") as f:
    json.dump({"labels": test_labels_other, "preds": test_preds_other}, f)

classification_report_dict, _ = evaluate_texts(val_texts_other, val_labels_other)
df = pd.DataFrame(classification_report_dict)
df.to_csv(os.path.join(output_dir, checkpoint, "val_classification_report_other.csv"))