In [1]:
import site

site.addsitedir("..")

from pathlib import Path

from src.ai.train import *
from src.game.tasks import *
from src.game.types import *

In [2]:
outdir = Path("/Users/davidzheng/projects/crew-ai/outdirs/0323/run_26")

In [3]:
settings_dict = torch.load(outdir / "settings.pth", weights_only=False)
settings = settings_dict["settings"]
hp = settings_dict["hp"]
models = get_models(hp, settings)
policy_model = models["policy"].eval()
value_model = models["value"].eval()

In [4]:
checkpoint = torch.load(outdir / "checkpoint.pth", weights_only=False)
policy_model.load_state_dict(checkpoint["policy_model"])
value_model.load_state_dict(checkpoint["value_model"])
td = checkpoint["td"]
td_train = td[: hp.num_train_rollouts_per_round]
td_val = td[hp.num_train_rollouts_per_round :]

In [5]:
td_train["values"], td_train["aux_info_preds"] = value_model(td_train["inps"])

In [126]:
def to_card(x):
    if x[1] == settings.num_suits:
        return "nosignal"
    return str(Card(x[0].item() + 1, settings.get_suit(x[1].item())))

In [127]:
print(f"Win rate: {td_train['win'].float().mean():.3f}")
idxs = np.arange(len(td_train["win"]))
bad_idx = idxs[~td_train["win"]][1]
print(f"Bad idx: {bad_idx}")

bad_ex = td_train[bad_idx]

Win rate: 0.734
Bad idx: 14


In [116]:
def print_game(bad_ex):
    private = bad_ex["inps"]["private"]
    valid_actions = bad_ex["inps"]["valid_actions"]
    actions = bad_ex["actions"]
    probs = bad_ex["orig_probs"]
    prev_trick = None
    tasks = ", ".join(
        [
            f"P{p}: {EASY_TASK_DEFS[tidx][0]}"
            for tidx, p in sorted(bad_ex["inps"]["task_idxs"][0], key=lambda x: x[1])
        ]
    )
    print(f"Tasks: {tasks}")
    prev_phase = None

    for move, (_priv, _act, _valid, _probs) in enumerate(
        zip(private, actions, valid_actions, probs)
    ):
        if _priv["trick"] != prev_trick:
            prev_trick = _priv["trick"]
            print("=" * 50)

        phase = "signal" if _priv["phase"] == 1 else "play"
        if phase == "play" and prev_phase == "signal":
            print("-" * 25)
        prev_phase = phase

        if phase == "signal":
            continue

        print(f"Trick: {_priv['trick']} Phase: {phase} Player: {_priv['player_idx']}")
        hand = " ".join(to_card(x) for x in _priv["hand"] if x[0] != -1)
        print(f"Hand: {hand}")
        action = to_card(_valid[_act])
        print(f"Action: {action} Prob: {_probs[_act]:.2f}")
        print()


In [149]:
bad_ex["aux_infos"][:, 24:27]

tensor([[1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [1., 0., 1.]])

In [150]:
bad_ex["aux_info_preds"][:, 24:27]

tensor([[ 0.9672,  0.1922,  0.8037],
        [ 1.0851,  0.0589,  0.7104],
        [ 1.0885,  0.0754,  0.6986],
        [ 0.9788,  0.1419,  0.7599],
        [ 1.0443,  0.1916,  0.7128],
        [ 1.0315,  0.0755,  0.8035],
        [ 0.9955,  0.0969,  0.9280],
        [ 1.0548,  0.0938,  0.9243],
        [ 1.0431,  0.0649,  0.9400],
        [ 0.9964,  0.0851,  0.9425],
        [ 1.1097,  0.0143,  0.9269],
        [ 1.0833,  0.0359,  0.9427],
        [ 0.9758,  0.0877,  0.8915],
        [ 1.0943,  0.0347,  0.8907],
        [ 1.0854, -0.0168,  0.9331],
        [ 0.9942,  0.1172,  0.9263],
        [ 1.0722,  0.0160,  0.9035],
        [ 1.0758, -0.0614,  0.9196],
        [ 0.9695,  0.1577,  0.9088],
        [ 1.0677,  0.0506,  0.8470],
        [ 1.0613, -0.0177,  0.8841],
        [ 0.9766,  0.1177,  0.9001],
        [ 1.0392,  0.0472,  0.8596],
        [ 1.0472,  0.1162,  0.8310],
        [ 0.9922,  0.0264,  1.0586],
        [ 1.0048, -0.0106,  1.0702],
        [ 1.0248, -0.0514,  0.9996],
 

In [118]:
bad_ex["orig_values"]

tensor([-0.2609, -0.1687, -0.1532, -0.2397, -0.1025, -0.1593, -0.3126, -0.2512,
        -0.2514, -0.3292, -0.2066, -0.2900, -0.5479, -0.5527, -0.5545, -0.5905,
        -0.5391, -0.5710, -0.6303, -0.5882, -0.6298, -0.5822, -0.6665, -0.5740,
        -0.6190, -0.6124, -0.5678, -0.5811, -0.6703, -0.5849, -0.6066, -0.5860,
        -0.6085, -0.5988, -0.6239, -0.6206])

In [119]:
td_train["orig_values"].mean(dim=0)

tensor([ 0.6470,  0.6612,  0.6575,  0.6429,  0.6654,  0.6470,  0.5519,  0.5577,
         0.5546,  0.5499,  0.5373,  0.5328,  0.4633,  0.4564,  0.4515,  0.4478,
         0.4367,  0.4324,  0.3512,  0.3341,  0.3371,  0.3310,  0.3250,  0.3274,
         0.1616,  0.1663,  0.1574,  0.1603,  0.1511,  0.1401, -0.0244, -0.0210,
        -0.0332, -0.0277, -0.0314, -0.0602])

In [120]:
print_game(bad_ex)

Tasks: P0: #T>=1, P1: #T>=2, P2: #T>=1
-------------------------
Trick: 0 Phase: play Player: 0
Hand: 4b 1g 3g 4y 1t 2t
Action: 1g Prob: 1.00

Trick: 0 Phase: play Player: 1
Hand: 2b 1p 2p 4p 1y 2y
Action: 1y Prob: 0.54

Trick: 0 Phase: play Player: 2
Hand: 1b 3b 2g 4g 3p 3y
Action: 2g Prob: 0.96

-------------------------
Trick: 1 Phase: play Player: 2
Hand: 1b 3b 4g 3p 3y
Action: 1b Prob: 0.97

Trick: 1 Phase: play Player: 0
Hand: 4b 3g 4y 1t 2t
Action: 4b Prob: 1.00

Trick: 1 Phase: play Player: 1
Hand: 2b 1p 2p 4p 2y
Action: 2b Prob: 1.00

-------------------------
Trick: 2 Phase: play Player: 0
Hand: 3g 4y 1t 2t
Action: 4y Prob: 0.11

Trick: 2 Phase: play Player: 1
Hand: 1p 2p 4p 2y
Action: 2y Prob: 1.00

Trick: 2 Phase: play Player: 2
Hand: 3b 4g 3p 3y
Action: 3y Prob: 1.00

-------------------------
Trick: 3 Phase: play Player: 0
Hand: 3g 1t 2t
Action: 3g Prob: 0.93

Trick: 3 Phase: play Player: 1
Hand: 1p 2p 4p
Action: 4p Prob: 0.99

Trick: 3 Phase: play Player: 2
Hand: 3b 4g 3