In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Bulk analysis and plot generation for the paper

This notebook will generate all of the plots included in the paper. This assumes you have generated and cleaned all of the data (or acquired it through 
HuggingFace at https://huggingface.co/datasets/copenlu/llm-pct-tropes) and placed in under `../data/consolidated_clean` and the base-case data under `../data/consolidated_clean_base`

## 1) Generate the plots for Figures 2 and 3 (PCT positions)

In [98]:
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from pathlib import Path
from collections import Counter
from matplotlib.lines import Line2D

if not os.path.exists('../figures'):
    os.mkdir('../figures')
    
gen_data_loc = '../data/consolidated_clean'
pct_data_loc = '../data/political_compass'
prompting_loc = '../data/prompting'

In [None]:
answer_map = {'Strongly disagree': 0, 'Strongly Disagree': 0, 'Disagree': 1, 'Agree': 2, 'Strongly agree': 3, 'Strongly Agree': 3, 'None':-1}
personas = json.load(open(f'{prompting_loc}/personas.json', 'r'))
personas['age'] = [float(i) for i in personas['age']]
personas = {i:j for i, j in personas.items() if i != 'party'}

categories = list(personas.keys())

models = ['Llama-2-13b-chat-hf', 'Mixtral-8x7B-Instruct-v0.1', 
        'Mistral-7B-Instruct-v0.2', 'zephyr-7b-beta',
        'OLMo-7B-Instruct', 'Meta-Llama-3-8B-Instruct']

questions = [l.strip() for l in open(f"{pct_data_loc}/political_compass_questions.txt", 'r').readlines()]
q_map = {q.strip():i for i,q in enumerate(questions, start=1)}

In [None]:
### Util.py
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Political compass calculations taken from https://politicalcompass.github.io/
e0 = 0.38
s0 = 2.41
econv = [
    #[4.5, 2.5, -2.5, -4.5],
    [7, 5, 0, -2],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [7, 5, 0, -2],
    [-7, -5, 0, 2],
    [6, 4, 0, -2],
    [7, 5, 0, -2],
    [-8, -6, 0, 2],
    [8, 6, 0, -2],
    [8, 6, 0, -1],
    [7, 5, 0, -3],
    [8, 6, 0, -1],
    [-7, -5, 0, 2],
    [-7, -5, 0, 1],
    [-6, -4, 0, 2],
    [6, 4, 0, -1],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [-8, -6, 0, 1],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [-10, -8, 0, 1],
    [-5, -4, 0, 1],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [-9, -8, 0, 1],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0]
]

socv = [
    [0, 0, 0, 0],
    [-8, -6, 0, 2],
    [7, 5, 0, -2],
    [-7, -5, 0, 2],
    [-7, -5, 0, 2],
    [-6, -4, 0, 2],
    [7, 5, 0, -2],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [-6, -4, 0, 2],
    [7, 6, 0, -2],
    [-5, -4, 0, 2],
    [0, 0, 0, 0],
    [8, 4, 0, -2],
    [-7, -5, 0, 2],
    [-7, -5, 0, 3],
    [6, 4, 0, -3],
    [6, 3, 0, -2],
    [-7, -5, 0, 3],
    [-9, -7, 0, 2],
    [-8, -6, 0, 2],
    [7, 6, 0, -2],
    [-7, -5, 0, 2],
    [-6, -4, 0, 2],
    [-7, -4, 0, 2],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [7, 5, 0, -3],
    [-9, -6, 0, 2],
    [-8, -6, 0, 2],
    [-8, -6, 0, 2],
    [-6, -4, 0, 2],
    [-8, -6, 0, 2],
    [-7, -5, 0, 2],
    [-8, -6, 0, 2],
    [-5, -3, 0, 2],
    [-7, -5, 0, 2],
    [7, 5, 0, -2],
    [-6, -4, 0, 2],
    [-7, -5, 0, 2],
    [-6, -4, 0, 2],
    [0, 0, 0, 0],
    [-7, -5, 0, 2],
    [-6, -4, 0, 2],
    [-7, -6, 0, 2],
    [7, 6, 0, -2],
    [7, 5, 0, -2],
    [8, 6, 0, -2],
    [-8, -6, 0, 2],
    [-6, -4, 0, 2]
]

def political_compass_values(answers):
    sumE = 0
    sumS = 0

    for i in range(62):
        if answers[i] != -1:
            sumE += econv[i][answers[i]]
            sumS += socv[i][answers[i]]

    valE = sumE / 8.0
    valS = sumS / 19.5

    valE += e0
    valS += s0

    valE = round((valE + 1e-15) * 100) / 100
    valS = round((valS + 1e-15) * 100) / 100

    return valE, valS

def get_values(answers: pd.DataFrame):
    answer_mat = np.rint(np.array([[answer_map[a] for a in row[1:]] for row in answers.to_numpy()]).mean(-1)).astype(np.int32)
    loc = political_compass_values(answer_mat)
    return loc
                             
sns.set(style="whitegrid", font_scale=1.5)
plt.rcParams['font.family'] = 'DejaVu Sans'

def political_compass_base_plot(figsize):
    fig, ax = plt.subplots(figsize=figsize, clip_on=False)

    ax.set_xlim((-10, 10))
    ax.set_ylim((-10, 10))

    ax.set_xticks(list(range(-10, 11)))
    ax.set_xticklabels([])
    ax.set_yticks(list(range(-10, 11)))
    ax.set_yticklabels([])
    ax.axhline(y=0, color='k')
    ax.axvline(x=0, color='k')

    ax.set_facecolor('white')
    for sp in ax.spines:
        ax.spines[sp].set_color('#AAAAAA')
        ax.spines[sp].set_visible(True)

    plt.grid(color='grey', linestyle='--', linewidth=0.5)

    # Define the quadrants with the original color maps
    extent = [0, 10, 0, 10]
    arr = np.array([[1, 1], [1, 1]])
    ax.imshow(arr, extent=extent, cmap='winter', interpolation='none', alpha=0.15)

    extent = [-10, 0, 0, 10]
    arr = np.array([[1, 1], [1, 1]])
    ax.imshow(arr, extent=extent, cmap='autumn', interpolation='none', alpha=0.15)

    extent = [-10, 0, -10, 0]
    arr = np.array([[1, 1], [1, 1]])
    ax.imshow(arr, extent=extent, cmap='summer', interpolation='none', alpha=0.15)

    extent = [0, 10, -10, 0]
    arr = np.array([[1, 1], [1, 1]])
    ax.imshow(arr, extent=extent, cmap='spring_r', interpolation='none', alpha=0.15)

    ax.annotate("Economic right", xy=(9.8, -0.75), fontsize=16, ha='right')
    ax.annotate("Economic left", xy=(-9.8, -0.75), fontsize=16)
    ax.annotate("Authoritarian", xy=(0, 8.75), fontsize=16, annotation_clip=False, ha='center')
    ax.annotate("Libertarian", xy=(0, -8.75), fontsize=16, annotation_clip=False, ha='center', va='top')

    return fig, ax

