In [157]:
import pandas as pd
# read in files
# Read UK_df.csv as pandas dataframe
original_UK_dialogue = pd.read_csv('UK_df.csv')
original_UK_politeness = pd.read_csv('UK_direct_df.csv')
original_UK_narrator = pd.read_csv('UK_narrator_df.csv')
original_US_dialogue = pd.read_csv('US_df.csv')
original_US_politeness = pd.read_csv('US_direct_df.csv')
original_US_narrator = pd.read_csv('US_narrator_df.csv')
dataframes = [original_UK_dialogue, original_UK_politeness, original_UK_narrator, original_US_dialogue, original_US_politeness, original_US_narrator]
def elim_outliers(df):
    # dropped Unnamed: 0 column
    df.drop(columns=['Unnamed: 0'], inplace=True)
    filtered_df = df.loc[(df['response'] > 95) | (df['response'] < 5)]
    for id in df['person_id'].unique():
        if len(filtered_df[filtered_df['person_id'] == id])/len(df[df['person_id'] == id])>0.8:
            df.drop(df[df['person_id'] == id].index, inplace=True)
    df['predicate Z-score'] = df.groupby(['person_id','predicate'])['response'].transform(lambda x: (x - x.mean()) / x.std())
    # if has_intensifier = no then change 'intensifier' to 'none'
    df.loc[df['has intensifier?'] == 'no', 'intensifier'] = 'none'
    return df
for i in range(len(dataframes)):
    dataframes[i] = elim_outliers(dataframes[i])
dialogue = pd.concat([dataframes[0], dataframes[3]])
politeness = pd.concat([dataframes[1], dataframes[4]])
UK_dialogue = dataframes[0]
US_dialogue = dataframes[3]
UK_politeness = dataframes[1]
US_politeness = dataframes[4]

# end of reading in data
#-------------------------------------------------------------------------------

# compute U_soc (social Utility)
U_soc_data = politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()
UK_U_soc_data = UK_politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()
US_U_soc_data = US_politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()

In [158]:
from memo import memo
import jax
import jax.numpy as np
from enum import IntEnum, auto

Define constants

In [159]:
epsilon = 0.1
infty = 10000000
utterences =list(U_soc_data.keys())
state_values = np.linspace(-2.8,2.8,20)
S = np.arange(0,20,1)

Define params to iterate

In [160]:
costs = np.linspace(0,4,4)
possible_soc_terms = np.linspace(0,3,5)
possible_inf_terms = np.linspace(0,3,5)
theta_to_test = np.arange(0,20,2)

Grid search (code from demo-rsa.py)

In [161]:
class W(IntEnum):  # utterance space
    # borings
    none_boring = auto(0)
    slightly_boring = auto()
    kind_of_boring = auto()
    quite_boring = auto()
    very_boring = auto()
    extremely_boring = auto()
    # concerneds
    none_concerned = auto()
    slightly_concerned = auto()
    kind_of_concerned = auto()
    quite_concerned = auto()
    very_concerned = auto()
    extremely_concerned = auto()
    # difficults
    none_difficult = auto()
    slightly_difficult = auto()
    kind_of_difficult = auto()
    quite_difficult = auto()
    very_difficult = auto()
    extremely_difficult = auto()
    # exhausteds
    none_exhausted = auto()
    slightly_exhausted = auto()
    kind_of_exhausted = auto()
    quite_exhausted = auto()
    very_exhausted = auto()
    extremely_exhausted = auto()
    # helpfuls
    none_helpful = auto()
    slightly_helpful = auto()
    kind_of_helpful = auto()
    quite_helpful = auto()
    very_helpful = auto()
    extremely_helpful = auto()
    # impressives
    none_impressive = auto()
    slightly_impressive = auto()
    kind_of_impressive = auto()
    quite_impressive = auto()
    very_impressive = auto()
    extremely_impressive = auto()
    # understandables
    none_understandable = auto()
    slightly_understandable = auto()
    kind_of_understandable = auto()
    quite_understandable = auto()
    very_understandable = auto()
    extremely_understandable = auto()


In [162]:
# Generate the mapping dictionary
U_soc_key_map = {w: (' '.join(w.name.split('_')[:-1]), w.name.split('_')[-1]) for w in W}
U_soc_array = np.array([U_soc_data[U_soc_key_map[W(i)]] for i in range(len(W))])

Construct measured_values which is an array of 42 arrays where the i'th entry is the values people reported for the i'th utterance

In [163]:
possible_literal_semantics = np.array(
    [
        np.concatenate([np.repeat(np.array([epsilon]),i),np.repeat(np.array([1]),20-i)])
        for i in range(20)
    ]
)
@jax.jit
def state_prior(s):
    return np.float32(np.exp(-state_values[s]**2/2))  # uninformative state prior doesn't matter that it doesn't add up to 1
@jax.jit
def U_soc(soc):
    return U_soc_array[soc]
