## Setup Dataset

In [None]:
from src.data.setup import download_and_extract

download_and_extract(
    url="https://drive.google.com/uc?id=1STHDs5uR-bx-6beC36qtGaW_UqC1fY7s",
    zip_path="data.zip",
    extract_dir="data",
    verbose=True,
    remove_zip=True
)

## Offline-Train All Classifiers

In [None]:
from src.models.decision_tree.train import main as train_decision_tree
from src.models.hoeffding.train import main as train_hoeffding_tree
from src.models.weighted_forest.train import main as train_weighted_forest

train_decision_tree()
train_hoeffding_tree()
train_weighted_forest()

## Launch Interactive Game

Left player controls the paddle using Q (up) and A (down).

Choose the enemy using the mode variable. Choices include:
- "human": play against another player
- "pc": pc player aiming to reach ball_y
- "dt": offline-only trained decision tree
- "ht": offline-only pre-trained hoeffding tree
- "wf:" offline-only pre-trained weighted forest

In [None]:
from src.main import main

mode = "wf" # "human", "pc", "dt", "ht", "wf"
main(mode)

## Online Training - Hoeffding Tree

In [None]:
from src.training.train_online import train_decision_online
from src.training.train_online import train_hoeffding_online
from src.training.train_online import train_weighted_forest_online

train_decision_online(
    pretrained_model_path="models/dt/decision_tree_pong.pkl",
    num_episodes=20,
    max_score_per_episode=5,
    save_interval=5
)

train_hoeffding_online(
    pretrained_model_path="models/ht/hoeffding_tree_pong.pkl",
    num_episodes=20,
    max_score_per_episode=5,
    save_interval=5
)

train_weighted_forest_online(
    pretrained_model_path="models/wf/weighted_forest_pong.pkl",
    metadata_path="models/wf/weighted_forest_metadata.pkl",
    num_episodes=20,
    max_score_per_episode=5,
    save_interval=5
)

## Compare Pretrained vs Online Models

In [None]:
from src.evaluation import compare_pretrained_vs_online

results = compare_pretrained_vs_online(max_score=5)

## Visualize Online Training Metrics

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

ht_metrics = pd.read_csv("models/ht/hoeffding_online_metrics.csv")
wf_metrics = pd.read_csv("models/wf/weighted_forest_online_metrics.csv")

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(ht_metrics['episode'], ht_metrics['survival_seconds'], marker='o')
axes[0, 0].set_title('Hoeffding Tree - Survival Time')
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Survival (seconds)')

axes[0, 1].plot(ht_metrics['episode'], ht_metrics['progressive_accuracy'], marker='o')
axes[0, 1].set_title('Hoeffding Tree - Progressive Accuracy')
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Accuracy')

axes[1, 0].plot(wf_metrics['episode'], wf_metrics['survival_seconds'], marker='o')
axes[1, 0].set_title('Weighted Forest - Survival Time')
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Survival (seconds)')

axes[1, 1].plot(wf_metrics['episode'], wf_metrics['num_cells'], marker='o')
axes[1, 1].set_title('Weighted Forest - Active Cells')
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Number of Cells')

plt.tight_layout()
plt.show()

## Statistical Analysis (10x5 CV)

In [None]:
import time
import numpy as np
from src.statistical_tests import run_rskf, print_results
from src.data.loader import load_training_data
from src.data.preparation import min_max_scale, convert_str_to_int, undersample
from sklearn.tree import DecisionTreeClassifier
from river.tree import HoeffdingTreeClassifier
from src.models.weighted_forest.clf import WeightedForest, euclidean_distance
from sklearn.metrics import accuracy_score

X, y = load_training_data(random_state=42)
X_np = X.to_numpy()
y_np = y.to_numpy()

X_np, _, _ = min_max_scale(X_np)
y_np, class_mapping = convert_str_to_int(y_np)
X_np, y_np = undersample(X_np, y_np, random_seed=42)

models = {
    'DecisionTree': 'DecisionTree', 
    'HoeffdingTree': 'HoeffdingTree', 
    'WeightedForest': 'WeightedForest'
}

inference_times = {model: [] for model in models}