markers = ['o', 's', 'D', 'X', 'P']
colors = sns.color_palette("colorblind", 5)

def add_datapoints(ax, df, hue_col, hue_order=None):
    sns.scatterplot(data=df, x='x', y='y', hue=hue_col, palette=colors, s=130, edgecolor='black', ax=ax, hue_order=hue_order, style=hue_col, markers=markers, alpha=0.9)

# Define some colors for the data points
legend_elements = [Line2D([0], [0], marker=m, color='b', label=cat, markerfacecolor=c, markersize=10) for m, c, cat in zip(markers, colors, categories)]

### Figure 2

In [None]:
for model in models:
    fig, ax = political_compass_base_plot((8, 8))
    df_orig = pd.read_csv(f'{gen_data_loc}/closed/{model}.csv')
    df_orig.proposition = df_orig.proposition.str.strip()

    df = df_orig[df_orig.selection.isin(answer_map.keys())]
    print(f"{model}: {Counter(df['selection'])}, {len(df_orig)}, {len(df)}")
    all_data = []
    for category, color in zip(categories, colors):
        print(category)
        for value in personas[category]:
            df_value = df[df[category] == value]
            for i, inst in enumerate(df_value.instruction.unique()):
                # Plot for each instruction (same proposition, different way of asking)
                df_value_inst = df_value[df_value.instruction == inst]
                df_value_inst = df_value_inst.assign(proposition_id=[q_map[q] for q in df_value_inst['proposition']])
                # Find missing proposition ids (to preserve the order of propositions)
                missing_ids = list(set(range(1, 63)) - set(df_value_inst['proposition_id'].values))
                df_value_inst.sort_values(by='proposition_id', inplace=True)
                missing_df = pd.DataFrame({'proposition_id': missing_ids, 'selection': ['None'] * len(missing_ids)})
                # Fill in missing propositions with 'None'
                answers = pd.concat([missing_df, df_value_inst], ignore_index=True)
                answers.sort_values(by='proposition_id', inplace=True)
                answers = answers[['proposition', 'selection']]
                loc = get_values(answers)
                df_value_inst['x'] = loc[0]
                df_value_inst['y'] = loc[1]
                df_value_inst['category'] = category
                all_data.append(df_value_inst.iloc[:1])
    
    if all_data:
        
        final_df = pd.concat(all_data, ignore_index=True)
        add_datapoints(ax, final_df, 'category')
        
    
    plt.title(model, fontdict={'fontsize': 20})
    ax.legend(handles=legend_elements, loc='upper right', prop={'size': 12})
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(f'../figures/closed_{model}.png', dpi=300)
    plt.show()

### Figure 3

In [None]:
category = "political_orientation"
classes = ["far left", "mainstream left", "mainstream right", "far right"]
legend_elements = [Line2D([0], [0], marker=m, color='b', label=cat,
                          markerfacecolor=c, markersize=10) for m, c, cat in zip(markers, colors, classes)]


all_data = []
for model in ['Mixtral-8x7B-Instruct-v0.1']:
    df_orig = pd.read_csv(f"{gen_data_loc}/closed/{model}.csv")
    #df_orig["political_orientation"] = df_orig["political_orientation"].apply(lambda x: "mainstream" if type(x) == str and "mainstream" in x else x)
    df_orig = df_orig[df_orig.selection.isin(answer_map.keys())]
    for value in classes:
        df_value = df_orig[df_orig[category] == value]
        for i, inst in enumerate(df_value.instruction.unique()):
            # Plot for each instruction (same proposition, different way of asking)
            df_value_inst = df_value[df_value.instruction == inst]
            df_value_inst = df_value_inst.assign(proposition_id=[q_map[q] for q in df_value_inst['proposition']])
            # Find missing proposition ids (to preserve the order of propositions)
            missing_ids = list(set(range(1, 63)) - set(df_value_inst['proposition_id'].values))
            df_value_inst.sort_values(by='proposition_id', inplace=True)
            missing_df = pd.DataFrame({'proposition_id': missing_ids, 'selection': ['None'] * len(missing_ids)})
            # Fill in missing propositions with 'None'
            answers = pd.concat([missing_df, df_value_inst], ignore_index=True)
            answers.sort_values(by='proposition_id', inplace=True)
            answers = answers[['proposition', 'selection']]
            loc = get_values(answers)
            df_value_inst['x'] = loc[0]
            df_value_inst['y'] = loc[1]
            df_value_inst['category'] = value
            all_data.append(df_value_inst.iloc[:1])

final_df = pd.concat(all_data, ignore_index=True)
fig, ax = political_compass_base_plot((8, 8))
add_datapoints(ax, final_df, 'category', hue_order=classes)
        
    
plt.title(model, fontdict={'fontsize': 20})
ax.legend(handles=legend_elements, loc='upper right', prop={'size': 12})
ax.set_xlabel('')
ax.set_ylabel('')
plt.tight_layout()
plt.savefig(f'../figures/closed_{model}_politics.png', dpi=300)
plt.show()

### Figure 4: Regression coefficient heatmap

In [None]:
from statsmodels.stats.anova import anova_lm
from statsmodels.formula.api import ols

if not os.path.exists('../figures/regression'):
    os.mkdir('../figures/regression')

In [None]:
def generate_heatmap_single(arrays, labels_x, labels_y, labels_values, ax, cmap, colorbar_ticks=None, title=None):

    sns.heatmap(arrays, annot=labels_values, ax=ax, cmap=cmap, fmt='', cbar=True, annot_kws={"fontsize":20},
               cbar_kws={"pad":0.01, 'shrink': 0.8}, mask=np.array(labels_values) == "NS")
    # TODO set colorbar labels
    if colorbar_ticks:
        ax.collections[0].colorbar.set_ticks([np.array(arrays).min(), np.array(arrays).max()], labels=colorbar_ticks)
    ax.set_yticks(ax.get_yticks(), labels=labels_y, rotation='horizontal')
    if labels_x != None:
        ax.set_xticks(ax.get_xticks(), labels=labels_x, rotation='vertical')
    else:
        ax.set_xticks([])
    ax.set_title(title)

