In [1]:
import os
import json
import numpy as np
import pandas as pd
from scipy.stats import wasserstein_distance
import helpers as ph
import seaborn as sns
import dataframe_image as dfi

COLOR_STR = "#0A3EA4,#4874F9,#84A0F5,#F1F4FB,#FFFFFF"
palette = sns.color_palette(f"blend:{COLOR_STR}", 12, as_cmap=True)

styles = ph.VIS_STYLES

In [2]:
RESULTS_DIR = f'./data/distributions/'
CONTEXT = 'default'
SAVEFIG = False

## Load human and LM opinion distributions

In [3]:
combined_df, human_df = [], []
for wave in ph.PEW_SURVEY_LIST:
    SURVEY_NAME = f'American_Trends_Panel_W{wave}'

    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))
    cdf['survey'] = f'ATP {wave}'
    combined_df.append(cdf)
    
    hdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_baseline.csv'))
    hdf['survey'] = f'ATP {wave}'
    human_df.append(hdf)
combined_df, human_df = pd.concat(combined_df), pd.concat(human_df)
combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',
                                          axis=1)

In [4]:
print('# Questions:', len(set(combined_df['question'])))

# Questions: 1498


## Compute average representativeness across dataset

In [5]:
KEYS = ['Source', 'model_name', 'attribute', 'group', 'group_order', 'model_order']

grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \
         .sort_values(by=['model_order', 'group_order'])
grouped['Rep'] = 1 - grouped['WD']

  grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \


### Overall representativeness

In [6]:
human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})
human_baseline['Rep'] = 1 - human_baseline['WD']
human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()
human_baseline['model_name'] = human_baseline.apply(lambda x: 'Avg' if x['index'] == 'mean' \
                                                    else 'Worst', axis=1)
human_baseline['model_order'] = -1
human_baseline['Source'] = "Humans"


g = pd.concat([human_baseline, grouped[grouped['attribute'] == 'Overall']]).rename(columns={'model_name': '',
                                                                                            'Rep': 'R'})

table = pd.pivot_table(g, 
                       columns=['Source', ''], 
                       values='R', 
                       sort=False)
table_vis = table.style.background_gradient(palette, axis=1).set_table_styles(styles)  \
                        .set_properties(**{"font-size":"0.75rem"}).format(precision=3)

if SAVEFIG: table_vis.hide_index().export_png('./figures/representativeness.png')
display(table_vis)

  human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})
  human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()
  human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()


Source,Humans,Humans,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Unnamed: 0_level_1,Avg,Worst,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
R,0.949,0.867,0.814,0.817,0.806,0.826,0.793,0.71,0.717,0.764,0.703


### Subgroup representativeness

In [7]:
styles[-1]['props'][-1] = (styles[-1]['props'][-1][0], "105%")

In [8]:
for attribute in ph.DEMOGRAPHIC_ATTRIBUTES[1:]:
    
    print(f'-----{attribute}----')
    
    g = grouped[grouped['attribute'] == attribute].rename(columns={'model_name': 'Model', 'group': attribute,
                                                                  'Source': ''})

    table = pd.pivot_table(g, 
                           index=[attribute], 
                           columns=['', 'Model'], 
                           values="Rep", 
                           sort=False)
    table_vis = table.style.background_gradient(palette, axis=(attribute=='Overall')).set_table_styles(styles)  \
                            .set_properties(**{"font-size":"1.3rem"}).format(precision=3)
    if SAVEFIG: table_vis.export_png(f'./figures/representativeness_{attribute}.png')

    display(table_vis)

-----CREGION----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
CREGION,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Northeast,0.809,0.812,0.803,0.821,0.789,0.708,0.716,0.766,0.707
Midwest,0.81,0.811,0.799,0.822,0.788,0.71,0.717,0.764,0.703
South,0.818,0.82,0.807,0.828,0.795,0.71,0.715,0.761,0.698
West,0.811,0.815,0.804,0.823,0.791,0.707,0.718,0.765,0.706


-----AGE----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
AGE,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
18-29,0.816,0.819,0.809,0.829,0.796,0.707,0.716,0.765,0.703
30-49,0.812,0.816,0.806,0.824,0.792,0.707,0.717,0.765,0.705
50-64,0.811,0.811,0.799,0.819,0.787,0.71,0.714,0.759,0.699
65+,0.793,0.794,0.781,0.802,0.772,0.706,0.71,0.754,0.701


-----SEX----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
SEX,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Male,0.814,0.816,0.803,0.827,0.792,0.708,0.715,0.764,0.7
Female,0.808,0.811,0.801,0.818,0.788,0.708,0.717,0.761,0.704


-----EDUCATION----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
EDUCATION,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Less than high school,0.829,0.83,0.813,0.837,0.802,0.712,0.717,0.751,0.688
High school graduate,0.818,0.817,0.801,0.827,0.792,0.713,0.715,0.757,0.693
"Some college, no degree",0.813,0.815,0.805,0.824,0.792,0.708,0.717,0.763,0.703
Associate's degree,0.811,0.812,0.801,0.822,0.79,0.705,0.715,0.762,0.702
College graduate/some postgrad,0.798,0.803,0.795,0.811,0.782,0.704,0.716,0.767,0.713
Postgraduate,0.789,0.795,0.79,0.802,0.776,0.697,0.715,0.768,0.719


