In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import math
import os
import pathlib
import random
import time
import yaml

from copy import deepcopy
from itertools import pairwise
from typing import List, NamedTuple, Optional, Tuple

import backgammon
import numpy as np
import openai
import pandas as pd
from backgammon.backgammon import STARTING_MATCH_ID, STARTING_POSITION_ID, Move, Player

from sklearn.metrics import confusion_matrix

In [3]:
eval_path = pathlib.Path("../evals/evals/registry/evals/")
eval_data_path = pathlib.Path("../evals/evals/registry/data/")

In [4]:
STARTING_MATCH_ID, STARTING_POSITION_ID

('cAgAAAAAAAAA', '4HPwATDgc/ABMA')

In [5]:
openai.api_key = os.environ.get("OPENAI_API_KEY")


In [6]:
def roll_and_play(b) -> backgammon.Backgammon:
    """Generates "naive" backgammon moves for random dice roll and plays them.

    Args:
        b (backgammon.Backgammon): a backgammon instance

    Returns:
        backgammon.Backgammon: board after moves habe been applied
    """

    b.roll()
    plays = [play for play in b.generate_plays()]

    # cannot move
    if not plays:
        return b.end_turn()

    # naive way of "evaluating" a position with as few slots as possible
    # to have fewer legal moves availabe when creating the eval data
    idx_to_slots = [
        (idx, (np.array(p.position.board_points) == 1).sum())
        for idx, p in enumerate(plays)
    ]
    idx_to_slots.sort(key=lambda x: x[1])
    chosen_play = plays[idx_to_slots[0][0]]

    moves = tuple(((x.source, x.destination) for x in chosen_play.moves))
    b.play(moves)

    return b

In [116]:
MAX_SAMPLES = 10_000

b = backgammon.Backgammon()

eval_data = []
for _ in range(MAX_SAMPLES):
    b = backgammon.Backgammon()  # starting pos

    rounds = random.choice([4, 6])
    for x in range(rounds):
        b = roll_and_play(b)

    # roll once more to get the final board state
    b.roll()

    position_id, match_id = b.encode().split(":")

    # moves which would hit the opponent (again)
    hit_moves = [
        # moves: namedtuples in df does not play nice when storing
        [tuple(m) for m in play.moves]
        for play in b.generate_plays()
        if play.position.opponent_bar > b.position.opponent_bar
    ]

    eval_data.append(
        {
            "match_id": match_id,
            "position_id": position_id,
            "encoded": b.encode(),
            "dice": b.match.dice,
            "rounds": rounds,
            "hit_moves": hit_moves,
            "can_hit": any(hit_moves),
            "player_bar": b.position.player_bar,
            "opponent_bar": b.position.opponent_bar,
        }
    )

df = pd.DataFrame(eval_data)

In [119]:
df.drop_duplicates(subset=["encoded"], inplace=True)
print(df.shape)


(9885, 9)


In [120]:
def get_illegal_blocked_move_for_roll(b, roll) -> List[Tuple[Move, ...]]:
    """Returns a list of Move tuples given an individual dice roll- all those moves are not allowed
    since they would move to a point with 2+ checkers of the opponent

    Args:
        b (backgammon.Backgammon): the board
        roll (int): dice roll

    Returns:
        List[Tuple[Move, ...]]: List if Move tuples, e.g.
            [Move(pips=5, source=23, destination=18), ... ]
    """

    df = pd.DataFrame(b.position.board_points, columns=["board_points"])
    df["board"] = list(range(0, 24))
    # more intuitive
    df = df.reindex(index=df.index[::-1])
    df["shifted_board_points"] = df["board_points"].shift(roll)
    df["illegal"] = (
        (df.board_points * df.shifted_board_points < 0)
        & (df.shifted_board_points > 0)
        & (df.board_points < -1)
    )
    df["source"] = df.board + roll
    df["destination"] = df.board
    df["pips"] = roll

    return list(
        df[df.illegal == True][["pips", "source", "destination"]].itertuples(
            name="Move", index=False
        )
    )