def generate_heatmap(arrays, labels_x, labels_y, labels_values, axs, cmap, colorbar_ticks=None, title=None):

    for k,ax in enumerate(axs):
        
        sns.heatmap([arrays[k]], annot=[labels_values[k]], ax=ax, cmap=cmap[k], fmt='', cbar=True, annot_kws={"fontsize":20},
                   cbar_kws={"pad":0.01, 'shrink': 0.8})
        # TODO set colorbar labels
        if colorbar_ticks:
            ax.collections[0].colorbar.set_ticks([min(arrays[k]), max(arrays[k])], labels=colorbar_ticks[k])
        ax.set_yticks(ax.get_yticks(), labels=[labels_y[k]], rotation='horizontal')
        if k == len(axs) - 1:
            ax.set_xticks(ax.get_xticks(), labels=labels_x, rotation='vertical')
        #ax.set_xlabel("Question")
        if k == 0:
            ax.set_title(title)

In [None]:
def extract_coefficients(lm, df):
    all_names = set([str(c) for c in df['category']])
    print(all_names)
    param_values = {}
    for n,p in lm.params.items():
        if '.' in n:
            name = '.'.join(n.split('.')[1:])[:-1]
            all_names.remove(name)
        else:
            name = n
        param_values[name] = (p, lm.pvalues[n])
    param_values[f"{list(all_names)[0]}"] = param_values.pop('Intercept')
    return param_values

def get_basecase_locs(df, category_name):
    data = []
    for i, inst in enumerate(df.instruction.unique()):
        # Plot for each instruction (same proposition, different way of asking)
        df_value_inst = df[df.instruction == inst]
        df_value_inst = df_value_inst.assign(proposition_id=[q_map[q] for q in df_value_inst['proposition']])
        # Find missing proposition ids (to preserve the order of propositions)
        missing_ids = list(set(range(1,63)) - set(df_value_inst['proposition_id'].values))
        df_value_inst.sort_values(by='proposition_id', inplace=True)
        missing_df = pd.DataFrame({'proposition_id': missing_ids, 'selection': ['None']*len(missing_ids)})
        # Fill in missing propositions with 'None'
        answers = pd.concat([missing_df, df_value_inst], ignore_index=True)
        answers.sort_values(by='proposition_id', inplace=True)
        answers = answers[['proposition', 'selection']]
        loc = get_values(answers)
        data.append([category_name, loc[0], loc[1]])
    return data

models = ['Llama-2-13b-chat-hf', 'Mixtral-8x7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'zephyr-7b-beta', 'OLMo-7B-Instruct', 'Meta-Llama-3-8B-Instruct']

for setting in ['closed', 'open']:
    econ_values = []
    pol_values = []
    econ_labels = []
    pol_labels = []
    for model in models:
        econ_param_values = {}
        pol_param_values = {}
        df_orig = pd.read_csv(f'{gen_data_loc}/{setting}/{model}.csv')
        df_orig.proposition = df_orig.proposition.str.strip()
        df = df_orig[df_orig.selection.isin(answer_map.keys())]
        
        basecase = pd.read_csv(f'{gen_data_loc}_base/{setting}/{model}.csv')
        basecase.proposition = basecase.proposition.str.strip()
        basecase = basecase[basecase.selection.isin(answer_map.keys())].fillna("None")
        
        print(f"{model}: {Counter(df['selection'])}, {len(df_orig)}, {len(df)}")
        for category, color in zip(categories, colors):
            reg_data = get_basecase_locs(basecase, "Reference")
            for value, color in zip(personas[category], colors):
                df_value = df[df[category] == value]
                for i, inst in enumerate(df_value.instruction.unique()):
                    # Plot for each instruction (same proposition, different way of asking)
                    df_value_inst = df_value[df_value.instruction == inst]
                    df_value_inst = df_value_inst.assign(proposition_id=[q_map[q] for q in df_value_inst['proposition']])
                    # Find missing proposition ids (to preserve the order of propositions)
                    missing_ids = list(set(range(1,63)) - set(df_value_inst['proposition_id'].values))
                    df_value_inst.sort_values(by='proposition_id', inplace=True)
                    missing_df = pd.DataFrame({'proposition_id': missing_ids, 'selection': ['None']*len(missing_ids)})
                    # Fill in missing propositions with 'None'
                    answers = pd.concat([missing_df, df_value_inst], ignore_index=True)
                    answers.sort_values(by='proposition_id', inplace=True)
                    answers = answers[['proposition', 'selection']]
                    loc = get_values(answers)
                    reg_data.append([value, loc[0], loc[1]])
            reg_df = pd.DataFrame(reg_data, columns=['category', 'econ', 'pol'])
            econ_lm = ols(f'econ ~ C(category, Treatment(reference="Reference"))',
                     data=reg_df).fit()
            econ_param_values.update(extract_coefficients(econ_lm, reg_df))

            pol_lm = ols(f'pol ~ C(category, Treatment(reference="Reference"))',
                     data=reg_df).fit()
            pol_param_values.update(extract_coefficients(pol_lm, reg_df))

        
        labels_x = list(econ_param_values.keys())
        labels_x.remove("Reference")
        labels_x = ["Reference"] + ['female', 'male', 'non-binary', '18.0', '26.0', '48.0', '65.0', '81.0', 'Brazil', 'Denmark', 'India', 'South Korea', 'the USA', 'far left', 'mainstream left', 'mainstream right', 'far right', 'lower class', 'middle class', 'upper middle class', 'upper class']#list(labels_x)
        values = []
        labels = []
        colorbar_ticks = [["Left", "Right"], ['Lib', "Auth"]]
        colormaps = ["YlGnBu", "YlOrBr"]
        for k,param_values in enumerate([econ_param_values, pol_param_values]):
            values.append([])
            labels.append([])
            for name in labels_x:
                
                values[-1].append(param_values[name][0])
                
                lab = f"{param_values[name][0]:.2f}"
                
                if "Reference" in name:
                    labels[-1].append(f"({param_values[name][0]:.3f})")
                elif param_values[name][1] < 0.05:
                    labels[-1].append(f"{lab}")
                else:
                    labels[-1].append(f"NS")
        econ_values.append(values[0])
        pol_values.append(values[1])
        econ_labels.append(labels[0])
        pol_labels.append(labels[1])

    labels_x = [x[:-2] if '.0' in x else x for x in labels_x]
    plt.rc('axes', titlesize=28)  # fontsize of the axes title
    plt.rc('axes', labelsize=24)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=24)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=24) 
    fig, ax = plt.subplots(1, figsize=(30,7))
    ax.grid(False)
    generate_heatmap_single(econ_values, labels_x=labels_x, labels_y=models, labels_values=econ_labels, ax=ax, cmap=colormaps[0], colorbar_ticks=colorbar_ticks[0], title="Economics (x-axis)")
    #ax.get_xaxis().set_visible(False)
    plt.tight_layout()
    plt.savefig(f'../figures/regression/econ_{setting}_regression_adj.png')
    
    fig, ax = plt.subplots(1, figsize=(28,7))
    ax.grid(False)
    generate_heatmap_single(pol_values, labels_x=labels_x, labels_y=models, labels_values=pol_labels, ax=ax, cmap=colormaps[0], colorbar_ticks=colorbar_ticks[1], title="Authority (y-axis)")
    plt.tight_layout()
    plt.savefig(f'../figures/regression/pol_{setting}_regression_adj.png')
    