-----CITIZEN----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
CITIZEN,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Yes,0.813,0.815,0.804,0.824,0.791,0.709,0.716,0.764,0.703
No,0.806,0.818,0.813,0.819,0.799,0.702,0.718,0.753,0.708


-----MARITAL----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
MARITAL,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Married,0.809,0.812,0.801,0.821,0.787,0.708,0.715,0.762,0.702
Divorced,0.811,0.811,0.798,0.819,0.787,0.711,0.717,0.762,0.699
Separated,0.81,0.816,0.803,0.819,0.788,0.707,0.716,0.755,0.697
Widowed,0.801,0.801,0.786,0.809,0.779,0.709,0.712,0.753,0.697
Never been married,0.816,0.821,0.809,0.829,0.796,0.71,0.718,0.768,0.702


-----RELIG----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RELIG,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Protestant,0.812,0.812,0.798,0.822,0.789,0.71,0.715,0.756,0.697
Roman Catholic,0.813,0.818,0.807,0.825,0.794,0.709,0.717,0.761,0.704
Mormon,0.79,0.791,0.778,0.803,0.771,0.699,0.709,0.753,0.699
Orthodox,0.771,0.775,0.764,0.783,0.754,0.687,0.701,0.732,0.696
Jewish,0.793,0.794,0.787,0.802,0.774,0.699,0.71,0.759,0.709
Muslim,0.786,0.795,0.79,0.794,0.775,0.684,0.707,0.731,0.7
Buddhist,0.772,0.784,0.778,0.785,0.766,0.683,0.704,0.748,0.712
Hindu,0.778,0.798,0.795,0.791,0.777,0.684,0.705,0.729,0.709
Atheist,0.774,0.776,0.773,0.786,0.761,0.69,0.709,0.767,0.716
Agnostic,0.783,0.787,0.782,0.796,0.768,0.698,0.717,0.772,0.719


-----RELIGATTEND----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RELIGATTEND,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
More than once a week,0.809,0.808,0.794,0.818,0.785,0.704,0.711,0.751,0.692
Once a week,0.81,0.812,0.8,0.82,0.789,0.706,0.714,0.755,0.699
Once or twice a month,0.816,0.819,0.809,0.826,0.797,0.707,0.716,0.758,0.702
A few times a year,0.813,0.818,0.81,0.825,0.796,0.71,0.718,0.761,0.708
Seldom,0.811,0.813,0.802,0.822,0.789,0.709,0.718,0.764,0.705
Never,0.806,0.807,0.797,0.817,0.784,0.708,0.715,0.769,0.703


-----POLPARTY----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
POLPARTY,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Republican,0.798,0.793,0.778,0.807,0.77,0.707,0.706,0.744,0.682
Democrat,0.794,0.801,0.797,0.805,0.783,0.698,0.717,0.764,0.721
Independent,0.811,0.813,0.802,0.823,0.79,0.709,0.717,0.765,0.704
Other,0.822,0.821,0.805,0.833,0.795,0.711,0.718,0.766,0.695


-----INCOME----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
INCOME,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
"Less than $30,000",0.826,0.829,0.815,0.835,0.803,0.711,0.719,0.759,0.695
"$30,000-$50,000",0.813,0.816,0.804,0.823,0.791,0.71,0.716,0.76,0.701
"$50,000-$75,000",0.806,0.808,0.797,0.818,0.785,0.707,0.715,0.764,0.705
"$75,000-$100,000",0.801,0.802,0.793,0.813,0.782,0.705,0.714,0.763,0.707
"$100,000 or more",0.796,0.799,0.792,0.809,0.779,0.701,0.712,0.766,0.711


-----POLIDEOLOGY----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
POLIDEOLOGY,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Very conservative,0.806,0.798,0.78,0.813,0.773,0.704,0.7,0.735,0.665
Conservative,0.802,0.798,0.782,0.811,0.774,0.709,0.709,0.749,0.686
Moderate,0.811,0.815,0.805,0.823,0.793,0.709,0.719,0.764,0.708
Liberal,0.788,0.794,0.79,0.8,0.776,0.698,0.718,0.768,0.723
Very liberal,0.782,0.787,0.784,0.793,0.77,0.691,0.711,0.762,0.714


-----RACE----


Unnamed: 0_level_0,AI21 Labs,AI21 Labs,AI21 Labs,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI,OpenAI
Model,j1-grande,j1-jumbo,j1-grande-v2-beta,ada,davinci,text-ada-001,text-davinci-001,text-davinci-002,text-davinci-003
RACE,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
White,0.807,0.808,0.796,0.818,0.785,0.71,0.715,0.763,0.701
Black,0.814,0.822,0.813,0.824,0.798,0.703,0.717,0.755,0.704
Asian,0.808,0.815,0.807,0.82,0.794,0.699,0.717,0.757,0.71
Hispanic,0.814,0.821,0.812,0.826,0.799,0.705,0.719,0.757,0.708
Other,0.8,0.803,0.785,0.808,0.775,0.699,0.703,0.742,0.683
