In [None]:
# LOAD DATASET

import json
import os
import sys
from collections import Counter
import matplotlib.pyplot as plt

# add ../ to the sys path

sys.path.append(os.path.join(os.path.abspath(''), ".."))

# change working directory

os.chdir(os.path.join(os.path.abspath(''), ".."))

datasets = {
    "wildchat": "./labeled_data/wildchat_sample_labeled.json",
    "copilot": "./labeled_data/copilot_sample_labeled.json",
    "multiwoz": "./labeled_data/multiwoz_sample_labeled.json",
} 

loaded_datasets = {}

label_remap = {
    "agreement": "next turn",
    "disagreement": "next turn",
    "response": "next turn",
    "next turn": "next turn",
    "followup": "followup",
    "clarification": "clarification",
    "acknowledgement": "acknowledgement",
    "display": "acknowledgement",
    "overresponse": "overresponse",
    "repeat": "repeat",
    "topic switch": "topic switch",
    "repair": "repair",
    "request": "followup",
}





In [None]:
def fix_labels(dataset):
    for convo in dataset:
        if convo["labels"] is None:
            # print("No labels in this conversation")
            continue

        try:
            curr_convo = convo['labels']['messages']
        except:
            print(curr_convo)
            continue 
        
        for turn in curr_convo[1:]:

            try:
                curr_set = set([label_remap[label] for label in turn["labels"]])

                pos_grounding = [
                    "acknowledgement",
                    "followup",
                    "agreement",
                    "disagreement",
                    "response",
                    "overresponse"
                ]

                if len(curr_set) == 0:
                    curr_set.add("next turn")

                # if curr set includes anything from pos_grounding, then include next turn
                if len(curr_set.intersection(pos_grounding)) > 0:
                    curr_set.add("next turn")

                turn["labels"] = list(curr_set)
                turn["labels"].sort()
            except Exception as e:
                print("No labels in this turn")
                continue

for dataset in datasets:
    fname = datasets[dataset]
    with open(fname, "r") as f:
        loaded_datasets[dataset] = json.load(f)
        fix_labels(loaded_datasets[dataset])

In [None]:
def get_counts(dataset_name):
    user = []
    assistant = []

    # get the counts of each label
    for convo in loaded_datasets[dataset_name]:
        if convo["labels"] is None:
            # print("No labels in this conversation")
            continue

        try:
            curr_convo = convo['labels']['messages']
        except:
            print(curr_convo)
            continue 
                
        for turn in curr_convo[1:]:
            try:

                turn_counter = Counter(turn["labels"])

                # add NOT labels to the counter
                vals = set(label_remap.values())
                for label in vals:
                    if label not in turn_counter:
                        turn_counter["NOT " + label] = 1

                # make sure all the values are 1 in the counter
                for key in turn_counter:
                    turn_counter[key] = 1

                if turn["role"] == "user":
                    user.append(turn_counter)

                elif turn["role"] == "assistant":
                    assistant.append(turn_counter)
                
            except Exception as e:
                print("No labels in this turn")
                continue

    user_counter = Counter()
    assistant_counter = Counter()

    for turn in user:
        user_counter += turn

    for turn in assistant:
        assistant_counter += turn

    return user_counter, assistant_counter
    


In [None]:
remap_datasets = {
    "wildchat": "WildChat\n(Human-AI)",
    "copilot": "CoPilot\n(Human-AI)",
    "multiwoz": "MultiWOZ\n(Human-Human)"
}

remap_keys = {
    "next turn": "Continue",
    "followup": "Follow-up",
    "clarification": "Clarification",
    "acknowledgement": "Ack.",
    "overresponse": "Overresponse",
    "repeat": "Repeat",
    "topic switch": "Topic Switch",
    "repair": "Repair"
}

def normalize_counters(counter):

    frac_counter = {}

    for key in counter:
        if key.startswith("NOT "):
            continue

        not_key = "NOT " + key

        if key not in frac_counter:
            frac_counter[key] = counter[key] / (counter[key] + counter[not_key])

    return frac_counter

def render_figure(target_keys, fname, ylim):
    fig, ax = plt.subplots(3, 1, figsize=(5, 5))

    for row, dataset in enumerate(datasets):


        user_counter, assistant_counter = get_counts(dataset)

        frac_counters = {
            "user": normalize_counters(user_counter),
            "assistant": normalize_counters(assistant_counter)
        }

        print(frac_counters)

        # colors = {
        #     "user": "#e1bb97",
        #     "assistant": "#b5c7e7"
        # }



        colors = {
            "user": "#008000",
            "assistant": "#90ee90"
        }

        if "repair" in target_keys:
            colors = {
                "user": "#8b0000",
                "assistant": "#f08080"
            }

        # three bars, for each dataset name

        bar_width = 0.25

        pos = [i for i in range(len(target_keys))]

        for col, sender in enumerate(["user", "assistant"]):
            vals = [frac_counters[sender].get(key, 0) for key in target_keys]

            ax[row].bar([i +  bar_width * col for i in pos], vals, bar_width, label=sender, color=colors[sender])

        ax[row].set_xticks([i + (bar_width / 2) for i in range(len(target_keys))])

        ax[row].set_xticklabels([remap_keys[key] for key in target_keys])

        ax[row].set_title(remap_datasets[dataset], y=0.58)

        for label in ax[row].get_xticklabels():
            label.set_fontsize(13)  # Set fontsize to 14

        for label in ax[row].get_yticklabels():
            label.set_fontsize(13)

        # ylabel font

        if row == 1 and "repair" not in target_keys:
            ax[row].set_ylabel("Percentage of Turns", fontsize=14)

        # set ylim
        # if row == 0:
        ax[row].set_ylim(0, ylim)

        # legend
        if row == 0:
            # ax[row].legend()
            # move the legend to the right
            ax[row].legend(loc="upper right")


    # remove stems
    for row in range(3):
        ax[row].spines['top'].set_visible(False)
        ax[row].spines['right'].set_visible(False)

    plt.tight_layout()
    plt.savefig(fname)


