In [None]:
import numpy as np
import matplotlib.pyplot as plt
from operator import add

np.set_printoptions(precision=3, suppress=True)

In [None]:
class RSAModel():
    def __init__(self, objects, messages, truth_table, alpha, prior_o, cost_function):
        if np.shape(truth_table) != (len(messages), len(objects)):
            raise ValueError("""Truth matrix must be m x n, 
            where m is number of messages and n is the number of objects""")
        self.objects = objects
        self.messages = messages
        self.truth_table = truth_table # Literal meaning L(m,o)
        self.alpha = alpha
        self.prior_o = np.array(prior_o)
        self.cost_function = cost_function
        
        self.n_obj = len(self.objects)
        self.n_mes = len(self.messages)
    
    def normalize(self, arr, axis=1):
        epsilon = 1e-10 # To prevent division by zero
        if axis == 1: # Normalize rows
            row_sums = arr.sum(axis=1)[:, np.newaxis]
            return arr / (row_sums + epsilon)
        elif axis == 0: # Normalize columns
            col_sums = arr.sum(axis=0)[np.newaxis, :]
            return arr / (col_sums + epsilon)
        else:
            raise ValueError("Axis must be 0 or 1 for normalization.")

    # Literal Listener L0: P_L0(o|m) ∝ P(o) * Lit(m,o)
    # Lit(m,o) is 1 if message m is true of object o, 0 otherwise (from truth_table)
    def L0(self):
        # P(o) repeated for each message
        prior_o_repeated = np.tile(self.prior_o, (self.n_mes, 1)) # Shape (n_mes, n_obj)
        # P(o) * Lit(m,o)
        unnorm_L0 = self.truth_table * prior_o_repeated # Element-wise. Shape (n_mes, n_obj)
        # Normalize over objects o for each message m (normalize rows)
        norm_L0_o_given_m = self.normalize(unnorm_L0, axis=1)
        return norm_L0_o_given_m # P_L0(o|m)

    # Pragmatic Speaker S1: P_S1(m|o) ∝ exp(alpha * (log P_L0(o|m) - C(m)))
    # Utility U(m;o) = log P_L0(o|m) - C(m)
    def S1(self):
        epsilon = 1e-10
        
        # Get P_L0(o|m) from L0()
        L0_o_given_m = self.L0() # Shape (n_mes, n_obj)
        
        # Calculate log P_L0(o|m)
        log_L0_o_given_m = np.log(L0_o_given_m + epsilon) # Shape (n_mes, n_obj)
        
        # Costs C(m)
        message_costs = np.array(self.cost_function())[:, np.newaxis] # Shape (n_mes, 1)
        
        # Utility U(m;o) = log P_L0(o|m) - C(m)
        utility = log_L0_o_given_m - message_costs # Broadcasting. Shape (n_mes, n_obj)
        
        # S1(m|o) ∝ exp(alpha * U(m;o))
        # This should be normalized over messages m for each object o.
        # So we need exp(alpha * utility.T) and then normalize its rows.
        unnorm_S1_m_given_o_T = np.exp(self.alpha * utility.T) # Shape (n_obj, n_mes)
        norm_S1_m_given_o = self.normalize(unnorm_S1_m_given_o_T, axis=1)
        
        return norm_S1_m_given_o # P_S1(m|o)

    # Pragmatic Listener L1: P_L1(o|m) ∝ P_S1(m|o) * P(o)
    def L1(self):
        # P_S1(m|o) from S1()
        S1_m_given_o = self.S1() # Shape (n_obj, n_mes)
        
        # P(o)
        # prior_o_array = self.prior_o # Shape (n_obj,)
        
        # P_L1(o|m) ∝ P_S1(m|o)P(o)
        # We want the result to be (n_mes, n_obj), normalized over objects o for each message m.
        # P_S1(m|o) * P(o) term: S1_m_given_o * self.prior_o[:, np.newaxis]
        # This gives a matrix of shape (n_obj, n_mes) where element (o,m) is P_S1(m|o)P(o).
        # To get P_L1(o|m), we need to consider this as P_S1(m|o)P(o) for each (m,o) pair,
        # and then for a given m, normalize over o.
        # So, we are looking for (P_S1(m|o)P(o))^T, then normalize rows.
        
        numerator_L1 = (S1_m_given_o * self.prior_o[:, np.newaxis]).T # Shape (n_mes, n_obj)
                                                                     # Row i is P(m_i|o_j)P(o_j) for all j
        norm_L1_o_given_m = self.normalize(numerator_L1, axis=1) # Normalize rows
        
        return norm_L1_o_given_m # P_L1(o|m)

    def plot(self, filename=None, title_suffix=''):
        if filename is not None and not isinstance(filename, str):
            raise TypeError("If a filename is provided, it must be a string")
        
        L1_probs = self.L1() # P_L1(o|m), shape (n_mes, n_obj)
        
        # Determine number of subplots needed
        num_messages = self.n_mes
        
        fig, axes = plt.subplots(num_messages, 1, figsize=(8, num_messages * 3.5), squeeze=False) 
        # squeeze=False ensures axes is always 2D
        
        for i, message in enumerate(self.messages):
            ax = axes[i, 0] # Access subplot
            ax.bar(self.objects, L1_probs[i, :])
            ax.set_title(f'L1 P(object | message "{message}") {title_suffix}')
            ax.set_xlabel('Objects/States')
            ax.set_ylabel('Probability')
            ax.set_ylim([0, 1])
            ax.tick_params(axis='x', rotation=45) # Rotate x-axis labels for better readability

        plt.tight_layout()
        if filename:
            plt.savefig(filename, bbox_inches='tight')
        plt.show()