### Figures 5, 18, 19, 20: Robustness plots

#### Base case

In [None]:
twolabelmap = {'Strongly Disagree': 'Disagree', 'Strongly Agree': 'Agree', 'Agree': 'Agree', 'Disagree': 'Disagree', 'NaN': 'NaN', 'None': 'None'}

open_data_base_path = Path(f"{gen_data_loc}_base/open/")

all_open_base = []
for model_path in open_data_base_path.glob("*"):
    model_name = model_path.name.split(".")[0]
    model_df = pd.read_csv(model_path, index_col=0)
    all_open_base.append(model_df)

open_data_base = pd.concat(all_open_base, ignore_index=True)
open_data_base["selection_twolabel"] = open_data_base["selection"].map(twolabelmap).fillna("NaN")
open_data_base["setting"] = "open"

fourlabelmap = {'Strongly disagree': 'Strongly Disagree', 'Strongly Disagree': 'Strongly Disagree',\
                'Strongly agree': 'Strongly Agree', 'Strongly Agree': 'Strongly Agree',
                'Agree': 'Agree', 'Disagree': 'Disagree', 'NaN': 'NaN'}

closed_data_base_path = Path(f"{gen_data_loc}_base/closed/")

all_closed_base = []
for model_path in closed_data_base_path.glob("*"):
    model_name = model_path.name.split(".")[0]
    model_df = pd.read_csv(model_path, index_col=0)
    all_closed_base.append(model_df)

closed_data_base = pd.concat(all_closed_base, ignore_index=True)
closed_data_base["setting"] = "closed"
closed_data_base["selection_fourlabel"] = closed_data_base["selection"].map(fourlabelmap).fillna("NaN")

all_data = pd.concat([open_data_base, closed_data_base], ignore_index=True)
models = closed_data_base['model_name'].unique()

In [None]:
import shutil
base_robustness_dir = "../figures/robustness/base_robustness"

if os.path.exists(base_robustness_dir):
    shutil.rmtree(base_robustness_dir)
os.makedirs(base_robustness_dir)

answer_map = {'Strongly Disagree': 0, 'Disagree': 1, 'Agree': 2, 'Strongly Agree': 3, 'NaN':-2, 'None':-1}
questions = [l.strip() for l in open(f"{pct_data_loc}/political_compass_questions.txt", 'r').readlines()]
q_map = {q.strip():i for i,q in enumerate(questions, start=1)}

count_per_model = {}
for model in models:
    fig, ax = plt.subplots(2, 1, figsize=(9, 3), sharey=False, sharex=False)
    ax = ax.flatten()
    for i, setting in enumerate(['closed', 'open']):
        selected_df = all_data[(all_data['model_name'] == model) & (all_data['setting'] == setting)]
        selected_df.loc[:, 'proposition'] = selected_df['proposition'].str.strip()

        graph_df = pd.DataFrame(columns=['y_label', 'Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN'])
        count_per_question = {}
        for idx, q in enumerate(questions):
            # find the rows where proposition is equal to q and value_counts of selection_twolabel
            key = 'selection_fourlabel' if setting == 'closed' else 'selection_twolabel'
            q_rows = selected_df[selected_df["proposition"] == q][key].value_counts() / selected_df[selected_df["proposition"] == q][key].value_counts().sum() * 100
            count_per_question[q] = q_rows
            q_rows = pd.DataFrame(q_rows).T 
            q_rows['y_label'] = idx 
            for label in ['Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN']:
                if label not in q_rows.columns:
                    q_rows[label] = 0
            
            q_rows = q_rows[['y_label', 'Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN']]
            graph_df = pd.concat([graph_df, q_rows])

        count_per_model[model] = count_per_question
        graph_df = graph_df.set_index("y_label")
        graph_df.columns = pd.CategoricalIndex(graph_df.columns.values, ordered=True, categories=['Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN'])
        graph_df.sort_values(by=['Strongly Agree', 'Agree', 'Disagree', 'Strongly Disagree', 'NaN'], ascending=False, inplace=True)
        # store order of pct_ids
        breakpoint()
        if i == 0:
            pct_order = graph_df.index
        else:
            graph_df = graph_df.reindex(pct_order)
        # breakpoint()

        # plot as stacked bar chart, with specified bar colors, thin bars
        graph_df.plot.bar(stacked=True,color=["#FFA500", "#ffd88f", "#8fa5ff", "#4c6ffc", "#dcdcdc"], width=0.7, ax=ax[i], legend=False)

        # set x axis limits to be between 0 and 100
        ax[i].set_ylim(0, 100)
        ax[i].tick_params(axis='y', labelsize=16, rotation=0)

        # remove x-axis labels and tick labels
        ax[i].set_xlabel("")
        # set font size to xlabel


        # reduce size of x-axis tick labels and rotate them
        ax[i].tick_params(axis='x', labelsize=6, rotation=0)
        ax[i].set_title(f"{model} - {setting} domain", fontsize=16)
    
        fig.tight_layout()
        # reduce vertical space between plots
        plt.subplots_adjust(hspace=0.9)
        plt.savefig(f"{base_robustness_dir}/{model}_open_closed.png", dpi = 500)
    # close figure
    plt.close(fig)


#### All else

