In [1]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
from loguru import logger
sys.path.append(os.path.abspath('..'))
from env.env_gym import GuestEnv

In [2]:
def run_env(env, actions, steps=100, seed=42):
    obs, info = env.reset(seed=seed)
    phonemes, rewards = [], []
    for i in range(steps):
        act = actions[i] if i < len(actions) else actions[-1]
        obs, reward, terminated, truncated, info = env.step(act)
        phonemes.append(info['phoneme'].copy())
        rewards.append(reward)
        if terminated or truncated:
            break
    return np.array(phonemes), np.array(rewards)

def run_env_lowest_phoneme(env, action_kind, steps=100, seed=42):
    obs, info = env.reset(seed=seed)
    phonemes, rewards = [], []

    for i in range(steps):
        # Get current phoneme counts from env (or info)
        current_phonemes = getattr(env, "phonemes", info.get("phoneme"))
        
        if current_phonemes is None:
            raise ValueError("Environment or info must expose `phonemes` or `info['phoneme']`.")

        # Choose the participant with the lowest phoneme count
        target = int(np.argmin(current_phonemes))

        # Encourage that participant (action 5 + target)
        action = action_kind + target
        logger.info(f"{action=}, {current_phonemes=}, step={i}")

        obs, reward, terminated, truncated, info = env.step(action)

        # Track data
        phonemes.append(np.array(info["phoneme"], copy=True))
        rewards.append(reward)

        if terminated or truncated:
            break

    return np.array(phonemes), np.array(rewards)


def plot_results(phonemes, rewards, title):
    steps = np.arange(len(rewards))
    plt.figure(figsize=(9,4))
    for a in range(phonemes.shape[1]):
        plt.plot(steps, phonemes[:, a], label=f'Agent {a}')
    plt.xlabel('Step'); plt.ylabel('Phonemes'); plt.title(title + ' - Phonemes')
    plt.legend(); plt.grid(True); plt.show()
    plt.figure(figsize=(9,4))
    plt.plot(steps, rewards, label='Reward')
    plt.xlabel('Step'); plt.ylabel('Reward'); plt.title(title + ' - Reward')
    plt.grid(True); plt.show()

ACTIONS = {
    0: "wait",
    1: "stare_at 0",
    2: "stare_at 1",
    3: "stare_at 2",
    4: "encourage 0",
    5: "encourage 1",
    6: "encourage 2",
}

In [3]:
# Create environment
env = GuestEnv(max_steps=100, reward_shaping=False)

In [None]:
env.reset()
for i in range(100):
    # if i < 1:
    #     obs, reward, terminated, truncated, info = env.step(1)
    #     print("-------------------------------------------")
    obs, reward, terminated, truncated, info = env.step(4)
    eng = getattr(env, "energy")
    # print(f"{obs=}")
    logger.info(f"eng={[f'{x:.5f}' for x in eng]} - {i=}")

[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mself.energy=array([0.0386978 , 0.27555138, 0.44343917])[0m
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.94070', '0.28555', '0.47144'] - i=0
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.92000', '0.29555', '0.49944'] - i=1
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.92000', '0.30555', '0.52744'] - i=2
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.92000', '0.31555', '0.55544'] - i=3
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.92000', '0.32555', '0.58344'] - i=4
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 0.9, 'remaining': 10}[0m
eng=['0.92000', '0.33555', '0.61144'] - i=5
[32m2025-10-21 09:51:59[0m | [34mINFO[0m | [1mbuff={'amount': 