In [None]:
class RSAExtended(RSAModel):
    def __init__(self, objects, messages, truth_table, alpha, prior_o, prior_m, cost_function):
        super().__init__(objects, messages, truth_table, alpha, prior_o, cost_function)
        self._raw_prior_m = np.array(prior_m) # Store raw before normalization for setter
        self.prior_m = prior_m # This is P(m), will be normalized by setter

    # Overrides S1 to include P(m)
    # Pragmatic Speaker S1: P_S1(m|o) ∝ exp(alpha * (log P_L0(o|m) - C(m))) * P(m)
    def S1(self):
        epsilon = 1e-10
        
        # Get P_L0(o|m) from L0()
        L0_o_given_m = self.L0() # Shape (n_mes, n_obj)
        
        # Calculate log P_L0(o|m)
        log_L0_o_given_m = np.log(L0_o_given_m + epsilon) # Shape (n_mes, n_obj)
        
        # Costs C(m)
        message_costs = np.array(self.cost_function())[:, np.newaxis] # Shape (n_mes, 1)
        
        # Utility U(m;o) = log P_L0(o|m) - C(m)
        utility_with_cost = log_L0_o_given_m - message_costs # Broadcasting. Shape (n_mes, n_obj)
        
        # S1(m|o) ∝ exp(alpha * U(m;o)) * P(m)
        # We need exp(alpha * utility.T) which is (n_obj, n_mes)
        # Then multiply by P(m) (shape (n_mes,)) and normalize rows.
        exp_alpha_utility_T = np.exp(self.alpha * utility_with_cost.T) # Shape (n_obj, n_mes)
        
        # Multiply by P(m). self._prior_m is (n_mes,) normalized.
        # Each row of exp_alpha_utility_T corresponds to an object o.
        # For each o, P(m|o) is scaled by P(m).
        unnorm_S1_m_given_o_T = exp_alpha_utility_T * self._prior_m # Broadcasting P(m) across rows. Shape (n_obj, n_mes)

        norm_S1_m_given_o = self.normalize(unnorm_S1_m_given_o_T, axis=1) # Normalize rows
        
        return norm_S1_m_given_o # P_S1(m|o)
    
    @property
    def prior_m(self):
        return self._prior_m
    
    @prior_m.setter
    def prior_m(self, prior_m_values):
        if len(prior_m_values) == self.n_mes:
            # Store raw values for potential future use if needed
            self._raw_prior_m = np.array(prior_m_values)
            # Normalize and store the normalized prior
            sum_priors = np.sum(prior_m_values)
            if sum_priors == 0: # Avoid division by zero if all priors are zero
                self._prior_m = np.ones(self.n_mes) / self.n_mes # Uniform if sum is zero
            else:
                self._prior_m = np.array(prior_m_values) / sum_priors
        else:
            raise ValueError('Length of prior_m must equal the number of messages')

# Scenario: Interpreting "hope-wh" Utterances

This section models a scenario where a speaker expresses a preference about an uncertain situation 'S'.
- **Objects/States (o):** The speaker's internal state of desire.
  - `o1 (desire_positive_S)`: Speaker desires a positive outcome for S.
  - `o2 (uncertain_about_S)`: Speaker is uncertain/information-seeking about S.
- **Messages (m):**
  - `m1 ("hope that S_good")`: "I hope that S turns out good." (Standard, explicitly positive preference)
  - `m2 ("wonder what S")`: "I wonder what S will be." (Standard, information-seeking)
  - `m3 ("hope what S")`: "I hope what S will be." (Target marked/L2 utterance)

The listener (L1) infers the speaker's state (o) given their utterance (m).

In [None]:
objects_s = ['desire_positive_S', 'uncertain_about_S']
messages_s = ['hope_that_S_good', 'wonder_what_S', 'hope_what_S']

# truth_table[message_idx, object_idx] = 1 if message is compatible with object/state, else 0
# m1 ("hope that S_good") is compatible with o1 (desire_positive_S)
# m2 ("wonder what S") is compatible with o2 (uncertain_about_S)
# m3 ("hope what S") is assumed, if uttered, to express a desire for a positive outcome, so compatible with o1.
truth_table_s = np.array([
    [1, 0],  # m1 compatible with o1
    [0, 1],  # m2 compatible with o2
    [1, 0]   # m3 compatible with o1 (semantic/pragmatic assumption about its intended use)
])

# P(o): Prior probability for the speaker's underlying state
prior_o_s = np.array([0.5, 0.5]) 
alpha_s = 3.0 # Speaker optimality parameter (higher means more "rational" choices)