In [121]:
def get_illegal_moves(row) -> Tuple[Move, ...]:
    """Constructd a "valid" illegal move.

    Ignores moves that pass through, e.g. 7/6/5 (not make it too complicated)

    Args:
        row (row): Pandas Series

    Returns:
        Tuple[Move, ...]: The illegal Move
    """

    b = backgammon.Backgammon(position_id=row.position_id, match_id=row.match_id)
    all_plays = [play for play in b.generate_plays()]

    # ignore moves where we move the same checker twice - regardless if the slot is empty or not
    def checker_is_moving_through(moves):
        return [p[0].destination == p[1].source for p in pairwise(moves)]

    plays = [
        play for play in all_plays if not any(checker_is_moving_through(play.moves))
    ]

    illegal_moves = {
        roll: get_illegal_blocked_move_for_roll(b, roll) for roll in b.match.dice
    }

    # everygthing is valid
    if not any(illegal_moves.values()) or not (plays):
        return []

    # choose randon play to manipulate
    random_play = random.choice(plays)

    new_illegal_move = list(random_play.moves)
    for pos, move in enumerate(random_play.moves):
        # do we have an invalid blocked move for this roll?
        if not illegal_moves[move.pips]:
            # not - we keep the original
            continue
        else:
            # chose a random illegal move for the roll
            new_illegal_move[pos] = random.choice(illegal_moves[move.pips])
            break

    return [tuple(x) for x in new_illegal_move]


df["illegal_move"] = df.apply(lambda row: get_illegal_moves(row), axis=1)
df["has_illegal_move"] = df.illegal_move.astype(bool)


In [122]:
df.sample(n=5)


Unnamed: 0,match_id,position_id,encoded,dice,rounds,hit_moves,can_hit,player_bar,opponent_bar,illegal_move,has_illegal_move
9454,cIgSAAAAAAAA,cOeDATCY88QBMA,cOeDATCY88QBMA:cIgSAAAAAAAA,"(5, 4)",6,[],False,0,0,"[(5, 23, 18), (4, 9, 5)]",True
1209,cAgGAAAAAAAA,sGfDATA2zuABMA,sGfDATA2zuABMA:cAgGAAAAAAAA,"(4, 1)",6,[],False,0,0,"[(4, 23, 19), (1, 5, 4)]",True
6250,cIgFAAAAAAAA,4OeGATCY5+ABMA,4OeGATCY5+ABMA:cIgFAAAAAAAA,"(3, 1)",4,[],False,0,0,"[(3, 3, 0), (1, 3, 2)]",True
663,cIgFAAAAAAAA,uE3wATCwn8EBMA,uE3wATCwn8EBMA:cIgFAAAAAAAA,"(3, 1)",6,[],False,0,0,"[(3, 23, 20), (1, 7, 6)]",True
3166,cIgGAAAAAAAA,yOcPADDgz8EBMA,yOcPADDgz8EBMA:cIgGAAAAAAAA,"(5, 1)",4,[],False,0,0,"[(5, 23, 18), (1, 7, 6)]",True


## Sample Datasets

In [123]:
df.shape

(9885, 11)

In [124]:
n = 300

df["is_double"] = df.dice.apply(lambda x: x[0] == x[1])
can_hit_sample_df = df.groupby("can_hit").sample(n=n, random_state=1)

# avoid dupes and also coming from the bar to make it easier
illegal_n = 500
illegal_sample_df = (
    df[(~df.index.isin(can_hit_sample_df.index)) & (df.player_bar == 0)]
    .groupby("has_illegal_move")
    .sample(n=illegal_n, random_state=1)
)

# shuffle - not sure when running e.g. with --max_samples: are individual items randomly chosen or the first x only
can_hit_sample_df = can_hit_sample_df.sample(frac=1, random_state=1)
illegal_sample_df = illegal_sample_df.sample(frac=1, random_state=1)


In [125]:
# store
df.to_pickle("samples_df.p")
can_hit_sample_df.to_pickle("can_hit_sample_df.p")
illegal_sample_df.to_pickle("illegal_sample_df.p")


In [7]:
# read back
df = pd.read_pickle("samples_df.p")
can_hit_sample_df = pd.read_pickle("can_hit_sample_df.p")
illegal_sample_df = pd.read_pickle("illegal_sample_df.p")


### can hit sample

