# Te Ataarangi Lesson

This notebook attempts to implement a lesson in Te Ataarangi using the Rational Speech Act framework.

In [24]:
import numpy as np
import matplotlib.pyplot as plt

In [25]:
def normalize(x):
    return x / np.sum(x, axis=0)

def safe_log(x, eps=1e-10):
    clipped_x = np.clip(x, eps, None)
    return np.where(x > 0, np.log(clipped_x), -np.inf)

def matrices_are_similar(M, N, tolerance=0.1):
    """
    Check if two matrices are similar within a certain tolerance.
    """
    return np.allclose(M, N, atol=tolerance)

In [26]:
class RationalSpeechAgent:
    def __init__(self, world_states, utterances, literal_listener_matrix, prior=None):
        self.world_states = world_states
        self.utterances = utterances
        self.literal_listener_matrix = normalize(np.array(literal_listener_matrix))
        self.prior = prior if prior is not None else np.ones(len(world_states)) / len(world_states)

    def literal_listener(self, utterance_index):
        return self.literal_listener_matrix[:, utterance_index]

    def pragmatic_speaker(self, world_state_index, alpha=1.0):
        utilities = np.array([safe_log(alpha * probability) for probability in self.literal_listener_matrix[world_state_index, :]])
        return normalize(np.exp(utilities))

    def pragmatic_listener(self, utterance_index):
        speaker_matrix = np.array([self.pragmatic_speaker(ws) for ws in range(len(self.world_states))])
        return normalize(np.dot(speaker_matrix.T, self.prior)[utterance_index])


class TeacherAgent(RationalSpeechAgent):
    """
    The main difference between the teacher and the student is that a teacher has fixed beliefs, 
    and also tracks the beliefs of the student
    """
    def __init__(self, world_states, utterances, literal_listener_matrix, student_model_matrix, prior=None):
        super().__init__(world_states, utterances, literal_listener_matrix, prior)
        self.student_model_matrix = normalize(np.array(student_model_matrix))  # Teacher's belief about the student's knowledge

    def update_student_model(self, world_state_index, student_utterance_index):
        # Set the probability for the observed utterance and world state to 1
        self.student_model_matrix[world_state_index, :] = 0
        self.student_model_matrix[world_state_index, student_utterance_index] = 1

        # Renormalize every row
        for utterance_index in range(len(self.utterances)):
            nonzeros = 1 * (self.student_model_matrix[:, utterance_index] > 0)
            self.student_model_matrix[:, utterance_index] = nonzeros / np.sum(nonzeros)

    def suggest_world_state(self):
        # Calculate entropy only for rows with all non-zero values
        entropy = lambda p: -np.sum([pi * np.log(pi) for pi in p if pi > 0])
        valid_indices = [index for index, row in enumerate(self.student_model_matrix) if not np.any(row == 0)]

        # If no valid rows exist, pick a state at random
        if not valid_indices:
            return np.random.randint(0, len(self.world_states))
        
        # Calculate entropy for each valid world state
        student_entropy_values = np.array([entropy(self.student_model_matrix[state_index, :]) if state_index in valid_indices else -np.inf for state_index in range(len(self.world_states))])
        teacher_entropy_values = np.array([entropy(self.literal_listener_matrix[state_index, :]) if state_index in valid_indices else -np.inf for state_index in range(len(self.world_states))])

        # Choose the world state with the highest entropy among the valid ones
        return np.argmin(teacher_entropy_values * student_entropy_values)


class StudentAgent(RationalSpeechAgent):
    def __init__(self, world_states, utterances, literal_listener_matrix, prior=None):
        super().__init__(world_states, utterances, literal_listener_matrix, prior)
        self.known_utterances = set()
        self.attempt_history = set()

    def update_belief(self, world_state_index, observed_utterance_index):
        # Update the probability for the observed utterance and world state
        self.literal_listener_matrix[world_state_index, :] = 0
        self.literal_listener_matrix[world_state_index, observed_utterance_index] = 1

        # Efficiently normalize the matrix
        self.normalize_matrix()

        # Update known utterances based on the updated beliefs
        self.update_known_utterances()

    def normalize_matrix(self):
        for utterance_index in range(len(self.utterances)):
            column_sum = np.sum(self.literal_listener_matrix[:, utterance_index])
            if column_sum > 0:
                self.literal_listener_matrix[:, utterance_index] /= column_sum

    def update_known_utterances(self):
        # Update known utterances by checking columns with unique non-zero entries
        for utterance_index in range(len(self.utterances)):
            if np.any(np.sum(self.literal_listener_matrix[:, utterance_index] != 0) == 1):
                self.known_utterances.add(utterance_index)

    def suggest_world_state(self):
        # Entropy-based selection with an additional constraint to avoid states already correctly answered
        entropy = lambda p: -np.sum([pi * np.log(pi) for pi in p if pi > 0])

        valid_indices = [index for index, row in enumerate(self.literal_listener_matrix)
                         if index not in self.attempt_history or not self.attempt_history[index]
                         and not np.any(row == 0) 
                         and all(utterance_index in self.known_utterances for utterance_index, val in enumerate(row) if val > 0)]

        if not valid_indices:
            return np.random.randint(0, len(self.world_states))

        entropy_values = [entropy(self.literal_listener_matrix[state_index, :]) if state_index in valid_indices else -np.inf 
                          for state_index in range(len(self.world_states))]

        return np.argmax(entropy_values)

