In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy.stats as stats
import researchpy
import matplotlib
import ptitprince as pt


In [None]:
def compute_metrics(df, col1, col2, templates = None, DESC_VERBOSE=False, filter_n=None, GROUPBY_PERSONS = True, RETURN = False):
    ''' Assumption col1 > col2 for the test metric. Ensure you pass col1 as imagined/biography vs col2 as recalled/autobiography. 
    Evaluates the sequentiality of the two columns and computes t-tests between them.
    Parameters:
        df (pd.DataFrame): DataFrame containing the columns to compare.
        col1 (str): Name of the first group (e.g., imagined/biography). 
        col2 (str): Name of the second group (e.g., recalled/autobiography).
        templates (list): List of string templates to automate computing difference for different parts of sequentiality (seq, topic, context). The df must have columns corresponding to all the templates after col1 and col2 are applied to the templates. 
        For example, if col1 is 'story_imagined' and col2 is 'story_recalled', the templates could be ['{}_seq', '{}_topic_seq', '{}_context_seq'] which would require the DataFrame to have columns like 'story_imagined_seq', 'story_recalled_seq', etc. necessarily.
        
        DESC_VERBOSE (bool): If True, prints descriptive statistics of the columns.
        filter_n (int): Minimum number of entries per personality to keep in the DataFrame. Only applies if the DataFrame has a 'person' column.
        GROUPBY_PERSONS (bool): If True, averages results across personalities. Only applicable if the dataframe has a 'person' column.
        RETURN (bool): If True, returns the grouped DataFrame.
    Returns:
        None if RETURN is False, otherwise returns the grouped DataFrame.
    Example:
    >>> compute_metrics(df, 'story_imagined', 'story_recalled', templates=['{}_seq', '{}_topic_seq', '{}_context_seq'], DESC_VERBOSE=True, filter_n=20, GROUPBY_PERSONS=True)
    '''

    # Removing personalities with less than filter_n entries
    if filter_n:
        df = df.groupby('person').filter(lambda x: len(x) > filter_n)

    if templates is None:
        templates = ['{}_seq', '{}_topic_seq', '{}_context_seq']
    print(templates)
    # Descriptive statistics for all the templates across groups
    if DESC_VERBOSE:
        for template in templates:
            print('-------------------------------------')
            print(f'{template.format(col1)} description :\n', df[template.format(col1)].describe()[['count', 'mean', 'std']])
            print(f'{template.format(col2)} description :\n', df[template.format(col2)].describe()[['count', 'mean', 'std']])
            print('-------------------------------------')

    # Difference computation for all the templates across groups
    for template in templates:
        df[template.format('seq_diff')] = df[template.format(col1)] - df[template.format(col2)]
        print(f"{template.format(col1)} and {template.format(col2)} ttest : {stats.ttest_rel(df[template.format(col1)], df[template.format(col2)])}")
        _, res = researchpy.ttest(df[template.format(col1)], df[template.format(col2)], paired=True)
        display(res) 
    print('-------------------------------------')

    if GROUPBY_PERSONS == False:
        return
    # Average seq_diff across each personality
    for template in templates:
        df[template.format('seq_diff')] = df[template.format(col1)] - df[template.format(col2)]
        # Average across personality for both groups
        df_grouped = df.groupby('person')[[template.format('seq_diff'),template.format(col1), template.format(col2)]].mean()
        df_grouped = df_grouped.reset_index()

        print(f"Personality averaged {template.format('seq_diff')} t-test : {stats.ttest_1samp(df_grouped[template.format('seq_diff')], 0)}")
        _, res = researchpy.ttest(df_grouped[template.format(col1)], df_grouped[template.format(col2)], paired=True)
        display(res)
        print('-------------------------------------')

        if RETURN:
            return df_grouped


def raincloud_plot(df, col1, col2, col1_name, col2_name, title, ignore_y_ticks = False):
    pd.DataFrame.iteritems = pd.DataFrame.items # Without this somehow the ploting fails for raincloud plots
    c1 = matplotlib.colors.hex2color('#0cc0df')
    c2 = matplotlib.colors.hex2color('#ff3131')
    # display(df[[col1, col2]])
    pt.RainCloud(data=df[[col1, col2]], bw=0.05, cut=0, orient='v', palette=[c1, c2])
    # plt.title(title)
    plt.ylabel('Sequentiality', fontsize=16)
    plt.xlabel('Type', fontsize=16)
    if not ignore_y_ticks:
        plt.yticks(np.arange(-0.5, 5.5, 1), fontsize=16)
    plt.xticks([0, 1], [col1_name, col2_name], fontsize=16)
    plt.savefig(f'../data/plots/{title}.svg', bbox_inches='tight')
    plt.show()

Above function computes sequentiality difference across the constituent terms and displays the difference in terms of a paired t-test, with various descriptive stats. The requirement for using this function is to have a dataframe with columns corresponding to the two groups and the measures (seq, context and/or topic) you would like to compare across the two groups. Here is an example

In [None]:
score_dir = '../data/scores/runV3/hcV3-stories_combined_scores_Meta-Llama-3.1-8B-Instruct-AWQ-INT4.csv'
df = pd.read_csv(score_dir)
print(df.columns)
# >>> Index(['story_imagined', 'story_recalled', 'story_imagined_seq', 'story_imagined_topic_seq', 'story_imagined_context_seq', 'story_recalled_seq', 'story_recalled_topic_seq', 'story_recalled_context_seq'], dtype='object')

# Compute metrics for the story imagined vs recalled
df_grouped = compute_metrics(df, 'story_imagined', 'story_recalled', templates=['{}_seq', '{}_topic_seq', '{}_context_seq'], DESC_VERBOSE=True, RETURN=True)

We have also included code to plot the two side by side in a raincloud plot. Example (contd.)

In [None]:
# Plot the raincloud plot for the story imagined vs recalled
raincloud_plot(df_grouped, 'story_imagined_seq', 'story_recalled_seq', 'Story Imagined', 'Story Recalled', 'Story Sequentiality')