In [None]:
twolabelmap = {'Strongly Disagree': 'Disagree', 'Strongly Agree': 'Agree', 'Agree': 'Agree', 'Disagree': 'Disagree', 'NaN': 'NaN', 'None': 'None'}
open_data_path = Path(f"{gen_data_loc}/open/")

all_open = []
for model_path in open_data_path.glob("*"):
    model_name = model_path.name.split(".")[0]
    model_df = pd.read_csv(model_path)
    all_open.append(model_df)

open_data = pd.concat(all_open, ignore_index=True)
open_data["selection_twolabel"] = open_data["selection"].map(twolabelmap).fillna("NaN")
open_data["setting"] = "open"

fourlabelmap = {'Strongly disagree': 'Strongly Disagree', 'Strongly Disagree': 'Strongly Disagree',\
                'Strongly agree': 'Strongly Agree', 'Strongly Agree': 'Strongly Agree',
                'Agree': 'Agree', 'Disagree': 'Disagree', 'NaN': 'NaN'}

closed_data_path = Path(f"{gen_data_loc}/closed/")

all_closed = []
for model_path in closed_data_path.glob("*"):
    model_name = model_path.name.split(".")[0]
    model_df = pd.read_csv(model_path)
    all_closed.append(model_df)

closed_data = pd.concat(all_closed, ignore_index=True)
closed_data["setting"] = "closed"
closed_data["selection_fourlabel"] = closed_data["selection"].map(fourlabelmap).fillna("NaN")

all_data = pd.concat([open_data, closed_data], ignore_index=True)
models = closed_data['model_name'].unique()

In [None]:
import json

personas = json.load(open(f'{prompting_loc}/personas.json', 'r'))
personas['age'] = [float(i) for i in personas['age']]
personas = {i:j for i, j in personas.items() if i != 'party'}

categories = list(personas.keys())
categories_labels =  [key +"(" + str(len(val))+")" for key, val in personas.items()]

In [None]:
from tqdm import tqdm

allminusbase_robustness_dir = "../figures/robustness/all_demog_robustness"

if os.path.exists(allminusbase_robustness_dir):
    shutil.rmtree(allminusbase_robustness_dir)
os.makedirs(allminusbase_robustness_dir)

answer_map = {'Strongly Disagree': 0, 'Disagree': 1, 'Agree': 2, 'Strongly Agree': 3, 'NaN':-2, 'None':-1}
questions = [l.strip() for l in open(f"{pct_data_loc}/political_compass_questions.txt", 'r').readlines()]
q_map = {q.strip():i for i,q in enumerate(questions, start=1)}

count_per_model = {}
for model in tqdm(models):
    fig, ax = plt.subplots(2, 1, figsize=(9, 3), sharey=False, sharex=False)
    ax = ax.flatten()
    for category in categories:
        category_dir = f"{allminusbase_robustness_dir}/{category}"
        os.makedirs(category_dir, exist_ok=True)
        for i, setting in enumerate(['closed', 'open']):
            selected_df = all_data[(all_data['model_name'] == model) & (all_data['setting'] == setting) & (all_data[category] != "NaN")]
            selected_df.loc[:, 'proposition'] = selected_df['proposition'].str.strip()

            graph_df = pd.DataFrame(columns=['y_label', 'Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN'])
            count_per_question = {}
            for idx, q in enumerate(questions):
                # find the rows where proposition is equal to q and value_counts of selection_twolabel
                key = 'selection_fourlabel' if setting == 'closed' else 'selection_twolabel'
                q_rows = selected_df[selected_df["proposition"] == q][key].value_counts() / selected_df[selected_df["proposition"] == q][key].value_counts().sum() * 100
                q_rows = pd.DataFrame(q_rows).T 
                q_rows['y_label'] = idx 
                for label in ['Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN']:
                    if label not in q_rows.columns:
                        q_rows[label] = 0
            
                q_rows = q_rows[['y_label', 'Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN']]
                graph_df = pd.concat([graph_df, q_rows])

            count_per_model[model] = count_per_question
            graph_df = graph_df.set_index("y_label")
            graph_df.columns = pd.CategoricalIndex(graph_df.columns.values, ordered=True, categories=['Strongly Disagree', 'Disagree', 'Agree', 'Strongly Agree', 'NaN'])
            graph_df.sort_values(by=['Strongly Agree', 'Agree', 'Disagree', 'Strongly Disagree', 'NaN'], ascending=False, inplace=True)
            # store order of pct_ids
        
            if i == 0:
                pct_order = graph_df.index
            else:
                graph_df = graph_df.reindex(pct_order)

            # plot as stacked bar chart, with specified bar colors, thin bars
            graph_df.plot.bar(stacked=True,color=["#FFA500", "#ffd88f", "#8fa5ff", "#4c6ffc", "#dcdcdc"], width=0.7, ax=ax[i], legend=False)

            # set x axis limits to be between 0 and 100
            ax[i].set_ylim(0, 100)
            ax[i].tick_params(axis='y', labelsize=16, rotation=0)

            # remove x-axis labels and tick labels
            ax[i].set_xlabel("")
            # reduce size of x-axis tick labels and rotate them
            ax[i].tick_params(axis='x', labelsize=6, rotation=0)
            ax[i].set_title(f"{model} - {category} - {setting} domain", fontsize=16)
    
            fig.tight_layout()
            # reduce vertical space between plots
            plt.subplots_adjust(hspace=0.9)
            # plt.show()
            
            plt.savefig(f"{category_dir}/{model}_{category}_open_closed.png", dpi = 500)

    # close the figure
    plt.close(fig)

### Figure 6: Total variation distance plot

In [None]:
response_map = {'Strongly disagree': 'Disagree', 'Strongly Disagree': 'Disagree',
                'Disagree': 'Disagree','disagree': 'Disagree',
                'agree': 'Agree', 'Agree': 'Agree',
                'Strongly agree': "Agree", 'Strongly Agree': "Agree",
                'NaN': "NaN", 'None': "None"}
all_data['selection_twolabel'] = all_data['selection'].map(response_map)

In [None]:
import math

count_per_model = {}

model_cat_agree_q_count = {}
model_cat_disagree_q_count = {}
model_cat_nan_q_count = {}