world_states = ['1 rākau', '2 rākau', '3 rākau', '4 rākau', '5 rākau']
utterances = ['Te rākau', 'Ngā rākau', 'He rākau']

# Instantiate teacher and student
teacher = TeacherAgent(
    world_states=world_states,
    utterances=utterances,
    literal_listener_matrix=[
        [1.0, 0.0, 0.0],  # 1 rākau
        [0.0, 0.0, 1.0],  # 2 rākau
        [0.0, 0.0, 1.0],  # 3 rākau
        [0.0, 0.0, 1.0],  # 4 rākau
        [0.0, 1.0, 0.0],  # 5 rākau
    ],
    student_model_matrix=np.full((len(world_states), len(utterances)), 1 / len(world_states))  # Teacher's initial model of the student's knowledge
)

student = StudentAgent(
    world_states=world_states,
    utterances=utterances,
    literal_listener_matrix=np.full((len(world_states), len(utterances)), 1 / len(world_states))  # Student starts with no specific knowledge
)

In [27]:
print(f'\nteacher.literal_listener_matrix:\n{teacher.literal_listener_matrix}\n')

interaction_count = 0
interaction_limit = 500
while not matrices_are_similar(teacher.student_model_matrix, teacher.literal_listener_matrix, tolerance=0.01):
    interaction_count += 1
    print(f"\nInteraction {interaction_count}:")
    
    # Teacher determines which world state to demonstrate based on where the student needs most guidance
    world_state_index = teacher.suggest_world_state()
    teacher_utterance_index = np.argmax(teacher.pragmatic_speaker(world_state_index))
    teacher_utterance = teacher.utterances[teacher_utterance_index]

    # Teacher demonstrates
    teacher.update_student_model(world_state_index, teacher_utterance_index)
    print(f"Teacher: For '{teacher.world_states[world_state_index]}', the best utterance is '{teacher_utterance}'.")

    # Student observes and updates its belief
    student.update_belief(world_state_index, teacher_utterance_index)
    
    # Student's turn to conjecture
    student_world_state_index = student.suggest_world_state()
    student_utterance_index = np.argmax(student.pragmatic_speaker(student_world_state_index))
    student_utterance = student.utterances[student_utterance_index]

    # Teacher updates its model of the student
    teacher.update_student_model(student_world_state_index, student_utterance_index)

    # Feedback (optional but useful for observation and for student's adjustment in an expanded model)
    correct_utterance_index = np.argmax(teacher.pragmatic_speaker(student_world_state_index))
    student.update_belief(student_world_state_index, correct_utterance_index)
    teacher.update_student_model(student_world_state_index, correct_utterance_index)
    if student_utterance_index == correct_utterance_index:
        print(f"Student: For '{student.world_states[student_world_state_index]}', I believe the correct utterance is '{student_utterance}'. Correct!")
    else:
        correct_utterance = teacher.utterances[correct_utterance_index]
        print(f"Student: For '{student.world_states[student_world_state_index]}', I believe the correct utterance is '{student_utterance}'. Incorrect. The correct utterance should be '{correct_utterance}'.")

    if interaction_count > interaction_limit:  # Safety break to avoid infinite loops in case of convergence issues
        print("Interaction limit reached without convergence.")
        break

    # print(f'teacher.student_model_matrix:\n{teacher.student_model_matrix}\n')
    print(f'student.literal_listener_matrix:\n{teacher.student_model_matrix}\n')

if interaction_count <= interaction_limit:
    print("\nThe student's understanding is now aligned with the teacher's knowledge.")


teacher.literal_listener_matrix:
[[1.         0.         0.        ]
 [0.         0.         0.33333333]
 [0.         0.         0.33333333]
 [0.         0.         0.33333333]
 [0.         1.         0.        ]]


Interaction 1:
Teacher: For '1 rākau', the best utterance is 'Te rākau'.
Student: For '2 rākau', I believe the correct utterance is 'Ngā rākau'. Incorrect. The correct utterance should be 'He rākau'.
student.literal_listener_matrix:
[[0.25       0.         0.        ]
 [0.         0.         0.25      ]
 [0.25       0.33333333 0.25      ]
 [0.25       0.33333333 0.25      ]
 [0.25       0.33333333 0.25      ]]


Interaction 2:
Teacher: For '5 rākau', the best utterance is 'Ngā rākau'.
Student: For '3 rākau', I believe the correct utterance is 'Ngā rākau'. Incorrect. The correct utterance should be 'He rākau'.
student.literal_listener_matrix:
[[0.5        0.         0.        ]
 [0.         0.         0.33333333]
 [0.         0.         0.33333333]
 [0.5        0.5        0