@jax.jit
def is_costly(w):
    arr = [0, 1, 1, 1, 1, 1]*7
    return np.array(arr)[w]
@jax.jit
def L(w, s,t0,t1,t2,t3,t4,t5):  # literal likelihood L(w | s)
    intensifier_semantics = np.array([  # "hard semantics"
        possible_literal_semantics[t0],  # none
        possible_literal_semantics[t1],  # slightly 
        possible_literal_semantics[t2],  # kind of
        possible_literal_semantics[t3],  # quite
        possible_literal_semantics[t4],  # very
        possible_literal_semantics[t5],  # extremely
    ])
    return np.tile(intensifier_semantics, (7, 1))[w, s]
@memo
def L1[s: S, w: W](inf_term, soc_term, cost,t0,t1,t2,t3,t4,t5):
    listener: thinks[
        speaker: given(s in S, wpp=state_prior(s)),
        speaker: thinks[
            listener: thinks[
                speaker: given(s in S, wpp=state_prior(s)),
                speaker: chooses(w in W, wpp=L(w, s,t0,t1,t2,t3,t4,t5))
            ]
        ],
        speaker: chooses(w in W, wpp=exp( imagine[
            listener: observes [speaker.w] is w,
            listener: knows(s),
            (
                inf_term * listener[ log(Pr[speaker.s == s]) ] +  # U_inf = listener's surprisal
                soc_term * U_soc(w) - # U_soc = listener's EU
                cost*is_costly(w)  # cost of utterance
            )
        ]))
    ]
    listener: observes[speaker.w] is w
    listener: chooses(s in S, wpp=Pr[speaker.s == s])
    return Pr[listener.s == s]

In [164]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = UK_dialogue[((UK_dialogue['intensifier'] == intensifier) & (UK_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    for i in z:
        x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)

In [165]:
total_param_combos = len(costs)*len(possible_soc_terms)*len(possible_inf_terms)*len(theta_to_test)**3
in_repeat = total_param_combos//len(theta_to_test)
out_repeat = 1
# first_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
# second_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
# third_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
fourth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(theta_to_test)
fifth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(theta_to_test)
sixth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(costs)
seventh_costs = np.repeat(np.tile(costs, in_repeat),out_repeat)
out_repeat = out_repeat*len(costs)
in_repeat = in_repeat//len(possible_inf_terms)
eighth_infs = np.repeat(np.tile(possible_inf_terms, in_repeat),out_repeat)
out_repeat = out_repeat*len(possible_inf_terms)
in_repeat = in_repeat//len(possible_soc_terms)
ninth_socs = np.repeat(np.tile(possible_soc_terms, in_repeat),out_repeat)

In [166]:
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)

0 0 0
0 0 2
0 0 4
0 0 6
0 0 8
0 0 10
0 0 12
0 0 14
0 0 16
0 0 18
0 2 0
0 2 2


KeyboardInterrupt: 

In [135]:
all_output = np.array(all_output)

In [122]:
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('UK_flattened_output_big_eps.npy', all_output)

In [136]:
# all_output = all_output.reshape(len(theta_to_test),len(theta_to_test),len(theta_to_test),len(possible_soc_terms),len(possible_inf_terms),len(costs),len(theta_to_test),len(theta_to_test),len(theta_to_test)).transpose(0,1,2,8,7,6,5,4,3)

In [156]:
# for cost_i in range(len(costs)):
#     for inf_i in range(len(possible_inf_terms)):
#         for soc_i in range(len(possible_soc_terms)):
#            print(np.argmax(UK_output[:,:,:,:,:,:,cost_i,inf_i,soc_i]),np.max(UK_output[:,:,:,:,:,:,cost_i,inf_i,soc_i]))

-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-4324.632 -3615.312
-4303.101 -3619.4585
-4290.9717 -3622.411
-4294.791 -3624.4722
-4308.218 -3625.9976
-5152.9663 -3672.52
-5112.2715 -3670.7334
-5089.3447 -3670.5178
-5075.5537 -3670.7263
-5069.44 -3671.2114
-6094.6553 -3703.5725
-6044.421 -3701.9333
-6016.3276 -3702.767
-5999.4873 -3701.9702
-5988.998 -3701.405
-7062.5957 -3753.949
-7008.8525 -3769.2195
-6979.0547 -3771.3413
-6966.913 -3769.417
-7008.6724 -3763.311
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-3798.5068 -3798.5068
-4220.142 -3648.103
-4201.9697 -3659.3066
-4195.8496 -3663.6694
-4196.4883 -3663.715
-4201.1343 -3660.8994
-4923.6816 -3674.3572
-4891.9424 -3674.6611
-4881.4507 -3674.4844
-4882.4956 -3673.8408
-4890.7104 -3673.3042
-5814.414 -3704.1997
-5775.1675 -3714.152
-5762.089 -3716.3662
-5763.3945 -3712.8418
-5773.635 -3707.6987
-6767.6143 -3766.0618
-6726.537 -37