In [15]:
import os
import json
import numpy as np
import pandas as pd
from scipy.stats import wasserstein_distance
import helpers as ph
import seaborn as sns
import dataframe_image as dfi

COLOR_STR = "#0A3EA4,#4874F9,#84A0F5,#F1F4FB,#FFFFFF"
palette = sns.color_palette(f"blend:{COLOR_STR}", 12, as_cmap=True)

styles = ph.VIS_STYLES

In [16]:
RESULTS_DIR = f'./data/distributions/'
CONTEXT = 'default'
SAVEFIG = False

## Load human and LM opinion distributions

In [18]:
combined_df, human_df = [], []
for wave in ph.PEW_SURVEY_LIST:
    SURVEY_NAME = f'American_Trends_Panel_W{wave}'

    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))
    cdf['survey'] = f'ATP {wave}'
    combined_df.append(cdf)
    
    hdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_baseline.csv'))
    hdf['survey'] = f'ATP {wave}'
    human_df.append(hdf)
combined_df, human_df = pd.concat(combined_df), pd.concat(human_df)
combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',
                                          axis=1)

In [19]:
print('# Questions:', len(set(combined_df['question'])))

# Questions: 1498


## Compute average representativeness across dataset

In [20]:
KEYS = ['Source', 'model_name', 'attribute', 'group', 'group_order', 'model_order']

grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \
         .sort_values(by=['model_order', 'group_order'])
grouped['Rep'] = 1 - grouped['WD']

  grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \


### Overall representativeness

In [21]:
human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})
human_baseline['Rep'] = 1 - human_baseline['WD']
human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()
human_baseline['model_name'] = human_baseline.apply(lambda x: 'Avg' if x['index'] == 'mean' \
                                                    else 'Worst', axis=1)
human_baseline['model_order'] = -1
human_baseline['Source'] = "Humans"


g = pd.concat([human_baseline, grouped[grouped['attribute'] == 'Overall']]).rename(columns={'model_name': '',
                                                                                            'Rep': 'R'})

table = pd.pivot_table(g, 
                       columns=['Source', ''], 
                       values='R', 
                       sort=False)
table_vis = table.style.background_gradient(palette, axis=1).set_table_styles(styles)  \
                        .set_properties(**{"font-size":"0.75rem"}).format(precision=3)

#if SAVEFIG: table_vis.hide_index().export_png('./figures/representativeness.png')
display(table_vis)

  human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})
  human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()
  human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()


Source,Humans,Humans,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Unnamed: 0_level_1,Avg,Worst,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
R,0.949,0.866,0.813,0.816,0.804,0.824,0.791,0.707,0.715,0.762,0.701


### Subgroup representativeness

In [22]:
styles[-1]['props'][-1] = (styles[-1]['props'][-1][0], "105%")

In [23]:
for attribute in ph.DEMOGRAPHIC_ATTRIBUTES[1:]:
    
    print(f'-----{attribute}----')
    
    g = grouped[grouped['attribute'] == attribute].rename(columns={'model_name': 'Model', 'group': attribute,
                                                                  'Source': ''})

    table = pd.pivot_table(g, 
                           index=[attribute], 
                           columns=['', 'Model'], 
                           values="Rep", 
                           sort=False)
    table_vis = table.style.background_gradient(palette, axis=(attribute=='Overall')).set_table_styles(styles)  \
                            .set_properties(**{"font-size":"1.3rem"}).format(precision=3)
    #if SAVEFIG: table_vis.export_png(f'./figures/representativeness_{attribute}.png')

    display(table_vis)

-----CREGION----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
CREGION,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Northeast,0.807,0.811,0.802,0.819,0.788,0.706,0.714,0.764,0.704
Midwest,0.808,0.81,0.797,0.82,0.786,0.708,0.714,0.762,0.701
South,0.816,0.818,0.805,0.827,0.793,0.707,0.713,0.759,0.696
West,0.81,0.813,0.802,0.821,0.789,0.705,0.716,0.763,0.704


-----AGE----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
AGE,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
18-29,0.814,0.818,0.808,0.828,0.794,0.704,0.714,0.763,0.7
30-49,0.811,0.814,0.804,0.823,0.791,0.705,0.715,0.763,0.702
50-64,0.809,0.809,0.797,0.818,0.785,0.708,0.712,0.757,0.696
65+,0.791,0.792,0.779,0.8,0.77,0.704,0.708,0.752,0.699


-----SEX----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
SEX,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Male,0.813,0.814,0.802,0.826,0.79,0.706,0.712,0.762,0.697
Female,0.807,0.81,0.8,0.816,0.786,0.706,0.715,0.76,0.702


-----EDUCATION----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
EDUCATION,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Less than high school,0.827,0.828,0.812,0.835,0.8,0.71,0.715,0.749,0.685
High school graduate,0.817,0.816,0.799,0.826,0.79,0.711,0.712,0.755,0.691
"Some college, no degree",0.811,0.814,0.804,0.823,0.79,0.706,0.714,0.761,0.701
Associate's degree,0.809,0.811,0.8,0.821,0.789,0.703,0.713,0.76,0.7
College graduate/some postgrad,0.797,0.802,0.794,0.81,0.78,0.701,0.714,0.765,0.71
Postgraduate,0.788,0.794,0.789,0.8,0.774,0.695,0.713,0.766,0.717