## Native Speaker Model

Assumptions for the native speaker model:
- The marked utterance `m3 ("hope what S")` has a higher cost (reflecting lower grammaticality or preference for native speakers).
- The prior probability of a native speaker uttering `m3` is very low.

In [None]:
def cost_function_native_s():
    # Costs for ["hope_that_S_good", "wonder_what_S", "hope_what_S"]
    return np.array([0, 0, 2.0]) # m3 has a higher cost

# P(m): Prior over messages for native speaker
# Low prior for the marked utterance m3
prior_m_native_s = np.array([0.49, 0.49, 0.02])

RSA_native_s_model = RSAExtended(
    objects_s, 
    messages_s, 
    truth_table_s, 
    alpha_s, 
    prior_o_s, 
    prior_m_native_s, 
    cost_function_native_s
)

print("--- Native Speaker Model ---")
print("L0(object | message):")
print(f"{objects_s}")
L0_native = RSA_native_s_model.L0()
for i, msg in enumerate(messages_s):
    print(f"{msg}: {L0_native[i,:]}")

print("\nS1(message | object):")
print(f"          {messages_s}")
S1_native = RSA_native_s_model.S1()
for i, obj in enumerate(objects_s):
    print(f"{obj}: {S1_native[i,:]}")
    
print("\nL1(object | message):")
print(f"          {objects_s}")
L1_native_s = RSA_native_s_model.L1()
for i, msg in enumerate(messages_s):
    print(f"{msg}: {L1_native_s[i,:]}")

RSA_native_s_model.plot(filename='RSA_hope_wh_native.png', title_suffix='(Native Model)')

## Non-Native Speaker Model (Listener modeling an L2 Speaker)

Assumptions for the non-native speaker model (or a listener modeling an L2 speaker):
- The marked utterance `m3 ("hope what S")` may have a lower perceived cost or penalty compared to the native model.
- The prior probability of the L2 speaker uttering `m3` might be higher than for a native speaker (e.g., due to L1 transfer, overgeneralization of L2 rules, or less sensitivity to markedness).

In [None]:
def cost_function_nonnative_s():
    # Costs for ["hope_that_S_good", "wonder_what_S", "hope_what_S"]
    return np.array([0, 0, 0.5]) # m3 has a lower cost for L2 speaker model

# P(m): Prior over messages for non-native speaker model
# Higher prior for the marked utterance m3 compared to native model
prior_m_nonnative_s = np.array([0.4, 0.4, 0.2])

RSA_nonnative_s_model = RSAExtended(
    objects_s, 
    messages_s, 
    truth_table_s, 
    alpha_s, 
    prior_o_s, 
    prior_m_nonnative_s, 
    cost_function_nonnative_s
)

print("--- Non-Native Speaker Model ---")
print("L0(object | message):")
print(f"{objects_s}")
L0_nonnative = RSA_nonnative_s_model.L0()
for i, msg in enumerate(messages_s):
    print(f"{msg}: {L0_nonnative[i,:]}")

print("\nS1(message | object):")
print(f"          {messages_s}")
S1_nonnative = RSA_nonnative_s_model.S1()
for i, obj in enumerate(objects_s):
    print(f"{obj}: {S1_nonnative[i,:]}")

print("\nL1(object | message):")
print(f"          {objects_s}")
L1_nonnative_s = RSA_nonnative_s_model.L1()
for i, msg in enumerate(messages_s):
    print(f"{msg}: {L1_nonnative_s[i,:]}")

RSA_nonnative_s_model.plot(filename='RSA_hope_wh_nonnative.png', title_suffix='(Non-Native Model)')

## Discussion of Results

By comparing the `L1(object | message)` probabilities, particularly for `m3 ("hope what S")`, between the native and non-native speaker models, we can observe how the listener's interpretation shifts based on their model of the speaker.

Specifically, we are interested in `L1('desire_positive_S' | "hope what S")`.
- In the **Native Model**, this probability might be low if the high cost and low prior of `m3` make it an unlikely choice, even if the speaker desires a positive outcome. The listener might be confused or assign low probability to any specific state given such a marked utterance.
- In the **Non-Native Model**, if `m3` is less costly and has a higher prior, the listener might be more willing to infer `o1 ('desire_positive_S')` from it. This would happen if `m3`, despite being marked, becomes a more viable signal for conveying this state under the listener's assumptions about an L2 speaker's linguistic system.

The plots visualize these L1 listener probabilities. If `L1('desire_positive_S' | "hope what S")` is higher in the Non-Native Model than in the Native Model, it would suggest that modeling the speaker as non-native makes the marked utterance a more informative signal of their desire for a positive outcome.

**Note:**
* The parameters used (costs, priors, alpha) are illustrative and would ideally be informed by empirical data or more detailed linguistic theory.
* The specific implementation of the RSA model (especially the utility function U(m;o) and S1 definition) can vary. This notebook uses a common approach.
* The definition of the `truth_table`, particularly for the marked utterance `m3`, inherently includes pragmatic assumptions about its intended meaning if used.