In [8]:
# sample a case
# points are index based (!) -> +1 to match the board
# cIgVAAAAAAAA
sample = can_hit_sample_df[can_hit_sample_df.can_hit == True].sample(1, random_state=2)
b = backgammon.Backgammon(
    position_id=sample.position_id.item(), match_id=sample.match_id.item()
)
print(b)
# which plays hit?
print(f"dice: {sample.dice.item()}")
print()

print("these plays hit an opponents checker:")
for hit_moves in sample.hit_moves:
    for rolls in hit_moves:
        s = []
        rolls.sort(key=lambda x: x[1], reverse=True)
        for roll in rolls:
            roll = Move(*roll)
            s.append(f"{roll.source +1}/{roll.destination+1}")
        print(" ".join(s))


                 Position ID: is/gATDgz4MBMA
                 Match ID   : cAgVAAAAAAAA
 +13-14-15-16-17-18------19-20-21-22-23-24-+
 | X           O    |   | O        O  O  X |
 | X           O    |   | O              X |
 |                  |   | O                |
 |                  |   | O                |
 |                  |   | O                |
v|                  |BAR|                  |
 |                  |   | 7                |
 | O           X    |   | X                |
 | O           X    |   | X                |
 | O           X    |   | X              O |
 | O           X    |   | X              O |
 +12-11-10--9--8--7-------6--5--4--3--2--1-+

dice: (2, 5)

these plays hit an opponents checker:
24/22 13/8
24/22 8/3


### is illegal play sample

In [176]:
sample = illegal_sample_df[illegal_sample_df.has_illegal_move == True].sample(
    1, random_state=1
)
b = backgammon.Backgammon(
    position_id=sample.position_id.item(), match_id=sample.match_id.item()
)
print(b)
# which plays hit?
print(f"dice: {sample.dice.item()}")
print()
print("this ia an illegal play:")

s = []
for illegal_move in sample.illegal_move:
    for roll in illegal_move:
        roll = Move(*roll)
        s.append(f"{roll.source +1}/{roll.destination+1}")
    print(" ".join(s))

                 Position ID: zD3gATDg58HBAA
                 Match ID   : cIgVAAAAAAAA
 +13-14-15-16-17-18------19-20-21-22-23-24-+
 | X              X |   | O  O     O       |
 | X              X |   | O  O     O       |
 | X                |   | O  O             |
 |                  |   | O                |
 |                  |   |                  |
v|                  |BAR|                  |
 |                  |   | 6                |
 | O           X    |   | X                |
 | O           X    |   | X                |
 | O           X    |   | X              O |
 | O           X    |   | X              O |
 +12-11-10--9--8--7-------6--5--4--3--2--1-+

dice: (3, 5)

this ia an illegal play:
6/3 6/1


In [177]:
# another example - rolling a double:
sample = illegal_sample_df[
    (illegal_sample_df.has_illegal_move == True) & (illegal_sample_df.is_double == True)
].sample(1, random_state=1)
b = backgammon.Backgammon(
    position_id=sample.position_id.item(), match_id=sample.match_id.item()
)
print(b)
# which plays hit?
print(f"dice: {sample.dice.item()}")
print()
print("this ia an illegal play:")

s = []
for illegal_move in sample.illegal_move:
    for roll in illegal_move:
        roll = Move(*roll)
        s.append(f"{roll.source +1}/{roll.destination+1}")
    print(" ".join(s))


                 Position ID: 4NvBAQPg84UBMA
                 Match ID   : cIgNAAAAAAAA
 +13-14-15-16-17-18------19-20-21-22-23-24-+
 | X           O  O |   | O              X |
 | X           O  O |   | O              X |
 |             O    |   | O                |
 |                  |   | O                |
 |                  |   | O                |
v|                  |BAR|                  |
 |             X    |   | X                |
 |             X    |   | X                |
 | O           X    |   | X                |
 | O           X    |   | X  O             |
 | O        X  X    |   | X  O             |
 +12-11-10--9--8--7-------6--5--4--3--2--1-+

dice: (3, 3)

this ia an illegal play:
8/5 6/3 6/3 13/10


## Generate promts / eval data

