In [145]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().resolve().parent.parent))

In [None]:
# experiments
from experiments.run_sweep import run_sweep
from experiments.run_experiment import TrainingConfig, EvaluateConfig
from experiments.sweep_plots import plot_sweep_training, plot_sweep_evaluation
from experiments.sweep_plots_helper import taxi_training_plot_specs, taxi_evaluation_plot_specs
from experiments.notebook_helpers import generate_alpha_pairs, asymmetric_alphas


# environments
from environments.taxi_v3 import TaxiV3Config, get_taxi_v3_env

# SARSA confirmation bias agent
from agents.sarsa_td0_confirmation_bias import (
    SarsaTD0ConfirmationBiasConfig,
    SarsaTD0ConfirmationBiasAgent,
)

# metrics for training
from metrics.reward_mertrics import taxi_v3_reward_metrics_specs
from metrics.frustration_metrics import frustration_metrics_specs

# external libraries
import numpy as np

In [None]:
BASE_AGENT_CONFIG_CONF = dict(
    alpha_conf=0.2,
    alpha_disconf=0.2,
    gamma=0.99,
    epsilon=0.2,
    reward_metrics=taxi_v3_reward_metrics_specs(),
    td_error_metrics=frustration_metrics_specs(),
)


BASE_AGENT_CONFIG_POSITITY = dict(
    alpha_positive=0.2,
    alpha_negative=0.2,
    gamma=0.99,
    epsilon=0.2,
    reward_metrics=taxi_v3_reward_metrics_specs(),
    td_error_metrics=frustration_metrics_specs(),
)

NUM_TRAIN_EPISODES = 50000
NUM_EVAL_EPISODES = 5000
NUM_SEEDS = 15

## Environments

In [148]:
env_config = TaxiV3Config()

## Agents

In [None]:
## Agents
# sarsa_td0 confirmation bias agent
agent_factory = SarsaTD0ConfirmationBiasAgent
sarsa_td0_config = SarsaTD0ConfirmationBiasConfig(**BASE_AGENT_CONFIG_CONF)

## Sweep configuration

In [None]:
base_train = TrainingConfig(
    name="TaxiV3_sarsa_td0",
    num_train_episodes=NUM_TRAIN_EPISODES,
    env_kwargs={"config": env_config},
    agent_kwargs={"config": sarsa_td0_config},
)

base_eval = EvaluateConfig(
    name="TaxiV3_sarsa_td0",
    num_eval_episodes=NUM_EVAL_EPISODES,  # use >0 if you want eval outputs
    env_kwargs={"config": env_config},
    evaluation_metrics=taxi_v3_reward_metrics_specs(),
)

## Confirmation bias results

In [126]:
# different q_tables
env = get_taxi_v3_env(env_config)

In [None]:
# create different initial q_tables for the Taxi sweep
env = get_taxi_v3_env(env_config)
num_states = env.observation_space.n
num_actions = env.action_space.n

import numpy as np
import pickle
from pathlib import Path

out_dir = Path("outputs/sweeps")
out_dir.mkdir(parents=True, exist_ok=True)

# q-table setup (same as FrozenLake version)
q0 = np.zeros((num_states, num_actions), dtype=np.float64)
q_tables = [("zeros", q0)]

# confirmation-bias alpha pairs
confirmatory_pairs, balanced_pairs, disconfirmatory_pairs = generate_alpha_pairs(
    balanced_lr=[0.2], num_pairs=5, step_size=0.025
)

# use all pair groups
alpha_pairs = confirmatory_pairs + balanced_pairs + disconfirmatory_pairs
seeds = list(range(NUM_SEEDS))

all_results = []
done_pairs = []

for a_conf, a_disconf in alpha_pairs:
    # safe file name for float values
    pair_tag = f"ac_{a_conf:.3f}_ad_{a_disconf:.3f}".replace(".", "p")
    out_file = out_dir / f"taxi_conf_{pair_tag}.pkl"

    # resume behavior
    if out_file.exists():
        print(f"Skipping pair=({a_conf}, {a_disconf}) (already exists): {out_file}")
        with out_file.open("rb") as f:
            pair_results = pickle.load(f)
        all_results.extend(pair_results)
        done_pairs.append((a_conf, a_disconf))
        continue

    sweep_one_pair = {
        "agent_kwargs": [
            {
                "alpha_conf": a_conf,
                "alpha_disconf": a_disconf,
                "seed": seed,
                "initial_q_table": q_table.copy(),
                "initial_q_table_label": label,
            }
            for (label, q_table) in q_tables
            for seed in seeds
        ],
    }

    print(f"Running Taxi pair=({a_conf}, {a_disconf}) with {len(seeds)} seeds...")
    pair_results = run_sweep(
        base_training=base_train,
        base_evaluation=base_eval,
        sweep=sweep_one_pair,
        env_factory=get_taxi_v3_env,
        agent_factory=agent_factory,
    )

    # save immediately
    with out_file.open("wb") as f:
        pickle.dump(pair_results, f)

    all_results.extend(pair_results)
    done_pairs.append((a_conf, a_disconf))
    print(f"Saved pair=({a_conf}, {a_disconf}) -> {out_file}")

# optional combined file
combined_file = out_dir / "taxi_conf_all_pairs.pkl"
with combined_file.open("wb") as f:
    pickle.dump(all_results, f)

