In [1]:
import os
import json
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnableLambda
from key import OPENAI_KEY, LANGSMITH_KEY # Import your own keys
from copy import deepcopy
import time
from tqdm import tqdm

os.environ["OPENAI_API_KEY"] = OPENAI_KEY
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_KEY

In [None]:
text_dir = "../data/text/Subtask_2_test.json"
anno = json.load(open(text_dir))
video_captions = json.load(open("eval_proc_out.json"))
cause_windows = json.load(open("cause_windows.json"))
emotion_labels = json.load(open("emotion_eval_labelled.json"))

#### Postprocessing Emotion Labelled data to ensure valid emotions

In [None]:
emotions = ["anger", "joy", "sadness", "surprise", "disgust", "fear"]

for a in anno:
    a["emotion-cause_pairs"] = []
    emo_labels = emotion_labels[str(a["conversation_ID"])]["conversation"]
    for i, utt in enumerate(a["conversation"]):
        emo = "neutral"
        try:
            emo = emo_labels[i]["emotion"].lower()
            if emo not in emotions: emo = "neutral"
        except:
            emo = "neutral"
        utt["emotion"] = emo

#### Loading Cause index and Window Functions

In [None]:
pos = ["beg", "mid", "end"]
db_dict = {emo : {} for emo in emotions}
for emo in emotions:
    for p in pos:
        db_dict[emo][p] = FAISS.load_local(f"cause_windows/{emo}/{p}", embeddings)

In [None]:
def get_window_beg(convo: list, size:int = 3) -> list:
    return deepcopy(convo[:size])

def get_window_end(convo: list, size:int = 6) -> list:
    return deepcopy(convo[-size:])

def get_window_mid(convo:list, idx:int, prev_size:int = 5, next_size:int = 2) -> list:
    return deepcopy(convo[max(0, idx-prev_size) : (idx+1) + next_size])

def format_window(window: tuple, label:bool = False) -> str:
    idx, window = window
    utt_idx = idx
    emo = None
    speaker = None
    out_str = ""

    for i, utt in enumerate(window):
        if idx == utt["utterance_ID"]:
            emo = utt["emotion"]
            speaker = utt["speaker"]
            out_str += f'{i+1}. {utt["speaker"]}: {utt["text"]}'
            if label: out_str += f' [{emo}]\n'
            else: out_str += '\n'
            utt_idx = i+1
        else:
            out_str += f'{i+1}. {utt["speaker"]}: {utt["text"]}\n'
    
    if label:
        out_str += f"\nWhat are the causal utterances that trigger the emotion of {emo} in {speaker} in utterance {utt_idx}?"
    
    return out_str, idx-utt_idx

#### Prompts and RAG Pipeline

In [None]:
explaination_text = """
You are an expert in analyzing conversations to extract the causes of emotions 
in particular utterances by speakers. You give definite confident answers only.

Description of emotional causes:
- Each utterance always has a reason of why it was said and why it had a particular emotion.
- A cause is an utterance that comes before or after the particular utterance in question that best explains to be the reason behind the particular emotion.
- The emotional utterance itself can be a cause of itself if its content ALSO best explains the reason for the particular emotion. 
- Sometimes the cause can be beyond the context of the conversation thus an utterance might have no cause within conversation
- There can be multiple causes for an utterance.

Here's a conversation:
{conversation}

Analyze and justify the above annotation concisely.
"""

cause_text = """
You are an expert in analyzing conversations to extract the causes of emotions 
in particular utterances by speakers. You give definite confident answers only.

Description of emotional causes:
- Each utterance always has a reason of why it was said and why it had a particular emotion.
- A cause is an utterance that comes before or after the particular utterance in question that best explains to be the reason behind the particular emotion.
- The emotional utterance itself can be a cause of itself if its content ALSO best explains the reason for the particular emotion. 
- Sometimes the cause can be beyond the context of the conversation thus an utterance might have no cause within conversation
- There can be multiple causes for an utterance.

Here are some examples of how to recgonize causes:
Example 1:
{example_1}

Example 2:
{example_2}

Example 3:
{example_3}

Now, please recognize the causes in following conversation. Heres the context for the whole conversation:
{scene}

Conversation:
{window}
"""