def train_eval_fn(model_name, X_train, y_train, X_test, y_test):
    preds = []
    
    if model_name == 'DecisionTree':
        clf = DecisionTreeClassifier(max_depth=20, min_samples_split=10, random_state=42)
        clf.fit(X_train, y_train)
        start_time = time.time()
        preds = clf.predict(X_test)
        end_time = time.time()

    elif model_name == 'HoeffdingTree':
        clf = HoeffdingTreeClassifier()
        for x, y_label in zip(X_train, y_train):
            clf.learn_one(dict(enumerate(x)), y_label)
        start_time = time.time()
        preds = [clf.predict_one(dict(enumerate(x))) for x in X_test]
        end_time = time.time()

    elif model_name == 'WeightedForest':
        clf = WeightedForest(
            X_train.shape[1], len(np.unique(y_train)), euclidean_distance,
            accuracy_goal=0.65, random_seed=42
        )
        clf.fit(X_train, y_train, epochs=3)
        start_time = time.time()
        preds = clf.predict(X_test)
        end_time = time.time()
        
        preds = preds.astype(int)

    inference_times[model_name].append(end_time - start_time)

    return accuracy_score(y_test, preds)

output = run_rskf(train_eval_fn, models, X_np, y_np, n_repeats=10, n_splits=5, random_state=42)

print_results(output)

print("\nInference Time Results (Seconds per Fold)")
print(f"{'Model':<20} | {'Mean Time':<12} | {'Std Dev':<10}")
print("-" * 46)

for model_name in models:
    times = inference_times[model_name]
    mean_time = np.mean(times)
    std_time = np.std(times)
    print(f"{model_name:<20} | {mean_time:.6f}s   | ± {std_time:.6f}s")

## Game Simulation Evaluation (10 Games x 5 Models)

In [None]:
import numpy as np
import pandas as pd
from src.models.model_loader import PongAIPlayer
from src.evaluation import evaluate_model
from src.statistical_tests import friedman_test, wilcoxon_posthoc

NUM_GAMES = 10
MAX_SCORE = 5

MODEL_CONFIGS = {
    'DT_pretrained': ('models/dt/decision_tree_pong.pkl', 'models/dt/decision_tree_metadata.pkl'),
    'HT_pretrained': ('models/ht/hoeffding_tree_pong.pkl', 'models/ht/hoeffding_tree_metadata.pkl'),
    'HT_online': ('models/ht/hoeffding_tree_online.pkl', 'models/ht/hoeffding_tree_metadata.pkl'),
    'WF_pretrained': ('models/wf/weighted_forest_pong.pkl', 'models/wf/weighted_forest_metadata.pkl'),
    'WF_online': ('models/wf/weighted_forest_online.pkl', 'models/wf/weighted_forest_metadata.pkl'),
}

all_results = []

for model_name, (model_path, metadata_path) in MODEL_CONFIGS.items():
    print(f"\nEvaluating {model_name}...")
    
    try:
        ai = PongAIPlayer(model_path, metadata_path)
    except Exception as e:
        print(f"  Error loading model: {e}")
        continue
    
    for game_num in range(1, NUM_GAMES + 1):
        print(f"  Game {game_num}/{NUM_GAMES}", end="\r")
        
        result = evaluate_model(ai, model_name, MAX_SCORE)
        
        goal_diff = result.final_ai_score - result.final_pc_score
        
        all_results.append({
            'model': model_name,
            'game': game_num,
            'survival_time': result.survival_time_seconds,
            'returns': result.total_hits,
            'goal_diff': goal_diff,
            'ai_score': result.final_ai_score,
            'pc_score': result.final_pc_score
        })
    
    print(f"  Completed {NUM_GAMES} games for {model_name}")

df = pd.DataFrame(all_results)
csv_path = 'models/game_simulation_results.csv'
df.to_csv(csv_path, index=False)
print(f"\nResults saved to {csv_path}")


In [None]:
print("\nSummary Statistics (mean ± std) ###\n")
print(f"{'Model':<15} | {'Survival Time (s)':<20} | {'Returns':<18} | {'Goal Diff':<15}")
print("-" * 75)

summary_stats = df.groupby('model').agg({
    'survival_time': ['mean', 'std'],
    'returns': ['mean', 'std'],
    'goal_diff': ['mean', 'std']
}).round(2)