print(f"Done pairs: {len(done_pairs)} / {len(alpha_pairs)}")
print(f"Total runs collected: {len(all_results)}")
# print(f"Combined saved to: {combined_file}")

## Positivity bias results

In [157]:
# SARSA positivty bias agent
from agents.sarsa_td0_positivity_bias import (
    SarsaTD0PositivityBiasConfig,
    SarsaTD0PositivityBiasAgent,
)

In [None]:
## Agents
# sarsa_td0 agent
agent_factory = SarsaTD0PositivityBiasAgent
sarsa_td0_config = SarsaTD0PositivityBiasConfig(**BASE_AGENT_CONFIG_POSITITY)

In [None]:
base_train = TrainingConfig(
    name="TaxiV3_sarsa_td0",
    num_train_episodes=NUM_TRAIN_EPISODES,
    env_kwargs={"config": env_config},
    agent_kwargs={"config": sarsa_td0_config},
)

base_eval = EvaluateConfig(
    name="TaxiV3_sarsa_td0",
    num_eval_episodes=NUM_TRAIN_EPISODES,
    env_kwargs={"config": env_config},
    evaluation_metrics=taxi_v3_reward_metrics_specs(),
)

In [161]:
# different q_tables
env = get_taxi_v3_env(env_config)
num_states = env.observation_space.n
num_actions = env.action_space.n
q0 = np.zeros((num_states, num_actions), dtype=np.float64)

In [155]:
# mean alpha_0 = 0.2
a_0 = 0.2
ratios = [0.8, 0.9, 1.0, 1.1, 1.2]

seeds = list(range())

sweep = {
    "agent_kwargs": [
        {
            **asymmetric_alphas(a_0, r),  # gives alpha_positive / alpha_negative
            "seed": seed,
            "initial_q_table": q0,
            "initial_q_table_label": "zeros",
        }
        for r in ratios
        for seed in seeds
    ],
}

TypeError: range expected at least 1 argument, got 0

In [None]:
import numpy as np
import pickle
from pathlib import Path

# Taxi setup
env = get_taxi_v3_env(env_config)
num_states = env.observation_space.n
num_actions = env.action_space.n
q0 = np.zeros((num_states, num_actions), dtype=np.float64)

out_dir = Path("outputs/sweeps")
out_dir.mkdir(parents=True, exist_ok=True)

a_0 = 0.2
ratios = [4.0, 2.0, 1.5, 1.0, 0.75, 0.5]
seeds = list(range(NUM_SEEDS))

all_results = []
done_ratios = []

for r in ratios:
    ratio_tag = str(r).replace(".", "p")
    out_file = out_dir / f"taxi_conf_ratio_{ratio_tag}.pkl"

    # Skip already completed ratio runs (resume behavior)
    if out_file.exists():
        print(f"Skipping ratio={r} (already exists): {out_file}")
        with out_file.open("rb") as f:
            ratio_results = pickle.load(f)
        all_results.extend(ratio_results)
        done_ratios.append(r)
        continue

    sweep_one_ratio = {
        "agent_kwargs": [
            {
                **asymmetric_alphas(
                    a_0, r
                ),  # expected to produce alpha_conf / alpha_disconf
                "seed": seed,
                "initial_q_table": q0.copy(),
                "initial_q_table_label": "zeros",
            }
            for seed in seeds
        ],
    }

    print(f"Running Taxi ratio={r} with {len(seeds)} seeds...")
    ratio_results = run_sweep(
        base_training=base_train,
        base_evaluation=base_eval,
        sweep=sweep_one_ratio,
        env_factory=get_taxi_v3_env,
        agent_factory=agent_factory,
    )

    # Save immediately after each ratio
    with out_file.open("wb") as f:
        pickle.dump(ratio_results, f)

    all_results.extend(ratio_results)
    done_ratios.append(r)
    print(f"Saved ratio={r} -> {out_file}")

# Optional combined file
combined_file = out_dir / "taxi_conf_all_ratios.pkl"
with combined_file.open("wb") as f:
    pickle.dump(all_results, f)

print(f"Done ratios: {done_ratios}")
print(f"Total runs collected: {len(all_results)}")
print(f"Combined saved to: {combined_file}")

Running Taxi ratio=0.5 with 15 seeds...
Saved ratio=0.5 -> outputs/sweeps/taxi_pos_ratio_0.5.pkl
Running Taxi ratio=0.75 with 15 seeds...
Saved ratio=0.75 -> outputs/sweeps/taxi_pos_ratio_0.75.pkl
Running Taxi ratio=1.0 with 15 seeds...
Saved ratio=1.0 -> outputs/sweeps/taxi_pos_ratio_1.0.pkl
Running Taxi ratio=1.5 with 15 seeds...
Saved ratio=1.5 -> outputs/sweeps/taxi_pos_ratio_1.5.pkl
Running Taxi ratio=2.0 with 15 seeds...
Saved ratio=2.0 -> outputs/sweeps/taxi_pos_ratio_2.0.pkl
Running Taxi ratio=4.0 with 15 seeds...
Saved ratio=4.0 -> outputs/sweeps/taxi_pos_ratio_4.0.pkl
Done ratios: [0.5, 0.75, 1.0, 1.5, 2.0, 4.0]
Total runs collected: 90
Combined saved to: outputs/sweeps/taxi_pos_all_ratios.pkl
