In [None]:
import matplotlib.pyplot as plt
from typing import List
from chinese_checkers.simulation import S3SimulationCatalog, SimulationMetadata, SimulationData
from chinese_checkers.cnn import CnnEncoderExperience
import pandas as pd

from dataclasses import asdict
import random
from tqdm import tqdm
from chinese_checkers.experience import ExperienceData, S3ExperienceCatalog, ExperienceMetadata

In [None]:
sim_catalog = S3SimulationCatalog()
# name = "dql-cnn-v002-vs-bootstrap-p0-simulation"
sim_metadata: List[SimulationMetadata] = sim_catalog.list_datasets()
pd.DataFrame([asdict(m) for m in sim_metadata])

In [None]:
player_count = 2

simulations: List[SimulationData] = [
    dataset
    for metadata in sim_metadata[:1]
    if metadata.player_count == player_count
        and metadata.winning_player in ["0", "3"]
        # and metadata.name == name 
    for dataset in sim_catalog.load_dataset(metadata)
]
print(f"Found {len(simulations)} datasets for player_count: {player_count}.")

In [None]:
exp_catalog = S3ExperienceCatalog()
# v0.0.4 punishes the player for keeping pieces in the starting area.  It also rewards positions in the target area.
# 
pd.DataFrame([asdict(e) for e in exp_catalog.list_datasets() if e.generator_name == "CnnExperienceEncoder-v005"])

In [None]:
exp_encoder = CnnEncoderExperience("v005")

In [None]:
import time
import random
from tqdm import tqdm

# Shuffle simulations randomly
random.shuffle(simulations)

total_experiences = 0
with tqdm(simulations, desc="Generating experiences") as pbar:
    for simulation in pbar:
        # Start timing the encoding process
        encode_start_time = time.time()
        experiences = exp_encoder.encode(simulation)
        encode_end_time = time.time()
        encode_time = encode_end_time - encode_start_time

        # Start timing the upload process
        upload_start_time = time.time()
        exp_catalog.add_record_list(experiences)
        upload_end_time = time.time()
        upload_time = upload_end_time - upload_start_time

        total_experiences += len(experiences)

        # Update progress bar description with timing information
        pbar.set_description(
            f"Generating experiences (Last generated: {len(experiences)}, "
            f"Total {total_experiences}, Encode time: {encode_time:.2f}s, Upload time: {upload_time:.2f}s)"
        )


In [None]:
print(exp_encoder.encode(simulations[0]))

In [None]:
# Initialize tqdm with detailed postfix for metadata info
generator_name="CnnExperienceEncoder-v006"
player_count = 2
current_player = "0"
board_size = 4

catalog = ExperienceCatalog()
dataset_metadata: List[ExperienceMetadata] = catalog.list_datasets()
experiences: List[ExperienceData] = []
progress_bar = tqdm(dataset_metadata, desc="Loading datasets")

for metadata in progress_bar:
    # Dynamically set postfix with metadata name and generator name
    progress_bar.set_postfix(name=metadata.name, generator=metadata.generator_name)

    if (
        metadata.player_count == str(player_count)
        and metadata.current_player == current_player
        and metadata.generator_name == generator_name
        and metadata.board_size == str(board_size)
    ):
        experiences.extend(catalog.load_dataset(metadata))

print(f"Found {len(experiences)} datasets for player_count: {player_count}, and current_player: {current_player}.")

In [None]:
class ExperienceAnalysis:

    experience: List[ExperienceData]

    def __init__(self, experiences: List[ExperienceData]):
        self.experiences: List[ExperienceData] = experiences

    def print_winner_counts(self):
        p0_win_count = len([e for e in self.experiences if e.metadata.winning_player == "0"])
        p3_win_count = len([e for e in self.experiences if e.metadata.winning_player == "3"])
        print(f"p0_win_count {p0_win_count}, p3_win_count {p3_win_count}")


    def check_feature_overlap(self):
        # Extract rewards for each player group
        win_rewards = [e.data.reward.item() for e in self.experiences if e.metadata.winning_player == "0"]
        loss_rewards = [e.data.reward.item() for e in self.experiences if e.metadata.winning_player == "3"]


        plt.figure(figsize=(10, 6))
        plt.hist(win_rewards, bins=60, alpha=0.6, color='blue', edgecolor='black', label="Winning Game: Moves Rewards")
        plt.hist(loss_rewards, bins=60, alpha=0.6, color='green', edgecolor='black', label="Losing Game: Move Rewards")

        plt.title(f"Overlayed reward distributions for winning and losing move rewards - Encoder {generator_name}")
        plt.xlabel("Reward")
        plt.ylabel("Frequency")
        plt.legend()
        plt.savefig(f'win-loss-rewards-{generator_name}.png')
        plt.show()

an = ExperienceAnalysis(experiences)
an.print_winner_counts()
an.check_feature_overlap()