In [15]:
import re
import dataclasses
import enum
import itertools
import numpy as np


class Sute(enum.IntEnum):
    SPADE = 0
    CLUB = 1
    HEART = 2
    DIAMOND = 3
    NONE = 4

SUTES = [Sute.SPADE, Sute.CLUB, Sute.HEART, Sute.DIAMOND]
SUTE_DICT = {'S': Sute.SPADE, 'C': Sute.CLUB, 'H': Sute.HEART, 'D': Sute.DIAMOND}

NUMBER_DICT = {str(i): i - 3 for i in range(3, 10)}
NUMBER_DICT.update({'T': 7, 'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12})

JOKER_INDEX = 52


@dataclasses.dataclass(frozen=True)
class Card:
    sute: Sute
    number: int
    is_joker: bool = False

    @property
    def index(self) -> int:
        if self.is_joker:
            return JOKER_INDEX
        return self.sute * 13 + self.number
    
    def replace_with_joker(self) -> "Card":
        return dataclasses.replace(self, is_joker=True)

    @classmethod
    def from_str(cls, s: str) -> "Card":
        if s == 'JO':
            return JOKER_CARD

        if s.startswith('JO('):
            card = Card.from_str(s[3:5])
            return card.replace_with_joker()
        
        if s.startswith('J'):
            print(s)

        sute = SUTE_DICT[s[0]]
        number = NUMBER_DICT[s[1]]

        return Card(sute, number)


JOKER_CARD = Card(Sute.NONE, 13, is_joker=True)


@dataclasses.dataclass
class ActionType:
    def is_match(self, cards: list[Card]) -> bool:
        raise NotImplementedError

    def cards_to_action(self, cards: list[Card]) -> int:
        raise NotImplementedError

    def action_to_cards(self, action: int) -> list[Card]:
        raise NotImplementedError

    @property
    def num_combinations(self) -> int:
        raise NotImplementedError


@dataclasses.dataclass
class ActionTypePass(ActionType):
    def is_match(self, cards: list[Card]) -> bool:
        return len(cards) == 0

    def cards_to_action(self, cards: list[Card]) -> int:
        return 0

    def action_to_cards(self, action: int) -> list[Card]:
        return []

    @property
    def num_combinations(self) -> int:
        return 1


@dataclasses.dataclass
class ActionTypeSingleJoker(ActionType):
    def is_match(self, cards: list[Card]) -> bool:
        return len(cards) == 1 and cards[0].is_joker

    def cards_to_action(self, cards: list[Card]) -> int:
        return 0

    def action_to_cards(self, action: int) -> list[Card]:
        return [JOKER_CARD]

    @property
    def num_combinations(self) -> int:
        return 1


@dataclasses.dataclass
class ActionTypePair(ActionType):
    num_cards: int

    def is_match(self, cards: list[Card]) -> bool:
        if len(cards) != self.num_cards:
            return False

        if not all([card.number == cards[0].number for card in cards]):
            return False

        return True

    def cards_to_action(self, cards: list[Card]) -> int:
        sutes = [card.sute for card in cards]
        sute_combination_index = self.sutes_to_combination_index(sutes)

        number = cards[0].number

        return sute_combination_index * 13 + number

    def action_to_cards(self, action: int) -> list[Card]:
        sute_combination_index = action // 13
        number = action % 13

        sutes = self.sute_combinations[sute_combination_index]

        return [Card(sute, number) for sute in sutes]

    def sutes_to_combination_index(self, sutes: list[Sute]) -> int:
        return self.sute_combinations.index(frozenset(sutes))

    @property
    def sute_combinations(self) -> list[frozenset[Sute]]:
        if self.num_cards > 4:
            append_sute = [Sute.NONE]
        else:
            append_sute = []

        combinations = itertools.combinations(SUTES, min(self.num_cards, 4))

        return [frozenset(list(sutes) + append_sute) for sutes in combinations]

    @property
    def num_combinations(self) -> int:
        return len(self.sute_combinations) * 13


@dataclasses.dataclass
class ActionTypeStairs(ActionType):
    num_cards: int

    def is_match(self, cards: list[Card]) -> bool:
        if len(cards) != self.num_cards:
            return False

        if not all([card.sute == cards[0].sute for card in cards]):
            if cards[0].sute == Sute.NONE and cards[0].is_joker:
                return True

            if cards[-1].sute == Sute.NONE and cards[-1].is_joker:
                return True

            return False

        if not all([card.number == cards[0].number + i for i, card in enumerate(cards)]):
            return False

        return True

    def cards_to_action(self, cards: list[Card]) -> int:
        sute_index = cards[0].sute
        first_number = cards[0].number + 1

        return sute_index * self.num_number_combinations + first_number

    def action_to_cards(self, action: int) -> list[Card]:
        sute_index = action // self.num_number_combinations
        first_number = action % self.num_number_combinations - 1

        sute = SUTES[sute_index]
        numbers = [first_number + i for i in range(self.num_cards)]

        return [Card(sute, number, is_joker=not (0 <= number < 13)) for number in numbers]

    @property
    def num_number_combinations(self) -> int:
        return 13 - self.num_cards + 3

    @property
    def num_combinations(self) -> int:
        return 4 * self.num_number_combinations


@dataclasses.dataclass
class ActionConverter:
    action_types: list[ActionType]

    def cards_to_action(self, cards: list[Card]) -> int:
        action = 0

        for action_type in self.action_types:
            if action_type.is_match(cards):
                return action + action_type.cards_to_action(cards)

            action += action_type.num_combinations

        raise ValueError(f"No action type matches {cards}")

    def action_to_cards(self, action: int) -> list[Card]:
        for action_type in self.action_types:
            if action < action_type.num_combinations:
                return action_type.action_to_cards(action)
            action -= action_type.num_combinations

        raise ValueError(f"Invalid action: {action}")


ACTION_CONVERTER = ActionConverter([
    ActionTypePass(),
    ActionTypeSingleJoker(),
    ActionTypePair(1),
    ActionTypePair(2),
    ActionTypePair(3),
    ActionTypePair(4),
    ActionTypePair(5),
    ActionTypeStairs(3),
    ActionTypeStairs(4),
    ActionTypeStairs(5),
    ActionTypeStairs(6),
    ActionTypeStairs(7),
])

In [None]:
@dataclasses.dataclass(frozen=True)
class PlayerInfo:
    name: str
    previous_rank: int
    hands: list[Card]


PLAYER_INFO_REGEX = re.compile(r'(\d+\.\w+)\s*\((\d+)\)\s*:\{([^}]*)\}')


def parse_player_info(s: str) -> list[PlayerInfo]:
    matches: list[tuple[str, str, str]] = PLAYER_INFO_REGEX.findall(s)

    player_info_list = []

    for name, rank, hands_str in matches:
        hands = [Card.from_str(s) for s in hands_str.split(" ") if len(s) == 2]
        player_info_list.append(PlayerInfo(name, int(rank) - 1, hands))

    return player_info_list


@dataclasses.dataclass(frozen=True)
class Move:
    cards: list[Card]

    def get_indices(self) -> list[int]:
        return [card.index for card in self.cards]

    def get_action(self) -> int:
        return ACTION_CONVERTER.cards_to_action(self.cards)


def parse_move(line: str) -> tuple[str, Move]:
    tokens = line.split(" ")
    name = tokens[0][:-3]
    move_strs = tokens[2:]

    if move_strs[0] == "Pass":
        cards = []
    else:
        cards = [Card.from_str(s) for s in move_strs if len(s) > 0]

    return name, Move(cards)


@dataclasses.dataclass(frozen=True)
class BeforeState:
    flash_before: bool
    on_stage_before: list[Card]
    lock_before: bool
    rev_before: bool


def parser_before_state(line: str) -> BeforeState:
    tokens = line.split(",")
    tokens = [s.split(":")[1] for s in tokens]

    on_stage_before = [Card.from_str(s) for s in tokens[1].split(" ") if len(s) >= 2]

    return BeforeState(
        flash_before=tokens[0].replace(" ", "") == "1",
        on_stage_before=on_stage_before,
        lock_before=tokens[2].replace(" ", "") == "1",
        rev_before=tokens[3].replace(" ", "") == "1",
    )


def card_indices_to_action(indices: np.ndarray, max_num_cards: int) -> np.ndarray:
    passes = 53 + np.arange(max_num_cards)

    if len(indices) < max_num_cards:
        indices = np.concatenate([indices, passes[len(indices):]])

    return indices


MAX_TURNS = 300
MAX_NUM_CARDS_IN_MOVE = 7

S_DIM = MAX_NUM_CARDS_IN_MOVE + 3

class StateIndex(enum.IntEnum):
    PLAYER_ID = MAX_NUM_CARDS_IN_MOVE
    IS_FLASH = MAX_NUM_CARDS_IN_MOVE + 1
    TURN = MAX_NUM_CARDS_IN_MOVE + 2


def convert_replay(lines: list[str], offset: int) -> np.ndarray:
    players = parse_player_info(lines[offset])
    player_id_dict = {player.name: i for i, player in enumerate(players)}

    s_init = np.zeros((len(players), 2), dtype=np.uint8)
    s_hands = np.zeros((len(players), 12, S_DIM), dtype=np.uint8)

    for i in range(len(players)):
        s_init[i, 1] = players[i].previous_rank

        for j in range(s_hands.shape[1]):
            if j < len(players[i].hands):
                cards = [players[i].hands[j].index]
            else:
                cards = []

            s_hands[i, j, :MAX_NUM_CARDS_IN_MOVE] = card_indices_to_action(cards, max_num_cards=MAX_NUM_CARDS_IN_MOVE)

        s_hands[i, :, StateIndex.PLAYER_ID] = i

    def next_player_id(player_id: int) -> int:
        return (player_id + 1) % len(players)

    offset += 1
    turn = 0

    s = np.zeros((MAX_TURNS, S_DIM), dtype=np.uint8)
    a = np.zeros((MAX_TURNS), dtype=np.uint16)

    prev_player_id = None

    while lines[offset].startswith(f"Turn"):
        name, move = parse_move(lines[offset + 1])
        before_state = parser_before_state(lines[offset + 2])

        player_id = player_id_dict[name]

        if np.all(s_init[:, 0] == 0):
            s_init[:, 0] = (np.arange(len(players)) - player_id) % len(players)

        while (prev_player_id is not None) and (next_player_id(prev_player_id) != player_id):
            prev_player_id = next_player_id(prev_player_id)

            s[turn, :MAX_NUM_CARDS_IN_MOVE] = card_indices_to_action([], max_num_cards=MAX_NUM_CARDS_IN_MOVE)
            s[turn, StateIndex.PLAYER_ID] = prev_player_id
            s[turn, StateIndex.TURN] = turn + 1
            a[turn] = Move([]).get_action()
            turn += 1

        if turn > 0 and before_state.on_stage_before is None:
            s[turn - 1, StateIndex.IS_FLASH] = 1

        s[turn, :MAX_NUM_CARDS_IN_MOVE] = card_indices_to_action(move.get_indices(), max_num_cards=MAX_NUM_CARDS_IN_MOVE)
        s[turn, StateIndex.PLAYER_ID] = player_id
        s[turn, StateIndex.TURN] = turn + 1
        a[turn] = move.get_action()
        turn += 1

        prev_player_id = player_id
        offset += 10

    r = np.zeros((5), dtype=np.uint8)

    for i in range(5):
        tokens = lines[offset + 1 + i].split(" ")
        name = tokens[-1]
        player_id = player_id_dict[name]
        r[player_id] = 4 - i

    for i in range(turn):
        if not np.all(a[i: i + 5] == 0):
            continue

        a[i: -5] = a[i + 5:]
        s[i: -5] = s[i + 5:]
        s[i:, 7] -= 5

    return s_init, s_hands, s, a, r, offset


with open("./datasets/random_data2-39.txt", "r") as f:
    lines = f.read().splitlines()

s_init = []
s_hands = []
s_moves = []
a = []
r = []

offset = 0

while True:
    while not lines[offset].startswith("original"):
        offset += 1

        if offset >= len(lines):
            break

    if offset >= len(lines):
        break

    s_init_i, s_hands_i, s_moves_i, a_i, r_i, offset = convert_replay(lines, offset)

    s_init.append(s_init_i)
    s_hands.append(s_hands_i)
    s_moves.append(s_moves_i)
    a.append(a_i)
    r.append(r_i)

    print(len(s_init))


In [17]:
s_init = np.concatenate(s_init, axis=0) 
s_hands = np.concatenate(s_hands, axis=0)
s_moves = np.concatenate(s_moves, axis=0)
a = np.concatenate(a, axis=0)
r = np.concatenate(r, axis=0)

np.savez_compressed("./datasets/d_20000.npz", s_init=s_init, s_hands=s_hands, s_moves=s_moves, a=a, r=r)
