In [7]:
import pyspiel
import json
import os

# Load your game
game = pyspiel.load_game("patriks_coin_game")

# Wrap the game to make it sequential
sequential_game = pyspiel.convert_to_turn_based(game)

# Initialize CFR Solver
cfr_solver = pyspiel.CFRSolver(sequential_game)
num_iterations = 1000

# Initialize a dictionary to store action probabilities over iterations
action_prob_history = {f"Player 0 chose: {i}": [] for i in range(6)}  # Assuming 6 choices: 0-5

for i in range(1, num_iterations + 1):
    cfr_solver.evaluate_and_update_policy()

    # Get the current average policy
    current_policy = cfr_solver.average_policy()

    # Traverse the initial state to get action probabilities
    initial_state = sequential_game.new_initial_state()
    player = initial_state.current_player()
    action_probs = current_policy.action_probabilities(initial_state)

    # Store the probabilities
    for action, prob in action_probs.items():
        action_name = sequential_game.action_to_string(player, action)
        if action_name in action_prob_history:
            action_prob_history[action_name].append(prob)
        else:
            # Handle unexpected actions if any
            action_prob_history[action_name] = [prob]

    if i % 100 == 0 or i == 1:
        print(f"Iteration {i} completed")

# Get the average policy after CFR
average_policy = cfr_solver.average_policy()

Iteration 1 completed
Iteration 100 completed
Iteration 200 completed
Iteration 300 completed
Iteration 400 completed
Iteration 500 completed
Iteration 600 completed
Iteration 700 completed
Iteration 800 completed
Iteration 900 completed
Iteration 1000 completed


In [8]:
action_prob_history

{'Player 0 chose: 0': [],
 'Player 0 chose: 1': [],
 'Player 0 chose: 2': [],
 'Player 0 chose: 3': [],
 'Player 0 chose: 4': [],
 'Player 0 chose: 5': [],
 'Action(id=5, player=0)': [0.16666666666666669,
  0.08333333333333333,
  0.13682049396335108,
  0.1583167649460296,
  0.1550366020819471,
  0.13979194759552724,
  0.12255269605318297,
  0.1072336090465351,
  0.0953187635969201,
  0.08578688723722808,
  0.07798807930657098,
  0.07148907269769007,
  0.06598991325940622,
  0.06127634802659149,
  0.057191258158152056,
  0.05361680452326755,
  0.050462874845428295,
  0.047659381798460056,
  0.04515099328275163,
  0.042893443618614045,
  0.04085089868439433,
  0.038994039653285496,
  0.03729864662488178,
  0.03574453634884504,
  0.03431475489489124,
  0.032994956629703115,
  0.03177292119897337,
  0.03063817401329575,
  0.029581685254216585,
  0.028595629079076035,
  0.027673189431363904,
  0.026808402261633783,
  0.025996026435523668,
  0.025231437422714147,
  0.0245105392106366,
  0.02

In [13]:
!pip install dm-haiku

Collecting dm-haiku
  Downloading dm_haiku-0.0.12-py3-none-any.whl.metadata (19 kB)
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl.metadata (8.9 kB)
Collecting tabulate>=0.8.9 (from dm-haiku)
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Collecting flax>=0.7.1 (from dm-haiku)
  Downloading flax-0.9.0-py3-none-any.whl.metadata (11 kB)
Collecting msgpack (from flax>=0.7.1->dm-haiku)
  Downloading msgpack-1.1.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (8.4 kB)
Collecting orbax-checkpoint (from flax>=0.7.1->dm-haiku)
  Downloading orbax_checkpoint-0.7.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorstore (from flax>=0.7.1->dm-haiku)
  Downloading tensorstore-0.1.66-cp310-cp310-macosx_11_0_arm64.whl.metadata (3.0 kB)
Collecting rich>=11.1 (from flax>=0.7.1->dm-haiku)
  Downloading rich-13.9.2-py3-none-any.whl.metadata (18 kB)
Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax>=0.7.1->dm-haiku)
  Using cached markdown_it_py-3.0.0-

In [10]:
!pip install chex
!pip install haiku
!pip install jax
!pip install optax

Collecting chex
  Downloading chex-0.1.87-py3-none-any.whl.metadata (17 kB)
Collecting jax>=0.4.27 (from chex)
  Downloading jax-0.4.34-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib>=0.4.27 (from chex)
  Downloading jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting toolz>=0.9.0 (from chex)
  Downloading toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB)
Collecting ml-dtypes>=0.2.0 (from jax>=0.4.27->chex)
  Downloading ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl.metadata (21 kB)
Collecting opt-einsum (from jax>=0.4.27->chex)
  Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Downloading chex-0.1.87-py3-none-any.whl (99 kB)
Downloading jax-0.4.34-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl (67.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.6/67

In [14]:
import enum
import functools
from typing import Any, Callable, Sequence, Tuple

import chex
import haiku as hk
import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util as tree
import numpy as np
import optax

from open_spiel.python import policy as policy_lib
import pyspiel