In [1]:
import pandas as pd
# read in files
# Read UK_df.csv as pandas dataframe
original_UK_dialogue = pd.read_csv('UK_df.csv')
original_UK_politeness = pd.read_csv('UK_direct_df.csv')
original_UK_narrator = pd.read_csv('UK_narrator_df.csv')
original_US_dialogue = pd.read_csv('US_df.csv')
original_US_politeness = pd.read_csv('US_direct_df.csv')
original_US_narrator = pd.read_csv('US_narrator_df.csv')
dataframes = [original_UK_dialogue, original_UK_politeness, original_UK_narrator, original_US_dialogue, original_US_politeness, original_US_narrator]
def elim_outliers(df):
    # dropped Unnamed: 0 column
    df.drop(columns=['Unnamed: 0'], inplace=True)
    filtered_df = df.loc[(df['response'] > 95) | (df['response'] < 5)]
    for id in df['person_id'].unique():
        if len(filtered_df[filtered_df['person_id'] == id])/len(df[df['person_id'] == id])>0.8:
            df.drop(df[df['person_id'] == id].index, inplace=True)
    df['predicate Z-score'] = df.groupby(['person_id','predicate'])['response'].transform(lambda x: (x - x.mean()) / x.std())
    # if has_intensifier = no then change 'intensifier' to 'none'
    df.loc[df['has intensifier?'] == 'no', 'intensifier'] = 'none'
    return df
for i in range(len(dataframes)):
    dataframes[i] = elim_outliers(dataframes[i])
dialogue = pd.concat([dataframes[0], dataframes[3]])
politeness = pd.concat([dataframes[1], dataframes[4]])
UK_dialogue = dataframes[0]
US_dialogue = dataframes[3]
UK_politeness = dataframes[1]
US_politeness = dataframes[4]

# end of reading in data
#-------------------------------------------------------------------------------

# compute U_soc (social Utility)
U_soc_data = politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()
UK_U_soc_data = UK_politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()
US_U_soc_data = US_politeness.groupby(['intensifier','predicate'])['predicate Z-score'].mean().to_dict()

In [2]:
from memo import memo
import jax
import jax.numpy as np
from enum import IntEnum, auto

Define constants

In [3]:
epsilon = 0.1
infty = 10000000
utterences =list(U_soc_data.keys())
state_values = np.arange(-2.8,2.8,0.28)
S = np.arange(0,20,1)

Define params to iterate

In [4]:
costs = np.arange(0,5,1)
possible_soc_terms = np.arange(0,5,1)
possible_inf_terms = np.arange(0,5,1)
theta_to_test = np.arange(0,20,2)

Grid search (code from demo-rsa.py)

In [5]:
class W(IntEnum):  # utterance space
    # borings
    none_boring = auto(0)
    slightly_boring = auto()
    kind_of_boring = auto()
    quite_boring = auto()
    very_boring = auto()
    extremely_boring = auto()
    # concerneds
    none_concerned = auto()
    slightly_concerned = auto()
    kind_of_concerned = auto()
    quite_concerned = auto()
    very_concerned = auto()
    extremely_concerned = auto()
    # difficults
    none_difficult = auto()
    slightly_difficult = auto()
    kind_of_difficult = auto()
    quite_difficult = auto()
    very_difficult = auto()
    extremely_difficult = auto()
    # exhausteds
    none_exhausted = auto()
    slightly_exhausted = auto()
    kind_of_exhausted = auto()
    quite_exhausted = auto()
    very_exhausted = auto()
    extremely_exhausted = auto()
    # helpfuls
    none_helpful = auto()
    slightly_helpful = auto()
    kind_of_helpful = auto()
    quite_helpful = auto()
    very_helpful = auto()
    extremely_helpful = auto()
    # impressives
    none_impressive = auto()
    slightly_impressive = auto()
    kind_of_impressive = auto()
    quite_impressive = auto()
    very_impressive = auto()
    extremely_impressive = auto()
    # understandables
    none_understandable = auto()
    slightly_understandable = auto()
    kind_of_understandable = auto()
    quite_understandable = auto()
    very_understandable = auto()
    extremely_understandable = auto()