-----CITIZEN----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
CITIZEN,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Yes,0.812,0.814,0.802,0.823,0.789,0.707,0.714,0.762,0.7
No,0.804,0.816,0.812,0.818,0.797,0.699,0.715,0.751,0.706


-----MARITAL----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
MARITAL,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Married,0.807,0.81,0.799,0.819,0.785,0.706,0.712,0.76,0.699
Divorced,0.809,0.809,0.796,0.817,0.785,0.709,0.714,0.76,0.696
Separated,0.808,0.814,0.801,0.818,0.786,0.705,0.714,0.753,0.694
Widowed,0.799,0.8,0.785,0.807,0.777,0.706,0.71,0.751,0.694
Never been married,0.815,0.819,0.808,0.828,0.795,0.708,0.716,0.766,0.7


-----RELIG----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RELIG,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Protestant,0.811,0.81,0.797,0.82,0.787,0.707,0.713,0.754,0.694
Roman Catholic,0.812,0.816,0.806,0.823,0.792,0.706,0.714,0.759,0.702
Mormon,0.788,0.789,0.777,0.802,0.769,0.697,0.707,0.751,0.696
Orthodox,0.769,0.773,0.762,0.781,0.752,0.685,0.698,0.73,0.693
Jewish,0.791,0.792,0.785,0.8,0.772,0.697,0.708,0.757,0.707
Muslim,0.784,0.794,0.788,0.792,0.774,0.682,0.705,0.728,0.697
Buddhist,0.77,0.782,0.777,0.783,0.764,0.681,0.702,0.746,0.709
Hindu,0.777,0.796,0.794,0.789,0.775,0.682,0.702,0.727,0.707
Atheist,0.772,0.774,0.771,0.784,0.759,0.687,0.707,0.765,0.714
Agnostic,0.782,0.785,0.781,0.794,0.767,0.696,0.714,0.77,0.717


-----RELIGATTEND----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RELIGATTEND,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
More than once a week,0.807,0.807,0.793,0.816,0.784,0.701,0.708,0.749,0.69
Once a week,0.809,0.811,0.798,0.819,0.787,0.703,0.712,0.753,0.696
Once or twice a month,0.814,0.818,0.807,0.825,0.795,0.704,0.713,0.756,0.699
A few times a year,0.812,0.817,0.809,0.824,0.794,0.707,0.716,0.759,0.705
Seldom,0.809,0.811,0.8,0.821,0.787,0.707,0.716,0.762,0.703
Never,0.804,0.806,0.795,0.816,0.782,0.705,0.713,0.767,0.701


-----POLPARTY----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
POLPARTY,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Republican,0.797,0.791,0.776,0.805,0.769,0.705,0.704,0.742,0.68
Democrat,0.792,0.8,0.795,0.804,0.781,0.696,0.714,0.762,0.719
Independent,0.809,0.812,0.801,0.821,0.788,0.706,0.715,0.763,0.701
Other,0.82,0.82,0.804,0.832,0.793,0.709,0.716,0.764,0.693


-----INCOME----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
INCOME,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
"Less than $30,000",0.824,0.828,0.813,0.833,0.801,0.709,0.717,0.757,0.693
"$30,000-$50,000",0.811,0.814,0.802,0.822,0.789,0.708,0.713,0.758,0.698
"$50,000-$75,000",0.804,0.807,0.796,0.816,0.784,0.705,0.713,0.762,0.703
"$75,000-$100,000",0.799,0.8,0.791,0.811,0.78,0.703,0.711,0.761,0.705
"$100,000 or more",0.794,0.797,0.79,0.807,0.777,0.698,0.71,0.764,0.708


-----POLIDEOLOGY----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
POLIDEOLOGY,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Very conservative,0.805,0.797,0.778,0.811,0.771,0.702,0.698,0.733,0.662
Conservative,0.8,0.796,0.78,0.81,0.773,0.707,0.707,0.747,0.684
Moderate,0.809,0.814,0.804,0.822,0.791,0.706,0.717,0.763,0.706
Liberal,0.786,0.792,0.788,0.799,0.774,0.696,0.716,0.767,0.721
Very liberal,0.78,0.785,0.782,0.791,0.768,0.688,0.709,0.76,0.712


-----RACE----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RACE,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
White,0.806,0.807,0.794,0.817,0.783,0.707,0.712,0.762,0.699
Black,0.813,0.82,0.812,0.823,0.796,0.7,0.714,0.753,0.702
Asian,0.806,0.814,0.806,0.819,0.792,0.697,0.715,0.755,0.708
Hispanic,0.812,0.82,0.81,0.824,0.797,0.703,0.717,0.755,0.706
Other,0.798,0.801,0.783,0.807,0.773,0.696,0.7,0.739,0.681
