In [None]:
%pip install -U open_spiel

In [1]:
import functools

import flax.nnx as nnx
import jax
import optax
import torch
import torch.nn as nn

from open_spiel.python.jax import rcfr as rcfr_jax
import pyspiel
from open_spiel.python.pytorch import rcfr as rcfr_pt

game = pyspiel.load_game('kuhn_poker')
batch_size = 12

Optional module pokerkit_wrapper was not importable: No module named 'pokerkit'


In [2]:
def flax_example(game_name, num_epochs, iterations):

  @nnx.vmap(in_axes=(None, 0), out_axes=0)
  def forward(model: nnx.Module, x: jax.Array) -> jax.Array:
    """Batched call for the flax.nnx model."""
    return model(x)

  @functools.partial(jax.jit, static_argnames=("graphdef",))
  def jax_train_step(
      graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array, y: jax.Array
  ) -> tuple:
    """Train step in pure jax."""

    model, optimizer = nnx.merge(graphdef, state, copy=True)

    def loss_fn(model):
      y_pred = forward(model, x)
      return optax.hinge_loss(y_pred, y).mean()

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    state = nnx.state((model, optimizer))
    return loss, state

  game = pyspiel.load_game(game_name)

  models = []
  for _ in range(game.num_players()):
    models.append(
        rcfr_jax.DeepRcfrModel(
            game,
            num_hidden_layers=1,
            num_hidden_units=13,
            num_hidden_factors=8,
            use_skip_connections=True,
        )
    )

  # these parameters are fixed initially
  buffer_size = -1
  truncate_negative = False
  bootstrap = False

  if buffer_size > 0:
    solver = rcfr_jax.ReservoirRcfrSolver(
        game, models, buffer_size, truncate_negative=truncate_negative
    )
  else:
    solver = rcfr_jax.RcfrSolver(
        game, models, truncate_negative=truncate_negative, bootstrap=bootstrap
    )

  batch_size = 12
  step_size = 0.01

  def _train_fn(model: nn.Module, data: tuple) -> None:
    """Train `model` on `data`."""
    data_, rng = data
    optimizer = nnx.Optimizer(
        model, optax.amsgrad(learning_rate=step_size), wrt=nnx.Param
    )
    graphdef, state = nnx.split((model, optimizer))

    num_batches = len(data_[0]) // batch_size
    data_ = jax.tree.map(
        lambda x: jax.random.permutation(rng, x, axis=0).reshape(
            num_batches, batch_size, -1
        ),
        data_,
    )

    for _ in range(num_epochs):
      for x, y in zip(*data_):
        _, state = jax_train_step(graphdef, state, x, y.squeeze(-1))

    nnx.update((model, optimizer), state)
    return

  result = []
  for i in range(iterations):
    solver.evaluate_and_update_policy(_train_fn, jax.random.key(i))
    if i % 10 == 0:
      conv = pyspiel.exploitability(game, solver.average_policy())
      result.append(conv)
      # print("Iteration {} exploitability {}".format(i, conv))
  return result

In [3]:
def pytorch_example(game_name, num_epochs, iterations):
  game = pyspiel.load_game(game_name)

  models = []
  for _ in range(game.num_players()):
    models.append(
        rcfr_pt.DeepRcfrModel(
            game,
            num_hidden_layers=1,
            num_hidden_units=13,
            num_hidden_factors=8,
            use_skip_connections=True,
        )
    )

  buffer_size = -1
  truncate_negative = False
  bootstrap = False
  if buffer_size > 0:
    solver = rcfr_pt.ReservoirRcfrSolver(
        game, models, buffer_size, truncate_negative=truncate_negative
    )
  else:
    solver = rcfr_pt.RcfrSolver(
        game, models, truncate_negative=truncate_negative, bootstrap=bootstrap
    )

  def _train_fn(model, data):
    """Train `model` on `data`."""
    batch_size = 100
    step_size = 0.01

    data = torch.utils.data.DataLoader(
        data, batch_size=batch_size, shuffle=True
    )
    loss_fn = nn.SmoothL1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=step_size, amsgrad=True)

    def _train(model, data):
      for _ in range(num_epochs):
        for x, y in data:
          optimizer.zero_grad()
          output = model(x)
          loss = loss_fn(output, y)
          loss.backward()
          optimizer.step()

    _train(model, data)

  # End of _train_fn
  result = []
  for i in range(iterations):
    solver.evaluate_and_update_policy(_train_fn)
    if i % 10 == 0:
      conv = pyspiel.exploitability(game, solver.average_policy())
      result.append(conv)
  return result

In [None]:
flax_rcfr = []
pytorch_rcfr = []
num_epochs, iterations = 20, 100
for _ in range(10):
  flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))
  pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))

In [None]:
import matplotlib.pyplot as plt

x = [i for i in range(10)]
flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]
pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]

plt.plot(x, flax_exploitability, label="flax.nnx")
plt.plot(x, pt_exploitability, label="pytorch")

plt.legend()

plt.show()

In [None]:
flax_rcfr = []
pytorch_rcfr = []
num_epochs, iterations = 200, 100
for _ in range(10):
  flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))
  pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))

In [None]:
import matplotlib.pyplot as plt

x = [i for i in range(10)]
flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]
pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]

plt.plot(x, flax_exploitability, label="flax_nnx")
plt.plot(x, pt_exploitability, label="pytorch")

plt.legend()

plt.show()

In [None]:
flax_rcfr = []
pytorch_rcfr = []
num_epochs, iterations = 20, 100
for _ in range(10):
  flax_rcfr.append(flax_example('leduc_poker', num_epochs, iterations))
  pytorch_rcfr.append(pytorch_example('leduc_poker', num_epochs, iterations))

In [None]:
import matplotlib.pyplot as plt

x = [i for i in range(10)]
flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]
pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]

plt.plot(x, flax_exploitability, label="flax_nnx")
plt.plot(x, pt_exploitability, label="pytorch")

plt.legend()

plt.show()