In [20]:
# from pgx.bridge_bidding import download_dds_results
# download_dds_results()

In [21]:
from functools import partial

from chex import PRNGKey
import chex

import jax
import jax.numpy as jnp

import wandb

from pgx import State
from pgx.bridge_bidding import BID_OFFSET_NUM, PASS_ACTION_NUM
from run_tournament import argmax_reverse, dds_policy, make_bzero_policy
from run_tournament import compute_elo_ratings

import bridge_env as env
from type_aliases import Done, Reward

from pprint import pprint
import pickle


def get_player_reward_for_bid(state, bid, player):
    chex.assert_shape(state._init_rng, [])
    state, _ = env.reset(state._init_rng)
    we_play_first = state.current_player == player
    action = bid + BID_OFFSET_NUM
    state, obs, rew, done = env.step(state, jnp.where(we_play_first, action, PASS_ACTION_NUM))
    state, obs, rew, done = env.step(state, jnp.where(we_play_first, PASS_ACTION_NUM, action))
    state, obs, rew, done = env.step(state, PASS_ACTION_NUM)
    state, obs, rew, done = env.step(state, PASS_ACTION_NUM)
    reward_we_play_first = state.rewards[player]
    state, obs, rew, done = env.step(state, PASS_ACTION_NUM)
    reward_we_play_second = state.rewards[player]
    
    return jnp.where(we_play_first, reward_we_play_first, reward_we_play_second)
        

def loop(args):
    _, subkey = jax.random.split(args[1])
    state, _ = env.reset(subkey)

    max_bid_0 = argmax_reverse(jax.vmap(get_player_reward_for_bid, in_axes=[None, 0, None])(state, jnp.arange(35), 0))
    max_bid_1 = argmax_reverse(jax.vmap(get_player_reward_for_bid, in_axes=[None, 0, None])(state, jnp.arange(35), 1))

    chex.assert_shape(max_bid_0, [])
    winnable = jnp.less(max_bid_0, max_bid_1)

    return (winnable, subkey)


def is_not_winnable_game(args):
    return jnp.logical_not(args[0])

def get_winnable_game(rng: PRNGKey):
    a = jax.lax.while_loop(is_not_winnable_game, loop, (False, rng))

    return env.reset(a[1])


def evaluate_pvp2(rng: PRNGKey, policy1, policy2, batch_size: int):
    def single_move(state: State, rng: PRNGKey) -> tuple[State, tuple[Reward, Done]]:
        rng0a, rng0b, rng1a, rng1b = jax.random.split(rng, 4)

        action_mask = state.legal_action_mask

        logits0 = policy1(rng0a, state)
        logits0_masked = jnp.where(action_mask, logits0, -1e9)
        action0 = jax.random.categorical(rng0b, logits0_masked)

        logits1 = policy2(rng1a, state)
        logits1_masked = jnp.where(action_mask, logits1, -1e9)
        action1 = jax.random.categorical(rng1b, logits1_masked)

        action = jnp.where(state.current_player == 0, action0, action1)

        new_state, new_observation, new_reward, new_done = jax.vmap(env.step)(
            state, action
        )
        return new_state, (new_state.rewards, new_done)

    rng, subkey = jax.random.split(rng)
    state, observation = jax.vmap(get_winnable_game)(jax.random.split(subkey, batch_size))
    first = state
    _, out = jax.lax.scan(single_move, first, jax.random.split(rng, env.max_steps))
    rewards, done = out
    chex.assert_shape(rewards, [env.max_steps, batch_size, 2])
    chex.assert_shape(done, [env.max_steps, batch_size])
    net_rewards = rewards[:, :, 0].sum(axis=0)
    episode_done = done.any(axis=0)
    return net_rewards, episode_done





In [22]:
player_names = [
    "dds",
    "bzero",
]

num_players = len(player_names)

policies = [
    dds_policy,
    make_bzero_policy(),
]

eval_funcs = [
    [
        jax.jit(
            partial(
                evaluate_pvp2, policy1=policies[p0], policy2=policies[p1], batch_size=64
            )
        )
        for p1 in range(num_players)
    ]
    for p0 in range(num_players)
]


In [23]:

# rng = jax.random.key(1)

# # eval_funcs[0][1](rng)

# rng, subkey = jax.random.split(rng)
# sks = jax.random.split(subkey, 4)


