In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
nas_dir = "/home/jaejoong/cocoanlab02"
dat_dir = os.path.join(nas_dir, "projects/AIDA/dataset")
result_dir = os.path.join(nas_dir, "projects/AIDA/results")

In [None]:
all_orig_df = pd.read_csv(os.path.join(dat_dir, "detailed_lables.csv"))
test_1_df = pd.read_csv(os.path.join(dat_dir, "full_test_split.csv"))

In [None]:
all_df = all_orig_df.copy()
all_df["split"] = "Training"
all_df.loc[all_df["Participant"].isin(test_1_df["Participant_ID"]), "split"] = "Test 1"
all_df.loc[all_df["Participant"] >= 600, "split"] = "Test 2"
all_df["split_edaic"] = all_df["split"].replace({"Test 2": "Test 2 E-DAIC"})
all_df.loc[(all_df["Participant"] >= 600) & (all_orig_df["split"] != "test"), "split_edaic"] = ""

In [None]:
all_dialogue = []
all_dialogue_P = []

for idx, row in all_df.iterrows():

    ts_file = os.path.join(dat_dir, "all_transcripts", str(row["Participant"]) + "_TRANSCRIPT.csv")
    ts_df = pd.read_csv(ts_file, sep=None, engine="python")
    ts_df = ts_df[ts_df["value"] != "<synch>"].copy()
    ts_df = ts_df.dropna(subset=["value"]).copy()
    ts_df["speaker_change"] = ts_df["speaker"] != ts_df["speaker"].shift()
    ts_df["group"] = ts_df["speaker_change"].cumsum()
    ts_df["value"] = ts_df["value"].astype(str)
    merged_ts_df = ts_df.groupby("group").agg({
        "speaker": "first",
        "value": ". ".join,
    }).reset_index(drop=True)
    dialogue = "\n".join(merged_ts_df["speaker"] + ": " + merged_ts_df["value"] + ".")
    all_dialogue.append(dialogue)
    merged_P_ts_df = merged_ts_df[merged_ts_df["speaker"] == "Participant"].reset_index(drop=True)
    dialogue_P = "\n".join(merged_P_ts_df["value"] + ".")
    all_dialogue_P.append(dialogue_P)

all_df["Dialogue"] = all_dialogue
all_df["Dialogue_P"] = all_dialogue_P

print("No virtual agent (Ellie)'s transcriptions")
wh_ellie = all_df["Dialogue"].str.contains("Ellie")
display(all_df.loc[~wh_ellie, :])
all_df = all_df[wh_ellie].copy()
all_df.reset_index(drop=True, inplace=True)

In [None]:
pd.set_option("display.max_rows", 6)
for split in ["Training", "Test 1", "Test 2"]:    
    df = all_df[all_df["split"] == split]
    print(f'{split}: Mean = {df.age.mean():.1f}, Std = {df.age.std():.1f}, Female {(df.gender=="female").sum()}')
    display(df)

In [None]:
plt.figure(figsize=(9, 7))
sns.set_style("ticks", {"font.sans-serif": "Helvetica", "xtick.bottom": False})
sns.boxplot(all_df, x="split", y="Depression_severity", order=["Training", "Test 1", "Test 2"],
            hue="split", hue_order=["Training", "Test 1", "Test 2"], palette=["#FFB3B5", "#CFC982", "#76D9B1"], showfliers=False, fill=False, linewidth=3, width=0.5)
sns.swarmplot(all_df, x="split", y="Depression_severity", order=["Training", "Test 1", "Test 2"],
            hue="split", hue_order=["Training", "Test 1", "Test 2"], palette=["#E16A86", "#AA9000", "#00AA5A"], size=5)
sns.despine()
plt.ylim([-2, 25])
plt.xlabel(None)
plt.ylabel('PHQ-8 score', fontsize=14)
plt.xticks(ticks=[0, 1, 2], labels=["Training set", "Test set 1", "Test set 2"], fontsize=14)
plt.yticks(fontsize=14)
plt.tick_params(axis='both', length=10)
plt.savefig(os.path.join(result_dir, "AIDA_PHQ-8_score_box.pdf"), bbox_inches='tight')
plt.show()

In [None]:
all_df.to_csv(os.path.join(result_dir, "AIDA_all_df.csv"), index=False)