json_text = "{prompt}\n\nReformat the text to JSON as {{'causes': [list of causal utterance numbers]}}. No plain text."

In [None]:
explaination_prompt = ChatPromptTemplate.from_template(explaination_text)
cause_prompt = ChatPromptTemplate.from_template(cause_text)
json_prompt = ChatPromptTemplate.from_template(json_text)

explaination_pipeline = explaination_prompt | model | output_parser
cause_chain = cause_prompt | model | output_parser
cause_json_chain = (
    {"prompt": cause_chain}
    | json_prompt
    | model
    | output_parser
)

#### Getting Cause for all Conversations

In [None]:
cause_eval_labelled = {}
# cause_eval_labelled = json.load(open("cause_eval_labelled.json")) # if resuming

In [None]:
step = 10
for k, a in enumerate(anno):
    conv_idx = a["conversation_ID"]
    convo = a["conversation"]
    for i, utt in enumerate(convo):
        emo = utt["emotion"]
        utt_idx = utt["utterance_ID"]
        name = f"dia{conv_idx}utt{utt_idx}"
        if emo == "neutral": continue
        elif emo not in ["anger", "joy", "sadness", "surprise", "disgust", "fear"]:
            print("  [{}/{}] Failed to process utterance {} due to invalid emotion {}".format(i+1, len(convo), name, emo))
            continue
        if name in cause_eval_labelled: continue
        
        if i == 0:
            pos = "beg"
            window = get_window_beg(convo)
        elif i == len(convo) - 1:
            pos = "end"
            window = get_window_end(convo)
        else:
            pos = "mid"
            window = get_window_mid(convo, i)
        
        try:
            window_str, diff = format_window((utt_idx, window), False)
            window_str_labelled, diff = format_window((utt_idx, window), True)
            closest_windows = db_dict[emo][pos].similarity_search(window_str)[1:4]
            closest_indices = [wdw.page_content.split("\n")[0].strip() for wdw in closest_windows]
            closest_windows = [cause_windows[idx] for idx in closest_indices]
            window_batch = [{"conversation": window} for window in closest_windows]
            explainations = explaination_pipeline.batch(window_batch, config={"max_concurrency": 3})        
            examples = [closest_windows[i] + "\n" +exp for i, exp in enumerate(explainations)]
            scene = video_captions[str(conv_idx)]
            out = cause_json_chain.invoke({"window": window_str_labelled,
                                        "scene": scene,
                                        "example_1": examples[0],
                                        "example_2": examples[1],
                                        "example_3": examples[2]})
            causes = json.loads(out)["causes"]
            ecp = []
            for c in causes:
                ecp.append([f"{utt_idx}_{emo}", str(int(c)+diff)])
            a["emotion-cause_pairs"].extend(ecp)
            cause_eval_labelled[name] = ecp   
                
            print("  [{}/{}] Processed utterance {}".format(i+1, len(convo), name))
            
        except Exception as error:
             print("  [{}/{}] Failed to process utterance {} due to {}".format(i+1, len(convo), name, error))
             
    print("[{}/{}] Processed Conv {}".format(k+1, len(anno), conv_idx))
    if k % step == 0:
        print("json dump...")
        json.dump(cause_eval_labelled, open("cause_eval_labelled.json", "w"))

In [None]:
json.dump(cause_eval_labelled, open("cause_eval_labelled.json", "w"))
json.dump(anno, open("cur_anno.json", "w"))

#### Postprocessing Causes to added self-causes

In [None]:
for convo in anno:
    same_ecp = []
    for utt in convo["conversation"]:
        uid = utt["utterance_ID"]
        emo = utt["emotion"]
        if emo != "neutral":
            same_ecp.append([str(uid)+"_"+emo, str(uid)])
    for p in same_ecp:
        if p not in convo["emotion-cause_pairs"]:
            convo["emotion-cause_pairs"].append(p)

json.dump(anno, open("cur_anno_same_added.json", "w"))