for model in tqdm(models):

    cat_agree_q_count = {}
    cat_disagree_q_count = {}
    cat_nan_q_count = {}

    for category in categories:

        # remove nan from category
        unique_cat_values = all_data[category].unique()
        unique_cat_values = [x for x in unique_cat_values if not (isinstance(x, float) and math.isnan(x))]

        for cat_val in unique_cat_values:
            selected_df = all_data[(all_data['model_name'] == model) & (all_data[category].notnull()) & (all_data[category] == cat_val)]
        
            # filter all rows where category is equal to cat_val
            selected_df_cat = selected_df[selected_df[category] == cat_val]
            selected_df_cat.loc[:, 'proposition'] = selected_df_cat['proposition'].str.strip()
            
            agree_q_count = {}
            disagree_q_count = {}
            nan_q_count = {}

            for idx, q in enumerate(questions):
                # find the rows where proposition is equal to q and value_counts of selection_twolabel
                key = 'selection_twolabel'


                # count number of agree and disagree for closed and open setting each
                agreed_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "closed")][key].map(response_map).value_counts().get("Agree", 0) # 0 means no value found
                disagreed_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "closed")][key].map(response_map).value_counts().get("Disagree", 0)
                nan_q_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "closed")][key].map(response_map).isna().sum()
                sum_closed = agreed_count_closed + disagreed_count_closed + nan_q_count_closed

                agreed_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().get("Agree", 0)
                disagreed_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().get("Disagree", 0)
                nan_q_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().isna().sum()
                sum_open = agreed_count_open + disagreed_count_open + nan_q_count_open

                prob_agree_closed = round((agreed_count_closed/sum_closed), 2) if sum_closed > 0 else 0
                prob_disagree_closed = round((disagreed_count_closed/sum_closed), 2) if sum_closed > 0 else 0
                prob_nan_closed = round((nan_q_count_closed/sum_closed), 2) if sum_closed > 0 else 0

                prob_agree_open = round((agreed_count_open/sum_open), 2) if sum_open > 0 else 0
                prob_disagree_open = round((disagreed_count_open/sum_open), 2) if sum_open > 0 else 0
                prob_nan_open = round((nan_q_count_open/sum_open), 2) if sum_open > 0 else 0

                breakpoint()

                # absolute difference between agree in closed and open setting: prob of agree - prob of disagree
                agree_q_count[idx] = np.abs((prob_agree_closed - prob_agree_open))
                disagree_q_count[idx] = np.abs((prob_disagree_closed - prob_disagree_open))
                nan_q_count[idx] = np.abs((prob_nan_closed - prob_nan_open))

            cat_agree_q_count[str(category)+"_"+str(cat_val)] = agree_q_count
            cat_disagree_q_count[str(category)+"_"+str(cat_val)] = disagree_q_count
            cat_nan_q_count[str(category)+"_"+str(cat_val)] = nan_q_count

        # demo_cat_agree_q_count[category] = cat_agree_q_count
        # demo_cat_disagree_q_count[category] = cat_disagree_q_count
        # demo_cat_nan_q_count[category] = cat_nan_q_count

    model_cat_agree_q_count[model] = cat_agree_q_count
    model_cat_disagree_q_count[model] = cat_disagree_q_count
    model_cat_nan_q_count[model] = cat_nan_q_count

In [None]:
df_model_cat_agree_q = pd.DataFrame(model_cat_agree_q_count)
df_model_cat_disagree_q = pd.DataFrame(model_cat_disagree_q_count)
df_model_cat_nan_q = pd.DataFrame(model_cat_nan_q_count)

# concatenate agree and disagree dataframes such that agree, disagree, nan are in three different rows

df_model_cat_agree_q['type'] = 'agree'
df_model_cat_disagree_q['type'] = 'disagree'
df_model_cat_nan_q['type'] = 'nan'

# # df_model_cat_agree_q has key as categories. Keep it as a column
df_model_cat_agree_q['category'] = df_model_cat_agree_q.index
df_model_cat_disagree_q['category'] = df_model_cat_disagree_q.index
df_model_cat_nan_q['category'] = df_model_cat_nan_q.index

df_model_cat = pd.concat([df_model_cat_agree_q, df_model_cat_disagree_q, df_model_cat_nan_q], ignore_index=True)

In [None]:
response_map = {'Strongly disagree': 'Disagree', 'Strongly Disagree': 'Disagree',
                'Disagree': 'Disagree','disagree': 'Disagree',
                'agree': 'Agree', 'Agree': 'Agree',
                'Strongly agree': "Agree", 'Strongly Agree': "Agree",
                'NaN': "NaN", 'None': "None"}
# if selection_fourlabel is in response_map, then map to response_map value
# do only for setting==closed
all_data_base = pd.concat([open_data_base, closed_data_base], ignore_index=True)
all_data_base['selection_twolabel'] = all_data_base['selection'].map(response_map)

In [None]:
count_per_model = {}
model_base_agree_q_count = {}
model_base_disagree_q_count = {}
model_base_nan_q_count = {}

for model in tqdm(models):
    selected_df_cat = all_data_base[(all_data_base['model_name'] == model)]
    agree_q_count = {}; disagree_q_count = {}; nan_q_count = {}
    for idx, q in enumerate(questions):
        # find the rows where proposition is equal to q and value_counts of selection_twolabel
        key = 'selection_twolabel'

        # count number of agree and disagree for closed and open setting each
        agreed_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "closed")][key].map(response_map).value_counts().get("Agree", 0) # 0 means no value found
        disagreed_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "closed")][key].map(response_map).value_counts().get("Disagree", 0)
        nan_q_count_closed = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "closed")][key].map(response_map).isna().sum()
        sum_closed = agreed_count_closed + disagreed_count_closed + nan_q_count_closed

        agreed_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().get("Agree", 0)
        disagreed_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().get("Disagree", 0)
        nan_q_count_open = selected_df_cat[(selected_df_cat["proposition"] == q) & (selected_df_cat["model_name"] == model) & (selected_df_cat["setting"] == "open")][key].map(response_map).value_counts().isna().sum()
        sum_open = agreed_count_open + disagreed_count_open + nan_q_count_open

        prob_agree_closed = round((agreed_count_closed/sum_closed), 2) if sum_closed > 0 else 0
        prob_disagree_closed = round((disagreed_count_closed/sum_closed), 2) if sum_closed > 0 else 0
        prob_nan_closed = round((nan_q_count_closed/sum_closed), 2) if sum_closed > 0 else 0

        prob_agree_open = round((agreed_count_open/sum_open), 2) if sum_open > 0 else 0
        prob_disagree_open = round((disagreed_count_open/sum_open), 2) if sum_open > 0 else 0
        prob_nan_open = round((nan_q_count_open/sum_open), 2) if sum_open > 0 else 0

        # absolute difference between agree in closed and open setting: prob of agree - prob of disagree
        agree_q_count[idx] = np.abs((prob_agree_closed - prob_agree_open))
        disagree_q_count[idx] = np.abs((prob_disagree_closed - prob_disagree_open))
        nan_q_count[idx] = np.abs((prob_nan_closed - prob_nan_open))

        # demo_cat_agree_q_count[category] = cat_agree_q_count
        # demo_cat_disagree_q_count[category] = cat_disagree_q_count
        # demo_cat_nan_q_count[category] = cat_nan_q_count


    model_base_agree_q_count[model] = agree_q_count
    model_base_disagree_q_count[model] = disagree_q_count
    model_base_nan_q_count[model] = nan_q_count