In [16]:
BOARD_POSITIONS = [x for x in reversed(range(1, 25))]

GPT_PLAYER_NAME = "backgammonGPT"

SYSTEM_CONTENT = (
    f"You are {GPT_PLAYER_NAME} - a backgammon engine. "
    "Answer the following True/False question with the provided backgammon board state and the provided dice rolls. "
    f"It is {GPT_PLAYER_NAME}'s turn."
)

### hit or nope?

In [17]:
eval_hit_id = "backgammon-can-hit"

json_data = []
for pos, row in can_hit_sample_df.iterrows():
    chat_gpt_player = []
    opponent_player = []

    b = backgammon.Backgammon(position_id=row.position_id, match_id=row.match_id)

    for pos, board_points in enumerate(reversed(b.position.board_points)):
        if not board_points:
            continue

        if board_points > 0:
            chat_gpt_player.append(
                f"{board_points} checker{'s'[:board_points^1]} on the {BOARD_POSITIONS[pos]} point"
            )
        else:
            opponent_player.append(
                f"{abs(board_points)} checker{'s'[:abs(board_points)^1]} on the {BOARD_POSITIONS[pos]} point"
            )

    chatgpt_setup = ", ".join(chat_gpt_player)
    opponent_setup = ", ".join(opponent_player)

    board_eval = f"The backgammon board's position id is {row.position_id} and the match id is {row.match_id}"
    question = (
        f"{GPT_PLAYER_NAME} is rolling a {row.dice[0]} and a {row.dice[1]}. "
        f"Can {GPT_PLAYER_NAME} hit one of the opponent's checkers?"
    )
    true_false_only_reminder = "Provide your reasoning step by step, and at the end, write your final answer, True or False, enclosed in square brackets."

    user_content = (
        f"{board_eval}. {GPT_PLAYER_NAME} has {chatgpt_setup}. "
        f"The opponent has {opponent_setup}. {question} {true_false_only_reminder}"
    )

    messages = [
        {"role": "system", "content": SYSTEM_CONTENT},
        {"role": "user", "content": user_content},
    ]

    json_data.append({"input": messages, "ideal": f"[{row.can_hit}]"})

# write can_hit samples
backgammon_folder = eval_data_path / "backgammon"
backgammon_folder.mkdir(parents=True, exist_ok=True)
backgammon_can_hit = backgammon_folder / f"{eval_hit_id}.jsonl"

with open(backgammon_can_hit, "w") as f:
    for entry in json_data:
        json.dump(entry, f)
        f.write("\n")

### illegal move or not?

In [18]:
few_shot_perc = 0.2

cutoff = int(few_shot_perc * illegal_sample_df.shape[0])
illegal_sample_fs_df = illegal_sample_df[0:cutoff]
illegal_sample_train_df = illegal_sample_df[cutoff:]


In [19]:
illegal_sample_df.shape, illegal_sample_train_df.shape, illegal_sample_fs_df.shape

((1000, 12), (800, 12), (200, 12))

In [20]:
illegal_sample_df.is_double.mean(), illegal_sample_train_df.is_double.mean(), illegal_sample_fs_df.is_double.mean()


(0.345, 0.35, 0.325)

In [23]:
eval_illegal_id = "backgammon-illegal-move"