In [6]:
# Generate the mapping dictionary
U_soc_key_map = {w: (' '.join(w.name.split('_')[:-1]), w.name.split('_')[-1]) for w in W}
U_soc_array = np.array([U_soc_data[U_soc_key_map[W(i)]] for i in range(len(W))])
UK_U_soc_array = np.array([UK_U_soc_data[U_soc_key_map[W(i)]] for i in range(len(W))])
US_U_soc_array = np.array([US_U_soc_data[U_soc_key_map[W(i)]] for i in range(len(W))])

Construct measured_values which is an array of 42 arrays where the i'th entry is the values people reported for the i'th utterance

In [7]:
possible_literal_semantics = np.array(
    [
        np.concatenate([np.repeat(np.array([epsilon]),i),np.repeat(np.array([1]),20-i)])
        for i in range(20)
    ]
)
@jax.jit
def state_prior(s):
    return np.float32(np.exp(-state_values[s]**2/2))  # uninformative state prior doesn't matter that it doesn't add up to 1
@jax.jit
def U_soc(soc):
    return U_soc_array[soc]
@jax.jit
def is_costly(w):
    arr = [0, 1, 1, 1, 1, 1]*7
    return np.array(arr)[w]
@jax.jit
def L(w, s,t0,t1,t2,t3,t4,t5):  # literal likelihood L(w | s)
    intensifier_semantics = np.array([  # "hard semantics"
        possible_literal_semantics[t0],  # none
        possible_literal_semantics[t1],  # slightly 
        possible_literal_semantics[t2],  # kind of
        possible_literal_semantics[t3],  # quite
        possible_literal_semantics[t4],  # very
        possible_literal_semantics[t5],  # extremely
    ])
    return np.tile(intensifier_semantics, (7, 1))[w, s]
@memo
def L1[s: S, w: W](inf_term, soc_term, cost,t0,t1,t2,t3,t4,t5):
    listener: thinks[
        speaker: given(s in S, wpp=state_prior(s)),
        speaker: chooses(w in W, wpp=
            imagine[
                listener: knows(w),
                listener: chooses(s in S, wpp=L(w, s,t0,t1,t2,t3,t4,t5)) ,
                exp(inf_term * log(Pr[listener.s == s]) + 
                soc_term * U_soc(w) - # U_soc = listener's EU
                cost*is_costly(w)) # U_inf = listener's surprisal       
            ]
        )
    ]
    listener: observes[speaker.w] is w
    listener: chooses(s in S, wpp=Pr[speaker.s == s])
    return Pr[listener.s == s]

