In [4]:
import dataclasses
import itertools

import numpy as np


@dataclasses.dataclass
class Action:
    def action_to_cards(self, action: int) -> tuple[np.ndarray, bool]:
        raise NotImplementedError
    
    @property
    def num_combinations(self) -> int:
        raise NotImplementedError


@dataclasses.dataclass
class Pair(Action):
    num_cards: int

    def action_to_cards(self, action: int) -> tuple[np.ndarray, bool]:
        return np.zeros((4, 13), dtype=np.uint8), False

    def sute_combinations(self) -> list[tuple[int, int]]:
        return list(itertools.combinations(range(4), self.num_cards))

    @property
    def num_combinations(self) -> int:
        itertools.combinations
        return 4 * 13


def action_to_cards(action: int) -> tuple[np.ndarray, bool]:
    cards = np.zeros((4, 13), dtype=np.uint8)
    is_include_joker = False

    if action < 52: # 1枚役
        cards[action // 13, action % 13] = 1

    elif action == 52:
        is_include_joker = True

    elif action < 131: # 2枚役
        action = action - 53

        sute_pairs = [
            (0, 1),
            (0, 2),
            (0, 3),
            (1, 2),
            (1, 3),
            (2, 3),
        ]
        number = action // len(sute_pairs)
        sute_pair_index = action % len(sute_pairs)

        for sute_index in sute_pairs[sute_pair_index]:
            cards[sute_index, number] = 1
    
    elif action < 183: # 3枚役
        action = action - 131

        sute_pairs = [
            (0, 1, 2),
            (0, 1, 3),
            (0, 2, 3),
            (1, 2, 3),
        ]

        number = action // len(sute_pairs)
        sute_pair_index = action % len(sute_pairs)

        for sute in sute_pairs[sute_pair_index]:
            cards[sute, number] = 1

    elif action < 196: # 4枚役
        action = action - 183

        cards[:, action] = 1

    elif action < 209: # 5枚役
        action = action - 196

        cards[:, action] = 1
        is_include_joker = True

    elif action < 261: # 3枚階段役
        action = action - 209

        first_number = action % 13 - 1
        sute = action // 13

        for i in range(3):
            if 0 <= (first_number + i) < 13:
                cards[sute, first_number + i] = 1
            else:
                is_include_joker = True

    elif action < 309:# 4枚階段役
        action = action - 261
        first_number = action % 12 - 1
        sute = action // 12

        for i in range(4):
            if 0 <= (first_number + i) < 13:
                cards[sute, first_number + i] = 1
            else:
                is_include_joker = True

    elif action < 353:# 5枚階段役
        action = action - 309
        first_number = action % 11 - 1
        sute = action // 11

        for i in range(5):
            if 0 <= (first_number + i) < 13:
                cards[sute, first_number + i] = 1
            else:
                is_include_joker = True

    return cards, is_include_joker

In [152]:
def card_index_to_str(index: int) -> str:
    sute = index // 13
    number = index % 13

    sutes = ['a', 'b', 'c', 'd']

    if sute == 4:
        return 'JOKER'
    else:
        return f'{sutes[sute]} {number}'


def card_indices_to_str(indices: np.ndarray) -> list[str]:
    return [card_index_to_str(i) for i in indices if i < 53]


def legacy_action_to_action(legacy_action: int, hand_cards: np.ndarray, max_num_cards: int) -> tuple[np.ndarray, np.ndarray]:
    cards, is_include_joker = action_to_cards(legacy_action)
    cards = np.concatenate([cards.reshape(-1), [is_include_joker]])

    num_missing_cards = np.sum((1 - hand_cards) * cards)

    if num_missing_cards > 1:
        hand_str = card_indices_to_str(np.where(hand_cards)[0])
        move_str = card_indices_to_str(np.where(cards)[0])
        assert False, f'[missing] hand: {hand_str}, move: {move_str}'

    if num_missing_cards == 1:
        if is_include_joker or hand_cards[52] == 0:
            hand_str = card_indices_to_str(np.where(hand_cards)[0])
            move_str = card_indices_to_str(np.where(cards)[0])
            assert False, f'[missing] hand: {hand_str}, move: {move_str}'

        cards = cards * hand_cards
        cards[52] = 1

    hand_cards = hand_cards * (1 - cards)

    return cards_to_action(cards, max_num_cards), hand_cards


def cards_to_action(cards: np.ndarray, max_num_cards: int) -> np.ndarray:
    indices = np.where(cards == 1)[0]

    return card_indices_to_action(indices, max_num_cards)


def card_indices_to_action(indices: np.ndarray, max_num_cards: int) -> np.ndarray:
    passes = 53 + np.arange(max_num_cards)

    if len(indices) < max_num_cards:
        indices = np.concatenate([indices, passes[len(indices):]])

    return indices.astype(np.uint16)

In [155]:
def legacy_asr_to_tokens(a: np.ndarray, s: np.ndarray) -> np.ndarray:
    tokens = np.zeros((s.shape[0] * 5 + 2, 7), dtype=np.uint8)

    hand_cards = s[0, -53:]
    init_card_indices = np.where(hand_cards == 1)[0]

    print(card_indices_to_str(init_card_indices))

    tokens[0, :6] = card_indices_to_action(init_card_indices[:5], max_num_cards=6)
    tokens[1, :6] = card_indices_to_action(init_card_indices[5:], max_num_cards=6)

    tokens[:2, 6] = 4

    prev_action = card_indices_to_action(indices=[], max_num_cards=6)

    t = 2

    for i in range(s.shape[0]):
        if np.all(s[i] == 0):
            break

        s_i = s[i]

        for player in range(4):
            cards = s_i[:53]
            feature = s_i[53:57]
            s_i = s_i[57:]

            if t == 2 and np.all(cards == 0):
                continue

            tokens[t, :6] = cards_to_action(cards, max_num_cards=6)

            print(card_indices_to_str(tokens[t, :6]), player, feature)

            if np.all(prev_action == tokens[t, :6]):
                tokens[t, :6] = card_indices_to_action(indices=[], max_num_cards=6)
            else:
                prev_action = tokens[t, :6]

            tokens[t, 6] = player
            t += 1

        if a[i] == 353:
            tokens[t, :6] = card_indices_to_action(indices=[], max_num_cards=6)
        else:
            tokens[t, :6], hand_cards = legacy_action_to_action(a[i], hand_cards, max_num_cards=6)
            prev_action = tokens[t, :6]

        print(card_indices_to_str(tokens[t, :6]), 4)

        tokens[t, 6] = 4
        t += 1

    return tokens


In [2]:
import numpy as np

npz_RSA_sample = np.load(f'datasets/rsa_500k_compressed.npz')

npz_R: np.ndarray = npz_RSA_sample['R_np']
npz_S: np.ndarray = npz_RSA_sample['S_np']
npz_A: np.ndarray = npz_RSA_sample['A_np']

print(f"{npz_A.dtype}, {npz_S.dtype}, {npz_R.dtype}")
print(f'npz_R:{npz_R.shape}, npz_S:{npz_S.shape}, npz_A:{npz_A.shape}')

int64, uint8, float32
npz_R:(500000, 1), npz_S:(500000, 30, 281), npz_A:(500000, 30)


In [13]:
def cards_to_str(cards: np.ndarray) -> str:
    indices = np.where(cards == 1)[0]
    strs = []

    for index in indices:
        sute = index // 13
        number = index % 13

        sutes = ['a', 'b', 'c', 'd']

        strs.append(f'{sutes[sute]} {number}')

    return strs


def test(a: np.ndarray, s: np.ndarray) -> np.ndarray:
    # a.shape: [30]
    # s.shape: [30, 281]

    hand_cards = s[0, -53:]
    print("hand", cards_to_str(hand_cards))

    for i in range(s.shape[0]):
        if np.all(s[i] == 0):
            break

        print()

        s_i = s[i]

        for player in range(4):
            cards = s_i[:53]
            feature = s_i[53:57]
            s_i = s_i[57:]

            print(f"s_{player}", cards_to_str(cards), feature)

        cards, is_include_joker = action_to_cards(a[i])
        cards = np.concatenate([cards.reshape(-1), [is_include_joker]])

        hand_cards = hand_cards * (1 - cards)

        print("a", cards_to_str(cards))
        print("hand", cards_to_str(hand_cards))

index = 3
test(npz_A[index], npz_S[index])

hand ['a 1', 'a 4', 'a 6', 'a 11', 'b 1', 'b 5', 'b 10', 'b 12', 'c 1', 'c 12', 'd 4']

s_0 ['b 0', 'c 0'] [0 0 0 0]
s_1 ['b 0', 'c 0'] [0 0 0 0]
s_2 ['a 3', 'c 3'] [0 0 0 0]
s_3 ['b 6', 'c 6'] [0 0 0 0]
a []
hand ['a 1', 'a 4', 'a 6', 'a 11', 'b 1', 'b 5', 'b 10', 'b 12', 'c 1', 'c 12', 'd 4']

s_0 [] [1 0 0 0]
s_1 ['d 6', 'd 7', 'd 8'] [0 0 0 0]
s_2 ['d 6', 'd 7', 'd 8'] [0 0 0 0]
s_3 ['d 6', 'd 7', 'd 8'] [0 0 0 0]
a []
hand ['a 1', 'a 4', 'a 6', 'a 11', 'b 1', 'b 5', 'b 10', 'b 12', 'c 1', 'c 12', 'd 4']

s_0 [] [1 0 0 0]
s_1 ['b 2'] [0 0 0 0]
s_2 ['c 8'] [0 0 0 0]
s_3 ['c 9'] [0 0 0 1]
a ['c 12']
hand ['a 1', 'a 4', 'a 6', 'a 11', 'b 1', 'b 5', 'b 10', 'b 12', 'c 1', 'd 4']

s_0 ['c 12'] [0 0 0 1]
s_1 ['c 12'] [0 0 0 1]
s_2 ['c 12'] [0 0 0 1]
s_3 ['c 12'] [0 0 0 1]
a []
hand ['a 1', 'a 4', 'a 6', 'a 11', 'b 1', 'b 5', 'b 10', 'b 12', 'c 1', 'd 4']

s_0 ['c 12'] [0 0 0 1]
s_1 ['c 12'] [0 0 0 1]
s_2 ['c 12'] [0 0 0 1]
s_3 [] [1 0 0 0]
a ['a 1', 'b 1', 'c 1']
hand ['a 4', 'a 6', 'a 1

In [158]:
index = 10
tokens = legacy_asr_to_tokens(npz_A[index], npz_S[index])

['a 11', 'a 12', 'b 4', 'c 0', 'c 1', 'c 7', 'c 11', 'c 12', 'd 3', 'JOKER']
['b 4'] 4
['a 7'] 0 [0 0 0 0]
['a 7'] 1 [0 0 0 0]
['d 8'] 2 [0 0 0 0]
['d 9'] 3 [0 0 0 1]
[] 4
['d 9'] 0 [0 0 0 1]
['d 9'] 1 [0 0 0 1]
[] 2 [1 0 0 0]
['b 3'] 3 [0 0 0 0]
['c 7'] 4
['c 7'] 0 [0 0 0 0]
['c 7'] 1 [0 0 0 0]
['c 7'] 2 [0 0 0 0]
['c 8'] 3 [0 0 0 1]
['c 12'] 4
['c 7'] 0 [0 0 0 0]
['c 8'] 1 [0 0 0 1]
['c 12'] 2 [0 0 0 1]
['c 12'] 3 [0 0 0 1]
[] 4
['c 8'] 0 [0 0 0 1]
['c 12'] 1 [0 0 0 1]
['c 12'] 2 [0 0 0 1]
[] 3 [1 0 0 0]
['d 3'] 4
['a 4'] 0 [0 0 0 0]
['b 6'] 1 [0 0 0 0]
['b 9'] 2 [0 0 0 1]
['b 9'] 3 [0 0 0 1]
[] 4
['a 1'] 0 [0 0 0 0]
['b 2'] 1 [0 0 0 0]
['b 10'] 2 [0 0 0 1]
['b 10'] 3 [0 0 0 1]
[] 4
[] 0 [1 0 0 0]
['c 2', 'd 2'] 1 [0 0 0 0]
[] 2 [1 0 0 0]
['a 2'] 3 [0 0 0 0]
['a 12'] 4
['a 12'] 0 [0 0 0 1]
['a 12'] 1 [0 0 0 1]
['a 12'] 2 [0 0 0 1]
['a 12'] 3 [0 0 0 1]
[] 4
['a 12'] 0 [0 0 0 1]
['a 12'] 1 [0 0 0 1]
['a 12'] 2 [0 0 0 1]
[] 3 [1 0 0 0]
['a 11', 'c 11'] 4
['a 11', 'c 11'] 0 [0 0 0 0]
['a

In [88]:
npz_A[1]

array([130, 353, 243, 353,  10,  52,   0, 175,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0])