<a href="https://colab.research.google.com/github/krvicky/open_spiel/blob/main/MCCFR_for_Kuhn_poker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install required dependencies
!pip install open-spiel  # This might take some time

# Import necessary libraries
from absl import app
from absl import flags

from open_spiel.python.algorithms import exploitability
from open_spiel.python.algorithms import external_sampling_mccfr as external_mccfr
from open_spiel.python.algorithms import outcome_sampling_mccfr as outcome_mccfr
import pyspiel


Collecting open-spiel
  Downloading open_spiel-1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: open-spiel
Successfully installed open-spiel-1.3


In [6]:
game_name = "kuhn_poker"
players = 2
sampling = "external" #can be "external" or "outcome"
iterations = 10000 #number of iterations
print_freq = 1000 #how oftern to print the exploitability

In [10]:
game = pyspiel.load_game(game_name, {"players": players})
if sampling == "external":
  cfr_solver = external_mccfr.ExternalSamplingSolver(
      game, external_mccfr.AverageType.SIMPLE)
else:
  cfr_solver = outcome_mccfr.OutcomeSamplingSolver(game)
for i in range(iterations):
  cfr_solver.iteration()
  if i % print_freq == 0 or i == iterations-1:
    conv = exploitability.nash_conv(game, cfr_solver.average_policy())
    print("Iteration {} exploitability {}".format(i, conv))

Iteration 0 exploitability 0.7500003333326666
Iteration 1000 exploitability 0.05038884825887441
Iteration 2000 exploitability 0.01859592341482902
Iteration 3000 exploitability 0.01723649173341002
Iteration 4000 exploitability 0.03260452587287532
Iteration 5000 exploitability 0.02256637502779557
Iteration 6000 exploitability 0.016831276524625016
Iteration 7000 exploitability 0.019252454666777996
Iteration 8000 exploitability 0.005580455492516123
Iteration 9000 exploitability 0.014030758723283132
Iteration 9999 exploitability 0.013402992885763854


In [11]:
sampling = "outcome" #can be "external" or "outcome"

In [12]:
game = pyspiel.load_game(game_name, {"players": players})
if sampling == "external":
  cfr_solver = external_mccfr.ExternalSamplingSolver(
      game, external_mccfr.AverageType.SIMPLE)
else:
  cfr_solver = outcome_mccfr.OutcomeSamplingSolver(game)
for i in range(iterations):
  cfr_solver.iteration()
  if i % print_freq == 0 or i == iterations-1:
    conv = exploitability.nash_conv(game, cfr_solver.average_policy())
    print("Iteration {} exploitability {}".format(i, conv))

Iteration 0 exploitability 0.9166666666666666
Iteration 1000 exploitability 0.13391324096162088
Iteration 2000 exploitability 0.10004644688542297
Iteration 3000 exploitability 0.09209738727970063
Iteration 4000 exploitability 0.09710811624389049
Iteration 5000 exploitability 0.0642483090504069
Iteration 6000 exploitability 0.06348347944257438
Iteration 7000 exploitability 0.05471588837821101
Iteration 8000 exploitability 0.06455581100371988
Iteration 9000 exploitability 0.038333578804467194
Iteration 9999 exploitability 0.0410902669488965