In [None]:
df_model_base_agree_q = pd.DataFrame(model_base_agree_q_count)
df_model_base_disagree_q = pd.DataFrame(model_base_disagree_q_count)
df_model_base_nan_q = pd.DataFrame(model_base_nan_q_count)

# concatenate agree and disagree dataframes such that agree, disagree, nan are in three different rows

df_model_base_agree_q['type'] = 'agree'
df_model_base_disagree_q['type'] = 'disagree'
df_model_base_nan_q['type'] = 'nan'

# # df_model_cat_agree_q has key as categories. Keep it as a column
df_model_base_agree_q['category'] = "base"
df_model_base_disagree_q['category'] = "base"
df_model_base_nan_q['category'] = "base"

df_model_base = pd.concat([df_model_base_agree_q, df_model_base_disagree_q, df_model_base_nan_q], ignore_index=True)

In [None]:
data = pd.concat([df_model_cat, df_model_base], ignore_index=True)
sample_data = data[data['category'] == 'base']
sample_data["type"].unique()

In [None]:
def generate_heatmap_single(arrays, labels_x, labels_y, labels_values, ax, cmap, colorbar_ticks=None, title=None):

    sns.heatmap(arrays, annot=labels_values, ax=ax, cmap=cmap, fmt='', cbar=True, annot_kws={"fontsize":20},
               cbar_kws={"pad":0.01, 'shrink': 0.8}, mask=np.array(labels_values) == "NS")
    # TODO set colorbar labels
    if colorbar_ticks:
        ax.collections[0].colorbar.set_ticks([np.array(arrays).min(), np.array(arrays).max()], labels=colorbar_ticks)
    ax.set_yticks(ax.get_yticks(), labels=labels_y, rotation='horizontal')
    if labels_x != None:
        ax.set_xticks(ax.get_xticks(), labels=labels_x, rotation='vertical')
    else:
        ax.set_xticks([])
    ax.set_title(title)

In [None]:
new_personas = {k: sorted(v) for k,v in personas.items()}

if 'political_orientation' in new_personas:
    new_personas['political_orientation'] = ['far left', 'mainstream left', 'mainstream right', 'far right']
if 'cls' in new_personas:
    new_personas['cls'] = ['lower class', 'middle class', 'upper middle class', 'upper class']


In [None]:
settings_dict = {}

labels_y = []
labels_x = []
values = []
label_values = []

for model in models:
    if model=="gemma-7b-it":
        continue
    labels_y.append(model)
    values.append([])
    label_values.append([])
    all_categories = ['base'] + categories
    for cat in all_categories:
        # add base to personas
        # personas['base'] = ['Reference']
        new_personas['base'] = ['Reference']
        # for val in personas[cat]:
        for val in new_personas[cat]:
            if val not in labels_x:
                labels_x.append(val)
            type_df = data[data['category'] == f"{cat}_{val}"] if cat != 'base' else data[data['category'] == f"{cat}"]
            model_df = type_df[model]
            vectors = []
            # for each ans in agree, disagree, nan, calculate the mean of the 62 questions
            for ans in ['agree', 'disagree', 'nan']:
                if cat != 'base':
                    d = type_df[type_df['type'] == ans][model].iloc[0]
                else:
                    d = type_df[type_df['type'] == ans][model] 
                    # convert pandas series to list
                    d = d.to_list() if not d.empty else [0] * 62
                vectors.append([float(d[i]) for i in range(62)]) 
                
            # breakpoint() 
            tvd = np.sum(np.array(vectors) * 0.5, axis=0).mean()

            # tvd = np.sum(vectors * 0.5, axis=0).mean()
            values[-1].append(tvd)
            label_values[-1].append(f"{tvd:.3f}")
            

In [None]:
labels_x_map = {18.0:18, 26.0:26, 48.0:48, 65.0:65, 81.0:81}
labels_x_mapped = [labels_x_map.get(i, i) for i in labels_x]

In [None]:
plt.rc('axes', titlesize=28)  # fontsize of the axes title
plt.rc('axes', labelsize=24)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=24)  # fontsize of the tick labels
plt.rc('ytick', labelsize=24) 
fig, ax = plt.subplots(1, figsize=(30,7))
ax.grid(False)
colormaps = ["YlGnBu", "YlOrBr"]
generate_heatmap_single(values, labels_x=labels_x_mapped, labels_y=labels_y, labels_values=label_values, ax=ax, cmap=colormaps[0])
plt.tight_layout()
plt.title("Total Variation Distance ")
plt.savefig(f'../figures/robustness/tvd_all_aggregated.png', dpi=500)

## Tropes Figures

In [None]:
from collections import Counter, defaultdict
from util.plotting import BubbleChart, InteractiveBubbleChart
from matplotlib.lines import Line2D

tropes_file = '../data/tropes.csv'
trope_column = 'distilled_trope'
tropes_figures_dir = '../figures/tropes'
tropes_reports_dir = '../figures/tropes/reports'

if not os.path.exists(tropes_figures_dir):
    os.mkdir(tropes_figures_dir)
    
if not os.path.exists(tropes_reports_dir):
    os.mkdir(tropes_reports_dir)

### Figure 7: Jaccard distance

### Figure 8: Venn Diagrams

### Figure 9: Model overlap

### Figures 10 - 15: Bubble diagrams

In [None]:
def trope_weight(trope_df, orig_df, category=None):
    if category != None:
        orig_comp_df = orig_df[orig_df[category['name']] == category['value']]
        trope_comp_df = trope_df[trope_df[category['name']] == category['value']]
    else:
        orig_comp_df = orig_df
        trope_comp_df = trope_df
    trope_counter = Counter(trope_comp_df[trope_column])
    N = len(orig_comp_df)
    
    trope_weight = {i: v/N for i,v in trope_counter.items()}
    
    return trope_weight