In [16]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = UK_dialogue[((UK_dialogue['intensifier'] == intensifier) & (UK_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)

[ 0 22 22 22 22 22  0 22 22 22 22 22  0 22 22 22 22 22  0 22 22 22 22 22
  0 22 22 22 22 22  0 22  0 22 22 22  0 22 22 22 22 22]


In [190]:
total_param_combos = len(costs)*len(possible_soc_terms)*len(possible_inf_terms)*len(theta_to_test)**3
in_repeat = total_param_combos//len(theta_to_test)
out_repeat = 1
# first_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
# second_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
# third_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
# out_repeat = out_repeat*len(theta_to_test)
# in_repeat = in_repeat//len(theta_to_test)
fourth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(theta_to_test)
fifth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(theta_to_test)
sixth_thetas = np.repeat(np.tile(theta_to_test, in_repeat),out_repeat)
out_repeat = out_repeat*len(theta_to_test)
in_repeat = in_repeat//len(costs)
seventh_costs = np.repeat(np.tile(costs, in_repeat),out_repeat)
out_repeat = out_repeat*len(costs)
in_repeat = in_repeat//len(possible_inf_terms)
eighth_infs = np.repeat(np.tile(possible_inf_terms, in_repeat),out_repeat)
out_repeat = out_repeat*len(possible_inf_terms)
in_repeat = in_repeat//len(possible_soc_terms)
ninth_socs = np.repeat(np.tile(possible_soc_terms, in_repeat),out_repeat)

In [191]:
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)

0 0 0
0 0 2
0 0 4
0 0 6
0 0 8
0 0 10
0 0 12
0 0 14
0 0 16
0 0 18
0 2 0
0 2 2
0 2 4
0 2 6
0 2 8
0 2 10
0 2 12
0 2 14
0 2 16
0 2 18
0 4 0
0 4 2
0 4 4
0 4 6
0 4 8
0 4 10
0 4 12
0 4 14
0 4 16
0 4 18
0 6 0
0 6 2
0 6 4
0 6 6
0 6 8
0 6 10
0 6 12
0 6 14
0 6 16
0 6 18
0 8 0
0 8 2
0 8 4
0 8 6
0 8 8
0 8 10
0 8 12
0 8 14
0 8 16
0 8 18
0 10 0
0 10 2
0 10 4
0 10 6
0 10 8
0 10 10
0 10 12
0 10 14
0 10 16
0 10 18
0 12 0
0 12 2
0 12 4
0 12 6
0 12 8
0 12 10
0 12 12
0 12 14
0 12 16
0 12 18
0 14 0
0 14 2
0 14 4
0 14 6
0 14 8
0 14 10
0 14 12
0 14 14
0 14 16
0 14 18
0 16 0
0 16 2
0 16 4
0 16 6
0 16 8
0 16 10
0 16 12
0 16 14
0 16 16
0 16 18
0 18 0
0 18 2
0 18 4
0 18 6
0 18 8
0 18 10
0 18 12
0 18 14
0 18 16
0 18 18
2 0 0
2 0 2
2 0 4
2 0 6
2 0 8
2 0 10
2 0 12
2 0 14
2 0 16
2 0 18
2 2 0
2 2 2
2 2 4
2 2 6
2 2 8
2 2 10
2 2 12
2 2 14
2 2 16
2 2 18
2 4 0
2 4 2
2 4 4
2 4 6
2 4 8
2 4 10
2 4 12
2 4 14
2 4 16
2 4 18
2 6 0
2 6 2
2 6 4
2 6 6
2 6 8
2 6 10
2 6 12
2 6 14
2 6 16
2 6 18
2 8 0
2 8 2
2 8 4
2 8 6
2 8 8
2 8 10
2 8

In [192]:
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('UK_flattened_output_big_eps.npy', all_output)
del all_output

Do the same for US

In [195]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = US_dialogue[((US_dialogue['intensifier'] == intensifier) & (US_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)

In [196]:
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('US_flattened_output_big_eps.npy', all_output)
del all_output

0 0 0
0 0 2
0 0 4
0 0 6
0 0 8
0 0 10
0 0 12
0 0 14
0 0 16
0 0 18
0 2 0
0 2 2
0 2 4
0 2 6
0 2 8
0 2 10
0 2 12
0 2 14
0 2 16
0 2 18
0 4 0
0 4 2
0 4 4
0 4 6
0 4 8
0 4 10
0 4 12
0 4 14
0 4 16
0 4 18
0 6 0
0 6 2
0 6 4
0 6 6
0 6 8
0 6 10
0 6 12
0 6 14
0 6 16
0 6 18
0 8 0
0 8 2
0 8 4
0 8 6
0 8 8
0 8 10
0 8 12
0 8 14
0 8 16
0 8 18
0 10 0
0 10 2
0 10 4
0 10 6
0 10 8
0 10 10
0 10 12
0 10 14
0 10 16
0 10 18
0 12 0
0 12 2
0 12 4
0 12 6
0 12 8
0 12 10
0 12 12
0 12 14
0 12 16
0 12 18
0 14 0
0 14 2
0 14 4
0 14 6
0 14 8
0 14 10
0 14 12
0 14 14
0 14 16
0 14 18
0 16 0
0 16 2
0 16 4
0 16 6
0 16 8
0 16 10
0 16 12
0 16 14
0 16 16
0 16 18
0 18 0
0 18 2
0 18 4
0 18 6
0 18 8
0 18 10
0 18 12
0 18 14
0 18 16
0 18 18
2 0 0
2 0 2
2 0 4
2 0 6
2 0 8
2 0 10
2 0 12
2 0 14
2 0 16
2 0 18
2 2 0
2 2 2
2 2 4
2 2 6
2 2 8
2 2 10
2 2 12
2 2 14
2 2 16
2 2 18
2 4 0
2 4 2
2 4 4
2 4 6
2 4 8
2 4 10
2 4 12
2 4 14
2 4 16
2 4 18
2 6 0
2 6 2
2 6 4
2 6 6
2 6 8
2 6 10
2 6 12
2 6 14
2 6 16
2 6 18
2 8 0
2 8 2
2 8 4
2 8 6
2 8 8
2 8 10
2 8

In [197]:
epsilon = 0.01

In [198]:
possible_literal_semantics = np.array(
    [
        np.concatenate([np.repeat(np.array([epsilon]),i),np.repeat(np.array([1]),20-i)])
        for i in range(20)
    ]
)
@jax.jit
def state_prior(s):
    return np.float32(np.exp(-state_values[s]**2/2))  # uninformative state prior doesn't matter that it doesn't add up to 1
@jax.jit
def U_soc(soc):
    return U_soc_array[soc]
@jax.jit
def is_costly(w):
    arr = [0, 1, 1, 1, 1, 1]*7
    return np.array(arr)[w]
@jax.jit
def L(w, s,t0,t1,t2,t3,t4,t5):  # literal likelihood L(w | s)
    intensifier_semantics = np.array([  # "hard semantics"
        possible_literal_semantics[t0],  # none
        possible_literal_semantics[t1],  # slightly 
        possible_literal_semantics[t2],  # kind of
        possible_literal_semantics[t3],  # quite
        possible_literal_semantics[t4],  # very
        possible_literal_semantics[t5],  # extremely
    ])
    return np.tile(intensifier_semantics, (7, 1))[w, s]
@memo
def L1[s: S, w: W](inf_term, soc_term, cost,t0,t1,t2,t3,t4,t5):
    listener: thinks[
        speaker: given(s in S, wpp=state_prior(s)),
        speaker: thinks[
            listener: thinks[
                speaker: given(s in S, wpp=state_prior(s)),
                speaker: chooses(w in W, wpp=L(w, s,t0,t1,t2,t3,t4,t5))
            ]
        ],
        speaker: chooses(w in W, wpp=exp( imagine[
            listener: observes [speaker.w] is w,
            listener: knows(s),
            (
                inf_term * listener[ log(Pr[speaker.s == s]) ] +  # U_inf = listener's surprisal
                soc_term * U_soc(w) - # U_soc = listener's EU
                cost*is_costly(w)  # cost of utterance
            )
        ]))
    ]
    listener: observes[speaker.w] is w
    listener: chooses(s in S, wpp=Pr[speaker.s == s])
    return Pr[listener.s == s]

In [199]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = UK_dialogue[((UK_dialogue['intensifier'] == intensifier) & (UK_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('UK_flattened_output_small_eps.npy', all_output)
del all_output

In [201]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = US_dialogue[((US_dialogue['intensifier'] == intensifier) & (US_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('US_flattened_output_small_eps.npy', all_output)
del all_output

Letting U_soc be country specific

In [None]:
@jax.jit
def U_soc(soc):
    return UK_U_soc_array[soc]
@memo
def L1[s: S, w: W](inf_term, soc_term, cost,t0,t1,t2,t3,t4,t5):
    listener: thinks[
        speaker: given(s in S, wpp=state_prior(s)),
        speaker: thinks[
            listener: thinks[
                speaker: given(s in S, wpp=state_prior(s)),
                speaker: chooses(w in W, wpp=L(w, s,t0,t1,t2,t3,t4,t5))
            ]
        ],
        speaker: chooses(w in W, wpp=exp( imagine[
            listener: observes [speaker.w] is w,
            listener: knows(s),
            (
                inf_term * listener[ log(Pr[speaker.s == s]) ] +  # U_inf = listener's surprisal
                soc_term * U_soc(w) - # U_soc = listener's EU
                cost*is_costly(w)  # cost of utterance
            )
        ]))
    ]
    listener: observes[speaker.w] is w
    listener: chooses(s in S, wpp=Pr[speaker.s == s])
    return Pr[listener.s == s]

In [None]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = UK_dialogue[((UK_dialogue['intensifier'] == intensifier) & (UK_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('specific_U_soc_UK_flattened_output_small_eps.npy', all_output)
del all_output

In [None]:
@jax.jit
def U_soc(soc):
    return US_U_soc_array[soc]
@memo
def L1[s: S, w: W](inf_term, soc_term, cost,t0,t1,t2,t3,t4,t5):
    listener: thinks[
        speaker: given(s in S, wpp=state_prior(s)),
        speaker: thinks[
            listener: thinks[
                speaker: given(s in S, wpp=state_prior(s)),
                speaker: chooses(w in W, wpp=L(w, s,t0,t1,t2,t3,t4,t5))
            ]
        ],
        speaker: chooses(w in W, wpp=exp( imagine[
            listener: observes [speaker.w] is w,
            listener: knows(s),
            (
                inf_term * listener[ log(Pr[speaker.s == s]) ] +  # U_inf = listener's surprisal
                soc_term * U_soc(w) - # U_soc = listener's EU
                cost*is_costly(w)  # cost of utterance
            )
        ]))
    ]
    listener: observes[speaker.w] is w
    listener: chooses(s in S, wpp=Pr[speaker.s == s])
    return Pr[listener.s == s]

In [None]:
# Create a list of JAX arrays
measured_values = []
for w in W:
    intensifier, predicate = U_soc_key_map[w]   
    raw_values = US_dialogue[((US_dialogue['intensifier'] == intensifier) & (US_dialogue['predicate'] == predicate))]['predicate Z-score'].values
    # measured_values.append(np.array([int(r/0.28)+10 for r in raw_values]))
    z = [int(r/0.28)+10 for r in raw_values]
    x = [0]*len(S)
    if intensifier != 'none': # make measured_values all zero if intensifier is none
        for i in z:
            x[i] += 1
    measured_values.append(x)
measured_values = np.array(measured_values)
def compute_logloss(*params):
    thetas = params[:6]
    cost = params[6]
    inf_term = params[7]
    soc_term = params[8]
    # compute fit e.g. log_likelihood
    P_l1 = L1(inf_term=inf_term, soc_term=soc_term, cost=cost, t0 = thetas[0],t1=thetas[1], t2= thetas[2], t3 = thetas[3], t4 = thetas[4], t5 = thetas[5]) # note this should be P(s|w)
    return np.sum(np.log(P_l1)*measured_values.T)
all_output = []
for t1 in theta_to_test:
    for t2 in theta_to_test:
        for t3 in theta_to_test:
            output = jax.vmap(compute_logloss,in_axes=(None,)*3+(0,)*6)(t1,t2,t3,fourth_thetas,fifth_thetas,sixth_thetas,seventh_costs,eighth_infs,ninth_socs).block_until_ready()
            all_output.append(output)
            print(t1,t2,t3)
all_output = np.array(all_output)
all_output = all_output.flatten()
# save all_output to a file
np.save('specific_U_soc_US_flattened_output_small_eps.npy', all_output)
del all_output