def get_json_data(df, is_few_shot=False):
    json_data = []
    for pos, row in df.iterrows():
        chat_gpt_player = []
        opponent_player = []
        b = backgammon.Backgammon(position_id=row.position_id, match_id=row.match_id)

        for pos, board_points in enumerate(reversed(b.position.board_points)):
            if not board_points:
                continue

            if board_points > 0:
                chat_gpt_player.append(
                    f"{board_points} checker{'s'[:board_points^1]} on the {BOARD_POSITIONS[pos]} point"
                )
            else:
                opponent_player.append(
                    f"{abs(board_points)} checker{'s'[:abs(board_points)^1]} on the {BOARD_POSITIONS[pos]} point"
                )

        chatgpt_setup = ", ".join(chat_gpt_player)
        opponent_setup = ", ".join(opponent_player)

        board_eval = f"The backgammon board's position id is {row.position_id} and the match id is {row.match_id}"

        moves_strings = []
        if row.has_illegal_move:
            moves = row.illegal_move
            moves = [Move(*x) for x in moves]
        else:
            plays = [play for play in b.generate_plays()]
            random_play = random.choice(plays)
            moves = list(random_play.moves)

        moves.sort(key=lambda x: x.source, reverse=True)
        for roll in moves:
            moves_strings.append(f"{roll.source +1}/{roll.destination+1}")

        illegal_move_question = f"Is {' '.join(moves_strings)} an illegal play?"

        question = f"{GPT_PLAYER_NAME} is rolling a {row.dice[0]} and a {row.dice[1]}. {illegal_move_question}"

        true_false_only_reminder = "Provide your reasoning step by step, and at the end, write your final answer, True or False, enclosed in square brackets."

        user_content = (
            f"{board_eval}. {GPT_PLAYER_NAME} has {chatgpt_setup}. "
            f"The opponent has {opponent_setup}. {question} {true_false_only_reminder}"
        )

        if is_few_shot:
            messages = [
                {"role": "system", "content": user_content, "name": "example_user"},
                {
                    "role": "system",
                    "content": f"[{row.has_illegal_move}]",
                    "name": "example_assistant",
                },
            ]
        else:
            messages = [
                {"role": "system", "content": SYSTEM_CONTENT},
                {"role": "user", "content": user_content},
            ]

        # TODO: check how to phrase this
        # negate has_illegal_move - so we can ask if it is a legal play
        if is_few_shot:
            json_data.append({"sample": messages})
        else:
            json_data.append({"input": messages, "ideal": f"[{row.has_illegal_move}]"})

    return json_data

In [24]:
# full set
json_data = get_json_data(illegal_sample_df)

# write can_hit samples
backgammon_folder = eval_data_path / "backgammon"
backgammon_folder.mkdir(parents=True, exist_ok=True)
backgammon_illegal_move = backgammon_folder / f"{eval_illegal_id}.jsonl"

with open(backgammon_illegal_move, "w") as f:
    for entry in json_data:
        json.dump(entry, f)
        f.write("\n")

In [25]:
# few_shot version
eval_illegal_fs_few_shot_id = "backgammon-illegal-fs-few_shot-move"
eval_illegal_fs_samples_id = "backgammon-illegal-fs-samples-move"

# have a unique "train set"
json_data = get_json_data(illegal_sample_train_df)

# write can_hit samples
backgammon_folder = eval_data_path / "backgammon"
backgammon_folder.mkdir(parents=True, exist_ok=True)
backgammon_illegal_move = backgammon_folder / f"{eval_illegal_fs_samples_id}.jsonl"

with open(backgammon_illegal_move, "w") as f:
    for entry in json_data:
        json.dump(entry, f)
        f.write("\n")


json_fs_data = get_json_data(illegal_sample_fs_df, is_few_shot=True)

# write can_hit samples
backgammon_folder = eval_data_path / "backgammon"
backgammon_folder.mkdir(parents=True, exist_ok=True)
backgammon_illegal_move = backgammon_folder / f"{eval_illegal_fs_few_shot_id}.jsonl"

with open(backgammon_illegal_move, "w") as f:
    for entry in json_fs_data:
        json.dump(entry, f)
        f.write("\n")

In [204]:
registry_yaml = {}

# can hit registry
registry_id = f"{eval_hit_id}.match.dev.v0"
registry_yaml[eval_hit_id] = {"id": registry_id, "metrics": ["accuracy"]}
registry_yaml[registry_id] = {
    "class": "evals.elsuite.basic.match:Match",
    "args": {"samples_jsonl": f"backgammon/{eval_hit_id}.jsonl"},
}

# illegal move reg
registry_illegal_id = f"{eval_illegal_id}.match.dev.v0"
registry_yaml[eval_illegal_id] = {"id": registry_illegal_id, "metrics": ["accuracy"]}
registry_yaml[registry_illegal_id] = {
    "class": "evals.elsuite.basic.includes:Includes",
    "args": {"samples_jsonl": f"backgammon/{eval_illegal_id}.jsonl"},
}

In [205]:
# illegal move reg - few shot version
# TODO: investigate this further - skipping the few shot example for now