# fast_argmax = jax.jit(lambda state, player: argmax_reverse(jax.vmap(get_player_reward_for_bid, in_axes=[None, 0, None])(state, jnp.arange(35), player)))
# fast_env_reset = jax.jit(env.reset)

# winnable = False
# key = sks[1]
# while not winnable:
#     state, _ = fast_env_reset(key)

#     max_bid_0 = fast_argmax(state, 0)
#     max_bid_1 = fast_argmax(state, 1)

#     chex.assert_shape(max_bid_0, [])
#     winnable = jnp.greater(max_bid_0, max_bid_1)

#     print(winnable)
#     print(max_bid_0, max_bid_1)

#     if winnable:
#         break

#     _, key = jax.random.split(key)



In [24]:

rng = jax.random.key(1)

game_history = []


wandb.init(project="bridge-elo")

try:
    while True:
        for p0 in range(num_players - 1):
            for p1 in range(p0 + 1, num_players):
                rng, subkey = jax.random.split(rng)
                results, dones = eval_funcs[p0][p1](subkey)

                for result, done in zip(results, dones):
                    if done:
                        game_history.append([p0, p1, result.astype(jnp.int32).item()])

                elo_ratings = compute_elo_ratings(player_names, game_history)

                logs = elo_ratings
                logs["num_games"] = len(game_history)

                winrate = jnp.mean(jnp.array(game_history)[:, 2] < 0)
                logs["winrate"] = winrate

                pprint(logs)
                wandb.log(logs)

        with open(f"game_history-{len(game_history)}.pkl", "wb") as f:
            pickle.dump(game_history, f)
except KeyboardInterrupt:
    pass
finally:
    wandb.finish()
    with open(f"game_history-{len(game_history)}.pkl", "wb") as f:
        pickle.dump(game_history, f)



0,1
bzero,▁▂▆▅▇▇▇▇▇▇▇▇█▇█▇█▇█▇▇█▇█▆▆▇▇▇▇▆▇▇▇█▇█▇▆▇
dds,█▇▃▄▂▂▂▂▂▂▂▂▁▂▁▂▁▂▁▂▂▁▂▁▃▃▂▂▂▂▃▂▂▂▁▂▁▂▃▂
num_games,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
winrate,█▇▃▃▂▂▂▂▂▂▂▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂

0,1
bzero,1006.03994
dds,993.96006
num_games,3264.0
winrate,0.48039


{'bzero': 955.69468085046,
 'dds': 1044.30531914954,
 'num_games': 64,
 'winrate': Array(0.375, dtype=float32)}
{'bzero': 970.0639427652791,
 'dds': 1029.9360572347205,
 'num_games': 128,
 'winrate': Array(0.40625, dtype=float32)}
{'bzero': 1001.7099530977341,
 'dds': 998.2900469022659,
 'num_games': 192,
 'winrate': Array(0.484375, dtype=float32)}
{'bzero': 999.5072161941602,
 'dds': 1000.4927838058398,
 'num_games': 256,
 'winrate': Array(0.4921875, dtype=float32)}
{'bzero': 997.0036269049077,
 'dds': 1002.9963730950922,
 'num_games': 320,
 'winrate': Array(0.503125, dtype=float32)}
{'bzero': 1004.9411729722302,
 'dds': 995.0588270277698,
 'num_games': 384,
 'winrate': Array(0.5078125, dtype=float32)}
{'bzero': 1004.2324957311404,
 'dds': 995.7675042688595,
 'num_games': 448,
 'winrate': Array(0.5044643, dtype=float32)}
{'bzero': 1004.6232208907436,
 'dds': 995.3767791092564,
 'num_games': 512,
 'winrate': Array(0.5, dtype=float32)}
{'bzero': 1007.7761859400385,
 'dds': 992.223814059



0,1
bzero,▁▆▇▇▇██▇██▇█▇█▆▇▇▇▇▇▇▆▇▆▆▆▆█▇▆▇▆▇▇▅▇▆▇▇▆
dds,█▃▂▂▂▁▁▂▁▁▂▁▂▁▃▂▂▂▂▂▂▃▂▃▃▃▃▁▂▃▂▃▂▂▄▂▃▂▂▃
num_games,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
winrate,▁▆▇▇▇▇█▇███████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇

0,1
bzero,998.13162
dds,1001.86838
num_games,6464.0
winrate,0.50139