In [None]:
# Open up the large trope dataset
orig_df = pd.read_csv(tropes_file)
# Get the trope categories
with open('../data/political_compass/question_category_mapping.json') as f:
    question_category_mapping = json.loads(f.read())
N = 30
maxlen=200

for model in list(orig_df['model_id'].unique()):
    model_id = model.split("/")[-1]
    print(model_id)
    tropes_df = orig_df[orig_df['model_id'] == model]
    
    tropes_df = tropes_df.dropna(subset=trope_column)

    t_to_i = {t:i for i,t in enumerate(tropes_df[trope_column])}

    top_tropes_by_question = defaultdict(list)
    for q in questions:
        tropes_curr = tropes_df[tropes_df['proposition'] == q]
        trope_count = Counter(tropes_curr[trope_column])
    #     if len(trope_count) == 0:
    #         top_tropes_by_question[q] = None
    #         continue
        top = list(sorted(trope_count.items(), key=lambda x: x[1], reverse=True))
        K = min(5, len(top))
        top_tropes_by_question[q] = [tropes_curr[tropes_curr[trope_column] == t[0]][trope_column].iloc[0] for t in top[:K]]


    question_category_mapping
    q_to_cat = {}
    q_to_color = {}
    for j in range(len(question_category_mapping)):
        for q in question_category_mapping[j]['questions']:
            q_to_cat[q] = j

    q_to_prop = {q:i for i,q in enumerate(questions)}

    # Get trope category mapping
    trope_cat_mapping = {}
    for trope in tropes_df[trope_column].unique():
        prop_to_count = {q_to_prop[q]: v for q,v in Counter(tropes_df[tropes_df[trope_column] == trope]['proposition']).items()}
        cat_to_count = defaultdict(int)
        for prop in prop_to_count:
            cat_to_count[q_to_cat[prop]] += prop_to_count[prop]
        top_cat = max(cat_to_count.items(), key=lambda x: x[1])[0]
        trope_cat_mapping[trope] = top_cat

    trope_coocurrence_map = defaultdict(dict)
    for trope1 in tropes_df[trope_column].unique():
        t1props = tropes_df[tropes_df[trope_column] == trope1]['proposition'].unique()
        for trope2 in tropes_df[trope_column].unique():
            if trope1 == trope2:
                continue

            t2props = tropes_df[tropes_df[trope_column] == trope2]['proposition'].unique()

            trope_coocurrence_map[trope1][trope2] = len(set(t1props) & set(t2props)) / len(set(t1props) | set(t2props))

    sorted_tropes = np.array([[tropes_df[tropes_df[trope_column] == t[0]][trope_column].iloc[0],t[1],t[0]] for t in sorted(trope_weight(tropes_df, orig_df).items()
                               , key=lambda x: x[1], reverse=True)])

    fname = str(N) if N > 0 else 'all'
    if N != -1:
        packed_tropes = sorted_tropes[:N]
    else:
        packed_tropes = sorted_tropes
    sns_colors = sns.color_palette('pastel')
    #np.random.shuffle(packed_tropes)

    sort = np.argsort([trope_cat_mapping[t] for t in packed_tropes[:,2]])
    packed_tropes = packed_tropes[sort]
    colors = np.array([sns_colors[trope_cat_mapping[t]] for t in packed_tropes[:,2]])#[sort]
    full_tropes = packed_tropes[:, 0]#[sort]
    tropes = []
    for t in full_tropes:
        if len(t) > maxlen:
            tropes.append(t[:maxlen] + '...')
        else:
            tropes.append(t)
    weight = packed_tropes[:, 1].astype(np.float32)#[sort]

    connections = []
    for i in range(len(packed_tropes)):
        for j in range(i+1, len(packed_tropes)):
            t1 = packed_tropes[i,2]
            t2 = packed_tropes[j,2]
            if trope_coocurrence_map[t1][t2] > 0.:
                connections.append([i,j,trope_coocurrence_map[t1][t2]])
    connections = np.array(connections)
    #weight = (weight / weight.sum()) * 100

    bubble_chart = BubbleChart(area=weight,
                               bubble_spacing=0.15)

    bubble_chart.collapse()

    fig, ax = plt.subplots(subplot_kw=dict(aspect="equal"), figsize=(30,30))
    bubble_chart.plot(
        ax, tropes, colors, connections, textsize=24
    )
    # bubble_chart.plot(
    #     ax, [""]*len(colors), colors, connections, textsize=24
    # )
    plt.rc('axes', titlesize=28)  # fontsize of the axes title
    legend_elements = [Line2D([0], [0], marker='o', color=c, label=cat['name'],
                              markerfacecolor=c, markersize=15) for c, cat in zip(sns_colors, question_category_mapping)]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=20)
    if N > 0:
        ax.set_title(f"Top {N} most common tropes for {model_id}")
    else:
        ax.set_title(f"All tropes")
    ax.axis("off")
    ax.relim()
    ax.autoscale_view()
    plt.tight_layout()
    plt.savefig(f'{tropes_figures_dir}/{model_id}_bubble_chart_{N}.png')

### Generate markdown reports

In [None]:
for model in list(orig_df['model_id'].unique()):
    model_id = model.split("/")[-1]
    print(model_id)
    tropes_df = orig_df[orig_df['model_id'] == model]
    
    t_to_i = {t:i for i,t in enumerate(tropes_df[trope_column])}

    title = "# <SETTING> Trope report"
    img = f"![Trope Graph]({model_id}_bubble_chart_{N}.png)"
    header = "## Tropes"
    trope_strs = []

    # Go from largest to smallest clusters
    for cluster_count in sorted(Counter(tropes_df[trope_column]).items(), key=lambda x: x[1], reverse=True):
        cluster = cluster_count[0]
        trope_text = f"### T{t_to_i[cluster]}: {cluster}"

        # Get the list of constituent sentences in a table
        table_strs = ["|Support|\n|---|"]
        table_strs.extend([f"|{sent.strip()}|" for sent in tropes_df[tropes_df[trope_column] == cluster]['sentences']])

        sentence_table = '\n'.join(table_strs)


        trope_strs.append(f"""
{trope_text}

{sentence_table}

""")

    combined_tropes = "---\n".join(trope_strs)
    md_str = f"""{title}

{img}
---
{header}

{combined_tropes}
"""

    with open(f"{tropes_reports_dir}/{model_id}.md", 'wt') as f:
        f.write(md_str)