for model in MODEL_CONFIGS.keys():
    if model in summary_stats.index:
        st_mean, st_std = summary_stats.loc[model, ('survival_time', 'mean')], summary_stats.loc[model, ('survival_time', 'std')]
        ret_mean, ret_std = summary_stats.loc[model, ('returns', 'mean')], summary_stats.loc[model, ('returns', 'std')]
        gd_mean, gd_std = summary_stats.loc[model, ('goal_diff', 'mean')], summary_stats.loc[model, ('goal_diff', 'std')]
        print(f"{model:<15} | {st_mean:>6.2f} ± {st_std:<10.2f} | {ret_mean:>5.2f} ± {ret_std:<8.2f} | {gd_mean:>5.2f} ± {gd_std:<5.2f}")

# Prepare data for statistical tests
metrics = ['survival_time', 'returns', 'goal_diff']
metric_labels = {'survival_time': 'Survival Time', 'returns': 'Returns', 'goal_diff': 'Goal Difference'}

for metric in metrics:
    print(f"\n{metric_labels[metric]} Statistical Tests\n")
    
    # Create results dict for Friedman test
    results_dict = {}
    for model in MODEL_CONFIGS.keys():
        model_data = df[df['model'] == model][metric].values
        if len(model_data) == NUM_GAMES:
            results_dict[model] = model_data.tolist()
    
    if len(results_dict) < 3:
        print("Not enough models for Friedman test (need at least 3)")
        continue
    
    # Friedman test
    try:
        friedman_stat, friedman_p = friedman_test(results_dict)
        print(f"Friedman Test: statistic={friedman_stat:.4f}, p-value={friedman_p:.6f}")
        
        # Post-hoc if significant
        if friedman_p < 0.05:
            print("\nPost-hoc Pairwise Comparisons (Wilcoxon + Hommel):")
            posthoc = wilcoxon_posthoc(results_dict)
            for pair, result in posthoc.items():
                sig = "*" if result['significant'] else ""
                print(f"  {pair}: p={result['p_corrected']:.4f} {sig}")
        else:
            print("  No significant differences found (p >= 0.05)")
    except Exception as e:
        print(f"Error in statistical test: {e}")


## Hyperparameter tuning (Grid search)

In [14]:
import numpy as np
from src.models.weighted_forest.test_parameter import perform_data_prep, create_trainer_and_player, test_online

## Prepare data
X, y, scaler, class_mapping = perform_data_prep(random_state=42)

## Start gridsearch
para1_list = [3, 5, 7]                                  ## Number of start cells
para2_list = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]     ## Siginificance difference threshold