_ = """
registry_illegal_fs_id = f"{eval_illegal_fs_samples_id}.match.dev.v0"
registry_yaml[eval_illegal_fs_samples_id] = {
    "id": registry_illegal_fs_id,
    "metrics": ["accuracy"],
}
registry_yaml[registry_illegal_fs_id] = {
    "class": "evals.elsuite.basic.includes:Includes",
    "args": {
        "few_shot_jsonl": f"backgammon/{eval_illegal_fs_few_shot_id}.jsonl",
        "num_few_shot": 4,
        "samples_jsonl": f"backgammon/{eval_illegal_fs_samples_id}.jsonl",
    },
}
"""

In [188]:
with open(eval_path / "backgammon.yaml", "w") as f:
    yaml.dump(registry_yaml, f)


In [26]:
#%%capture
!oaieval gpt-3.5-turbo backgammon-can-hit --record_path logs/bg_can_hit_4.log --no-cache --max_samples 30
!oaieval gpt-3.5-turbo backgammon-illegal-move --record_path logs/bg_illegal_move_4.log --max_samples 30
#!oaieval gpt-4 backgammon-illegal-fs-samples-move --record_path logs/bg_illegal_move_4_fs.log --max_samples 300


[2023-06-15 12:23:04,352] [registry.py:262] Loading registry from /Users/bakebrain/src/evals/evals/registry/evals
[2023-06-15 12:23:04,677] [registry.py:262] Loading registry from /Users/bakebrain/.evals/evals
[2023-06-15 12:23:04,681] [oaieval.py:138] [1;35mRun started: 230615102304BJCCKFQ5[0m
[2023-06-15 12:23:04,695] [data.py:83] Fetching backgammon/backgammon-can-hit.jsonl
[2023-06-15 12:23:04,698] [eval.py:33] Evaluating 30 samples
[2023-06-15 12:23:04,710] [eval.py:139] Running in threaded mode with 10 threads!
100%|███████████████████████████████████████████| 30/30 [00:38<00:00,  1.30s/it]
[2023-06-15 12:23:43,665] [record.py:341] Final report: {'accuracy': 0.7, 'boostrap_std': 0.08421873636878886}. Logged to logs/bg_can_hit_4.log
[2023-06-15 12:23:43,665] [oaieval.py:177] Final report:
[2023-06-15 12:23:43,665] [oaieval.py:179] accuracy: 0.7
[2023-06-15 12:23:43,665] [oaieval.py:179] boostrap_std: 0.08421873636878886
[2023-06-15 12:23:43,669] [record.py:330] Logged 60 rows of

## eval


### can we hit?

In [148]:
events = "logs/bg_can_hit_4.log"

with open(events, "r") as f:
    events_df = pd.read_json(f, lines=True)

matches_df = events_df[events_df.type == "match"].reset_index(drop=True)
matches_df = matches_df.join(pd.json_normalize(matches_df.data))

expected_strs = matches_df.expected.values
expected = expected_strs == "[True]"
correct = matches_df.correct.values

In [149]:
correct.mean()


0.7

In [159]:
# is this data in the df actually
y = np.where(correct, expected, np.logical_not(expected))


In [158]:
(y == expected).mean() == correct.mean()


True

In [155]:
pd.crosstab(expected, y, rownames=["actual"], colnames=["pred"], margins=True)

pred,False,True,All
actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
False,12,3,15
True,6,9,15
All,18,12,30


### is it an illegal play?

In [162]:
events = "logs/bg_illegal_move_4.log"

with open(events, "r") as f:
    events_df = pd.read_json(f, lines=True)

matches_df = events_df[events_df.type == "match"].reset_index(drop=True)
matches_df = matches_df.join(pd.json_normalize(matches_df.data))

expected_strs = matches_df.expected.values
expected = expected_strs == "[True]"
correct = matches_df.correct.values

In [163]:
correct.mean()


0.43333333333333335

In [164]:
y = np.where(correct, expected, np.logical_not(expected))


In [165]:
pd.crosstab(expected, y, rownames=["actual"], colnames=["pred"], margins=True)

pred,False,True,All
actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
False,10,6,16
True,11,3,14
All,21,9,30