In [None]:
render_figure(
    ["next turn", "followup", "acknowledgement", "overresponse"], 
    "pos_grounding.png",
    ylim=1.35,
)

In [None]:
render_figure(["repeat", "clarification", "repair"], "neg_grounding.png", ylim=0.21)

In [None]:
def convert_sequence(dataset):

    seq_remap = {
        "next turn": "S",
        "followup": "S",
        "acknowledgement": "S",
        "overresponse": "S",
        "repeat": "F",
        "clarification": "F",
        "repair": "F",
        "topic switch": "N"
    }

    seqs = []

    for convo in dataset:
        if "labels" in convo and convo["labels"] is not None:


            try:
                curr_convo = convo['labels']['messages']
            except:
                print(curr_convo)
                continue 

            seq = "B"

            for turn in curr_convo[1:]:

                # skip the assistant turns


                # see if more S or more F in turn

                try:
                    if turn["role"] == "assistant":
                        continue
                    
                    curr_set = Counter([seq_remap[label] for label in turn["labels"]])

                    if curr_set["F"] > curr_set["S"]:
                        seq += "F"

                    else:
                        seq += "S"
                
                except:
                    print("No labels in this turn")
                    continue


            seq += "E"
            seqs.append(seq)

    return seqs
     

In [None]:
def compute_likelihoods(sequences, prefix):
    prefix_satistifed = [seq for seq in sequences if seq.startswith(prefix)]

    # get character after prefix

    next_chars = [seq[len(prefix)] for seq in prefix_satistifed]

    # compute likelihooods over the next character

    next_char_counter = Counter(next_chars)

    return next_char_counter, { key: next_char_counter[key] / len(next_chars) for key in next_char_counter }


In [None]:
# plot compounding effects
from collections import defaultdict

def render_compounding_effects(dataset_name, compound_char="S"):

    char_remap = {
        "S": "Success",
        "F": "Failure",
        "E": "Leave"
    }

    list_strats = [
        "Success", "Failure", "Leave"
    ]

    sequences = convert_sequence(loaded_datasets[dataset_name])

    fig, ax = plt.subplots(1, 4, figsize=(10, 3))

    # line plot for each of the three categories
    

    for i in range(4):
        _, probs = compute_likelihoods(sequences, "B" + compound_char * i)

        X = ["E", "F", "S"]
        Y = [probs.get(char, 0) for char in X]

        # ax[i].bar(X, Y)

        # set colors for each bar

        colors = {

            "S": "#008000",
            "F": "#8b0000",
            "E": "gray"
        }

        for j, char in enumerate(X):

            ax[i].bar(char_remap[char], Y[j], color=colors[char])

            ax[i].set_ylim(0, 0.7)
        # ax[i].set_title(list_strats[i])

        if i == 0:
            ax[i].set_ylabel("Likelihood")
        # ax[i].set_ylabel("Likelihood")
        ax[i].set_xlabel("Outcome")



    plt.tight_layout()



def render_compounding_effects_line(dataset_name, compound_char="S"):


    colors = {

        "S": "#008000",
        "F": "#8b0000",
        "E": "gray"
    }

    char_remap = {
        "S": "Success",
        "F": "Failure",
        "E": "Leave"
    }

    list_strats = [
        "Success", "Failure", "Leave"
    ]

    sequences = convert_sequence(loaded_datasets[dataset_name])

    fig, ax = plt.subplots(1, 2, figsize=(10, 3))

    # line plot for each of the three categories, x-axis is the number of compound_chars

    X = [i for i in range(0, 5)]

    for char in ["S", "F", "E"]:

        Y = []

        for i in range(0, 5):
            _, probs = compute_likelihoods(sequences, "B" + "F" * i)
            Y.append(probs.get(char, 0))

        ax[0].plot(X, Y, label=char_remap[char], color=colors[char])

    ax[0].set_xlabel("Repeated Failed Turns")
    ax[0].set_ylabel("Likelihoods of Next Turn")
    ax[0].set_title("Compounding Failure")

    for char in ["S", "F", "E"]:
        Y = []

        for i in range(0, 5):
            _, probs = compute_likelihoods(sequences, "B" + "S" * i)
            Y.append(probs.get(char, 0))

        # ax[1].plot(X, Y, label=char_remap[char])

        # set line colors
        ax[1].plot(X, Y, label=char_remap[char], color=colors[char])

    ax[1].set_xlabel("Repeated Succesful Turns")
    ax[1].set_ylabel("Likelihoods of Next Turn")
    ax[1].set_title("Compounding Success")

    ax[1].set_ylim(0, 0.7)
    ax[0].set_ylim(0, 0.7)

    ax[0].legend()


render_compounding_effects_line("wildchat", "S")