result = np.zeros(shape=(len(para1_list), len(para2_list), 2))
for idx_para1, para1 in enumerate(para1_list):
    for idx_para2, para2 in enumerate(para2_list):
        trainer, ai_player = create_trainer_and_player(X, y, class_mapping=class_mapping, scaler=scaler, epochs=3, num_start_cells=para1, similarity_threshold=para2, random_state=42)
        result[idx_para1, idx_para2] = np.array(test_online(trainer=trainer, ai_player=ai_player))    ## Survival time, hits

        np.save("data/hyperparameter.npy", result)




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 149.6300000000005, 'avg_reward': 1.496300000000005, 'num_cells': 52, 'accuracy': 0.22}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 149.61000000000053, 'avg_reward': 0.7480500000000027, 'num_cells': 51, 'accuracy': 0.11}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 355.9370105596397, 'avg_reward': 1.1864567018654657, 'num_cells': 51, 'accuracy': 0.18333333333333332}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 405.7970105596399, 'avg_reward': 1.0144925263990998, 'num_cells': 51, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 555.5947076416755, 'avg_reward': 1.111189415283351, 'num_cells': 51, 'accuracy': 0.17}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 625.9447076416759, 'avg_reward': 1.0432411794027932, 'num_cells': 51, 'accuracy': 0.15833333333333333}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 632.644707641676, 'avg




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 142.91000000000037, 'avg_reward': 1.4291000000000036, 'num_cells': 45, 'accuracy': 0.21}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 142.91000000000037, 'avg_reward': 0.7145500000000018, 'num_cells': 44, 'accuracy': 0.105}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 315.50484450045985, 'avg_reward': 1.051682815001533, 'num_cells': 44, 'accuracy': 0.15666666666666668}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 405.96937800183053, 'avg_reward': 1.0149234450045763, 'num_cells': 44, 'accuracy': 0.1525}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 519.9493780018311, 'avg_reward': 1.039898756003662, 'num_cells': 44, 'accuracy': 0.156}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 594.5293780018314, 'avg_reward': 0.990882296669719, 'num_cells': 44, 'accuracy': 0.14833333333333334}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 742.8993780018322




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 131.57000000000036, 'avg_reward': 1.3157000000000036, 'num_cells': 47, 'accuracy': 0.19}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 131.57000000000036, 'avg_reward': 0.6578500000000018, 'num_cells': 46, 'accuracy': 0.095}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 325.6595789714437, 'avg_reward': 1.0855319299048123, 'num_cells': 46, 'accuracy': 0.16333333333333333}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 410.299578971444, 'avg_reward': 1.0257489474286101, 'num_cells': 46, 'accuracy': 0.1525}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 410.299578971444, 'avg_reward': 0.8205991579428881, 'num_cells': 46, 'accuracy': 0.122}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 537.8095789714446, 'avg_reward': 0.896349298285741, 'num_cells': 46, 'accuracy': 0.13333333333333333}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 651.4595789714451, 




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 129.4800000000006, 'avg_reward': 1.294800000000006, 'num_cells': 47, 'accuracy': 0.19}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 129.4800000000006, 'avg_reward': 0.647400000000003, 'num_cells': 46, 'accuracy': 0.095}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 311.93000000000137, 'avg_reward': 1.0397666666666712, 'num_cells': 46, 'accuracy': 0.15333333333333332}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 413.0000000000018, 'avg_reward': 1.0325000000000046, 'num_cells': 46, 'accuracy': 0.1525}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 550.0300000000028, 'avg_reward': 1.1000600000000056, 'num_cells': 46, 'accuracy': 0.164}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 684.2200000000034, 'avg_reward': 1.1403666666666723, 'num_cells': 46, 'accuracy': 0.17}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 846.5700000000041, 'avg_reward': 1




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 131.57000000000036, 'avg_reward': 1.3157000000000036, 'num_cells': 47, 'accuracy': 0.19}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 131.57000000000036, 'avg_reward': 0.6578500000000018, 'num_cells': 46, 'accuracy': 0.095}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 329.09799351879474, 'avg_reward': 1.0969933117293158, 'num_cells': 46, 'accuracy': 0.16333333333333333}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 388.91796111277273, 'avg_reward': 0.9722949027819319, 'num_cells': 46, 'accuracy': 0.1525}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 530.3379611127734, 'avg_reward': 1.0606759222255469, 'num_cells': 46, 'accuracy': 0.164}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 645.5691700298667, 'avg_reward': 1.0759486167164445, 'num_cells': 46, 'accuracy': 0.16833333333333333}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 712.9291700298




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 130.73000000000044, 'avg_reward': 1.3073000000000043, 'num_cells': 47, 'accuracy': 0.19}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 130.73000000000044, 'avg_reward': 0.6536500000000022, 'num_cells': 46, 'accuracy': 0.095}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 340.03000000000134, 'avg_reward': 1.1334333333333377, 'num_cells': 46, 'accuracy': 0.16666666666666666}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 420.95000000000164, 'avg_reward': 1.052375000000004, 'num_cells': 46, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 602.2817449656072, 'avg_reward': 1.2045634899312145, 'num_cells': 46, 'accuracy': 0.178}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 712.9017449656078, 'avg_reward': 1.1881695749426797, 'num_cells': 46, 'accuracy': 0.175}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 712.9017449656078, 'avg_reward




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 130.73000000000044, 'avg_reward': 1.3073000000000043, 'num_cells': 47, 'accuracy': 0.19}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 130.73000000000044, 'avg_reward': 0.6536500000000022, 'num_cells': 46, 'accuracy': 0.095}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 330.7779935187938, 'avg_reward': 1.1025933117293125, 'num_cells': 46, 'accuracy': 0.16333333333333333}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 443.51999999999845, 'avg_reward': 1.108799999999996, 'num_cells': 46, 'accuracy': 0.165}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 551.3199999999988, 'avg_reward': 1.1026399999999976, 'num_cells': 46, 'accuracy': 0.164}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 726.4199999999998, 'avg_reward': 1.2106999999999997, 'num_cells': 46, 'accuracy': 0.18}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 780.0700000000002, 'avg_reward':




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 151.31000000000088, 'avg_reward': 1.5131000000000088, 'num_cells': 36, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 151.31000000000088, 'avg_reward': 0.7565500000000044, 'num_cells': 36, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 348.0000000000016, 'avg_reward': 1.1600000000000052, 'num_cells': 36, 'accuracy': 0.17333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 415.10000000000184, 'avg_reward': 1.0377500000000046, 'num_cells': 36, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 585.6900000000023, 'avg_reward': 1.1713800000000048, 'num_cells': 36, 'accuracy': 0.178}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 693.0700000000029, 'avg_reward': 1.1551166666666715, 'num_cells': 36, 'accuracy': 0.175}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 885.2000000000035, 'avg_reward




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 119.41000000000017, 'avg_reward': 1.1941000000000017, 'num_cells': 40, 'accuracy': 0.24}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 119.41000000000017, 'avg_reward': 0.5970500000000009, 'num_cells': 40, 'accuracy': 0.12}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 308.93000000000086, 'avg_reward': 1.0297666666666696, 'num_cells': 40, 'accuracy': 0.17333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 578.5000000000013, 'avg_reward': 1.4462500000000031, 'num_cells': 40, 'accuracy': 0.23}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 774.8300000000024, 'avg_reward': 1.549660000000005, 'num_cells': 40, 'accuracy': 0.242}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 943.9200000000028, 'avg_reward': 1.5732000000000046, 'num_cells': 39, 'accuracy': 0.24333333333333335}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 1038.2000000000032,




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 150.0600000000002, 'avg_reward': 1.500600000000002, 'num_cells': 40, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 150.0600000000002, 'avg_reward': 0.750300000000001, 'num_cells': 40, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 338.3200000000011, 'avg_reward': 1.1277333333333368, 'num_cells': 40, 'accuracy': 0.17}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 597.3100000000013, 'avg_reward': 1.4932750000000032, 'num_cells': 40, 'accuracy': 0.2225}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 597.3100000000013, 'avg_reward': 1.1946200000000027, 'num_cells': 40, 'accuracy': 0.178}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 651.7800000000013, 'avg_reward': 1.0863000000000023, 'num_cells': 39, 'accuracy': 0.16166666666666665}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 800.6200000000019, 'avg_reward': 1.




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 157.19999999999996, 'avg_reward': 1.5719999999999996, 'num_cells': 40, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 157.19999999999996, 'avg_reward': 0.7859999999999998, 'num_cells': 40, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 386.2200000000008, 'avg_reward': 1.2874000000000028, 'num_cells': 40, 'accuracy': 0.19}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 446.18000000000103, 'avg_reward': 1.1154500000000025, 'num_cells': 40, 'accuracy': 0.165}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 632.3100000000013, 'avg_reward': 1.2646200000000025, 'num_cells': 40, 'accuracy': 0.188}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 726.6700000000018, 'avg_reward': 1.2111166666666697, 'num_cells': 39, 'accuracy': 0.18}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 928.4600000000032, 'avg_reward': 1.32637142857




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 155.5099999999998, 'avg_reward': 1.555099999999998, 'num_cells': 40, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 155.5099999999998, 'avg_reward': 0.777549999999999, 'num_cells': 40, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 351.3600000000006, 'avg_reward': 1.171200000000002, 'num_cells': 40, 'accuracy': 0.17333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 418.88000000000085, 'avg_reward': 1.0472000000000021, 'num_cells': 40, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 591.9800000000013, 'avg_reward': 1.1839600000000026, 'num_cells': 40, 'accuracy': 0.176}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 700.6200000000017, 'avg_reward': 1.1677000000000028, 'num_cells': 39, 'accuracy': 0.17333333333333334}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 897.7900000000022, 'a




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 156.77000000000024, 'avg_reward': 1.5677000000000023, 'num_cells': 40, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 156.77000000000024, 'avg_reward': 0.7838500000000012, 'num_cells': 40, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 359.7600000000011, 'avg_reward': 1.1992000000000038, 'num_cells': 40, 'accuracy': 0.17666666666666667}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 420.5600000000012, 'avg_reward': 1.051400000000003, 'num_cells': 40, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 608.7900000000012, 'avg_reward': 1.2175800000000023, 'num_cells': 40, 'accuracy': 0.18}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 702.7300000000017, 'avg_reward': 1.1712166666666695, 'num_cells': 39, 'accuracy': 0.17333333333333334}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 904.1000000000021, 




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 157.60999999999925, 'avg_reward': 1.5760999999999925, 'num_cells': 40, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 157.60999999999925, 'avg_reward': 0.7880499999999963, 'num_cells': 40, 'accuracy': 0.115}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 352.20000000000016, 'avg_reward': 1.1740000000000006, 'num_cells': 40, 'accuracy': 0.17333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 419.3000000000003, 'avg_reward': 1.0482500000000008, 'num_cells': 40, 'accuracy': 0.155}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 585.2799999999979, 'avg_reward': 1.1705599999999958, 'num_cells': 40, 'accuracy': 0.18}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 687.1999999999983, 'avg_reward': 1.1453333333333306, 'num_cells': 39, 'accuracy': 0.175}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 884.3699999999983, 'avg_reward'




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 144.57000000000016, 'avg_reward': 1.4457000000000015, 'num_cells': 52, 'accuracy': 0.21}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 144.57000000000016, 'avg_reward': 0.7228500000000008, 'num_cells': 51, 'accuracy': 0.105}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 259.06000000000023, 'avg_reward': 0.8635333333333342, 'num_cells': 51, 'accuracy': 0.12666666666666668}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 393.7100000000003, 'avg_reward': 0.9842750000000008, 'num_cells': 51, 'accuracy': 0.145}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 582.1000000000008, 'avg_reward': 1.1642000000000017, 'num_cells': 51, 'accuracy': 0.172}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 758.7000000000014, 'avg_reward': 1.2645000000000024, 'num_cells': 51, 'accuracy': 0.18666666666666668}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 844.438255034397




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 162.66999999999993, 'avg_reward': 1.6266999999999994, 'num_cells': 54, 'accuracy': 0.24}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 290.98000000000013, 'avg_reward': 1.4549000000000007, 'num_cells': 53, 'accuracy': 0.215}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 486.9000000000004, 'avg_reward': 1.6230000000000013, 'num_cells': 53, 'accuracy': 0.24}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 493.61000000000035, 'avg_reward': 1.2340250000000008, 'num_cells': 53, 'accuracy': 0.1825}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 682.2700000000012, 'avg_reward': 1.3645400000000025, 'num_cells': 53, 'accuracy': 0.202}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 770.3600000000016, 'avg_reward': 1.283933333333336, 'num_cells': 53, 'accuracy': 0.19}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 953.5700000000022, 'avg_reward': 1.36224285714




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 148.39000000000007, 'avg_reward': 1.4839000000000007, 'num_cells': 52, 'accuracy': 0.22}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 276.2800000000003, 'avg_reward': 1.3814000000000015, 'num_cells': 52, 'accuracy': 0.205}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 472.20000000000084, 'avg_reward': 1.5740000000000027, 'num_cells': 52, 'accuracy': 0.23333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 580.4000000000012, 'avg_reward': 1.4510000000000032, 'num_cells': 52, 'accuracy': 0.215}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 792.1100000000022, 'avg_reward': 1.5842200000000044, 'num_cells': 52, 'accuracy': 0.234}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 792.1100000000022, 'avg_reward': 1.320183333333337, 'num_cells': 52, 'accuracy': 0.195}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 792.1100000000022, 'avg_reward':




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 145.85000000000002, 'avg_reward': 1.4585000000000001, 'num_cells': 52, 'accuracy': 0.2}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 273.73000000000025, 'avg_reward': 1.3686500000000013, 'num_cells': 52, 'accuracy': 0.195}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 462.0900000000009, 'avg_reward': 1.540300000000003, 'num_cells': 52, 'accuracy': 0.22333333333333333}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 569.8700000000013, 'avg_reward': 1.4246750000000032, 'num_cells': 52, 'accuracy': 0.2075}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 781.6200000000022, 'avg_reward': 1.5632400000000044, 'num_cells': 52, 'accuracy': 0.228}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 781.6200000000022, 'avg_reward': 1.3027000000000035, 'num_cells': 52, 'accuracy': 0.19}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 781.6200000000022, 'avg_reward': 




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 147.97, 'avg_reward': 1.4797, 'num_cells': 52, 'accuracy': 0.22}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 275.44000000000017, 'avg_reward': 1.3772000000000009, 'num_cells': 52, 'accuracy': 0.205}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 471.35000000000065, 'avg_reward': 1.5711666666666688, 'num_cells': 52, 'accuracy': 0.23333333333333334}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 578.7200000000009, 'avg_reward': 1.4468000000000023, 'num_cells': 52, 'accuracy': 0.215}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 792.9300000000018, 'avg_reward': 1.5858600000000036, 'num_cells': 52, 'accuracy': 0.234}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 792.9300000000018, 'avg_reward': 1.321550000000003, 'num_cells': 52, 'accuracy': 0.195}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 792.9300000000018, 'avg_reward': 1.1327571428571455, 'n




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 142.9199999999999, 'avg_reward': 1.429199999999999, 'num_cells': 52, 'accuracy': 0.21}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 270.3900000000001, 'avg_reward': 1.3519500000000004, 'num_cells': 52, 'accuracy': 0.2}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 465.8800000000009, 'avg_reward': 1.5529333333333364, 'num_cells': 52, 'accuracy': 0.23}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 486.040000000001, 'avg_reward': 1.2151000000000025, 'num_cells': 52, 'accuracy': 0.18}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 668.8300000000019, 'avg_reward': 1.3376600000000036, 'num_cells': 52, 'accuracy': 0.198}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 756.0900000000022, 'avg_reward': 1.2601500000000037, 'num_cells': 52, 'accuracy': 0.18666666666666668}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 938.0400000000027, 'avg_reward': 1.3400




