In [None]:
import os
import re
import json
from copy import deepcopy
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from key import OPENAI_KEY, LANGSMITH_KEY # Add your own keys
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

os.environ["OPENAI_API_KEY"] = OPENAI_KEY
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_KEY

In [None]:
text_dir = "../data/text/Subtask_2_train.json"
anno = json.load(open(text_dir))

#### All Emotion Conversation Index
- Finding such conversations in training set
- Formatting them and getting explainations for emotion annotations
- Saving explainations
- Storing the conversations with all emotions in FAISS index 

In [None]:
emotions = ["anger", "disgust", "fear", "joy", "sadness", "surprise", "neutral"]
idx = []
for i, a in enumerate(anno):
    e_flag = {e: False for e in emotions}
    for utt in a["conversation"]:
        e_flag[utt["emotion"]] = True
    if sum(e_flag.values()) == len(emotions):
        idx.append(i)

In [None]:
def format_convo(idx, conversation, emotion=False):
    out = str(idx)
    # out = ""
    for i, utt in enumerate(conversation):
        out += f'\n{i+1}. {utt["speaker"]}: {utt["text"]}'
        if emotion:
            out += f' [{utt["emotion"]}]'
    return out

In [None]:
emotion_explaination_prompt = """
There are 6 basic emotions: Anger, Disgust, Fear, Joy, Sadness, Surprise. 
The emotion of the speaker is determined by the context of the conversation. 
If the emotion is not in any category, is a mix of several categories, or is ambiguous it can be categorized as "Neutral". 

Analyze the following conversation where emotion of each utterance is annotated in square brackets at the end. 
Give reasoning behind the annotation of each utterance.

{conversation}

Output a JSON in the following format:
[{{"utterance_ID": id,
  "text" : content,
  "speaker": speaker
  "emotion": emotion, 
  "explanation": detailed explanation}}
  ...
]
No plain text.
"""

In [None]:
model = ChatOpenAI(openai_api_key=OPENAI_KEY)
output_parser = StrOutputParser()

emotion_explaination_prompt = ChatPromptTemplate.from_template(emotion_explaination_prompt)
emotion_explaination_chain = emotion_explaination_prompt | model | output_parser

In [None]:
emo_convos = [format_convo(i, anno[i]["conversation"], True) for i in idx]
batch = [{"conversation": convo} for convo in emo_convos]

In [None]:
outs = emotion_explaination_chain.batch(batch, config={"max_concurrency": 5})

In [None]:
# Helper functions to fix a faulty json output from GPT
def strip(s):
    return re.sub('[^0-9a-zA-Z]+', '', s.strip())

def fix_json(s):
    emo_json = [] 
    cur_dict = {}
    for line in s.split("\n"):
        components = line.strip().split(":")
        key = strip(components[0])
        if key == "utteranceID":
            if len(cur_dict) != 0:
                emo_json.append(cur_dict)
                cur_dict = {}
            cur_dict["utterance_ID"] = strip(components[1]) 
        elif key == "text":
            cur_dict[key] = ":".join(components[1:])
        elif key == "speaker":
            cur_dict[key] = strip(components[1])
        elif key == "emotion":
            cur_dict[key] = strip(components[1])
        else:
            cur_dict[key] = components[1].strip()[1:-3]
    return emo_json

In [None]:
emo_explain_dict = {}

for i, k in enumerate(idx):
    try:
        v = json.loads(outs[i])
    except:
        v = fix_json(outs[i])
    emo_explain_dict[k] = v

In [None]:
json.dump(emo_explain_dict, open("emotion_explainations.json", "w"))

##### Creating FAISS Index

In [None]:
embeddings = OpenAIEmbeddings()
all_emo_convos = [format_convo(i, anno[i]["conversation"], False) for i in idx]
db = FAISS.from_texts(all_emo_convos, embeddings)
db.save_local("all_emotion_index")

#### Cause Windows Index 
- Create windows based on the position of emotional utterance (beg, mid, end)
- Create FAISS indices for three types of windows for each 6 emotions for RAG. (18 Indices in total)
- Save the Cause Windows in JSON file

In [None]:
def get_window_beg(convo: list, size:int = 3) -> list:
    return deepcopy(convo[:size])

def get_window_end(convo: list, size:int = 6) -> list:
    return deepcopy(convo[-size:])

def get_window_mid(convo:list, idx:int, prev_size:int = 5, next_size:int = 2) -> list:
    return deepcopy(convo[max(0, idx-prev_size) : (idx+1) + next_size])

In [None]:
emotions = ["anger", "joy", "sadness", "surprise", "disgust", "fear"]
index_dict = {emo: {"beg": [], "mid": [], "end": []} for emo in emotions}

for a in anno:
    for i, utt in enumerate(a["conversation"]):
        if utt["emotion"] != "neutral":
            if i == 0:
                index_dict[utt["emotion"]]["beg"].append((utt["utterance_ID"], 
                                                          get_window_beg(a["conversation"]),
                                                          utt["causes"],
                                                          utt["video_name"].split(".")[0]))
            elif i == len(a["conversation"]) - 1:
                index_dict[utt["emotion"]]["end"].append((utt["utterance_ID"], 
                                                          get_window_end(a["conversation"]),
                                                          utt["causes"],
                                                          utt["video_name"].split(".")[0]))
            else:
                index_dict[utt["emotion"]]["mid"].append((utt["utterance_ID"], 
                                                          get_window_mid(a["conversation"], i),
                                                          utt["causes"],
                                                          utt["video_name"].split(".")[0]))

In [None]:
def format_window(window: tuple, label:bool = False) -> str:
    idx, window, causes, name = window
    utt_idx = idx
    emo = None
    out_str = ""
    if not label:
        out_str += f"{name}\n"
    for i, utt in enumerate(window):
        if idx == utt["utterance_ID"]:
            emo = utt["emotion"]
            out_str += f'{i+1}. {utt["speaker"]}: {utt["text"]}'
            if label: out_str += f' [{emo}]\n'
            else: out_str += '\n'
            utt_idx = i+1
        else:
            out_str += f'{i+1}. {utt["speaker"]}: {utt["text"]}\n'
    return out_str

In [None]:
for emo in emotions:
    for pos in ["beg", "mid", "end"]:
        window_strings = [format_window(window) for window in index_dict[emo][pos]]
        db = FAISS.from_texts(window_strings, embeddings)
        db.save_local(f"cause_windows/{emo}/{pos}")        

In [None]:
cause_windows = {}

for emo in emotions:
    for pos in ["beg", "mid", "end"]:
        for window in index_dict[emo][pos]:
            window_str = format_window(window, True)
            idx, wdw, cs, name = window
            cause_windows[name] = window_str

json.dump(cause_windows, open("cause_windows.json", "w"))