[Frame 100] Metrics: {'total_updates': 100, 'total_reward': 153.84999999999994, 'avg_reward': 1.5384999999999993, 'num_cells': 52, 'accuracy': 0.23}
[Frame 200] Metrics: {'total_updates': 200, 'total_reward': 283.0100000000001, 'avg_reward': 1.4150500000000006, 'num_cells': 52, 'accuracy': 0.21}
[Frame 300] Metrics: {'total_updates': 300, 'total_reward': 478.92000000000064, 'avg_reward': 1.596400000000002, 'num_cells': 52, 'accuracy': 0.23666666666666666}
[Frame 400] Metrics: {'total_updates': 400, 'total_reward': 498.65000000000066, 'avg_reward': 1.2466250000000016, 'num_cells': 52, 'accuracy': 0.185}
[Frame 500] Metrics: {'total_updates': 500, 'total_reward': 676.7600000000014, 'avg_reward': 1.3535200000000027, 'num_cells': 52, 'accuracy': 0.2}
[Frame 600] Metrics: {'total_updates': 600, 'total_reward': 676.7600000000014, 'avg_reward': 1.1279333333333357, 'num_cells': 52, 'accuracy': 0.16666666666666666}
[Frame 700] Metrics: {'total_updates': 700, 'total_reward': 676.7600000000014, '

In [None]:
## Generate latex table
results = np.load("data/hyperparameter.npy")

s = " & ".join(map(str, para1_list))+ "\n"

for i2 in range(len(para2_list)):
    s += f"\\textbf" + '{' + f"{str(para2_list[i2])}" + '}' + f" & {" & ".join(map(str, map(int, results[:,i2,1].tolist())))} & {results[:,i2,1].mean():.2f}\\\\\n"
    s += "\\hline\n"

s += "\\textbf{Avg hits}" + f" & {" & ".join(['%.2f' % result for result in results[:,:,1].mean(axis=1).tolist()])} & {results[:,:,1].mean():.2f}\\\\\n"
s += "\\hline\n"

print(s)

ValueError: unsupported format character ':' (0x3a) at index 1