## Poker infrastructure

Here we define enums `SUITS`, `RANKS` and `HANDS` to represent the card and hand types, as well as a rule-based hand type classifier which will be used later for data generation.

In [1]:
import numpy as np
from enum import Enum

In [2]:
SUITS = Enum("suit", {"clubs":"♣", "diamonds":"♦", "hearts":"♥", "spades":"♠"})
RANKS = Enum("rank", {"2":0, "3":1, "4":2, "5":3, "6":4, "7":5, "8":6, "9":7, "10":8, "J":9, "Q":10, "K":11, "A":12})
RANK_NAMES = [r for r in RANKS]  # list of rank enums to support value-to-name lookup

In [3]:
class Card:
    def __init__(self, rank, suit):
        self.rank, self.suit = rank, suit
        
    def __repr__(self):
        return "{}{}".format(self.rank.name, self.suit.value)


class Deck:
    def __init__(self):
        self.cards_ = [Card(rank, suit) for rank in RANKS for suit in SUITS]
    
    def draw(self):
        return self.cards_.pop(np.random.randint(len(self.cards_)))
    
    def __len__(self):
        return len(self.cards_)

In [4]:
deck = Deck()
hand = []
for i in range(5):
    hand.append(deck.draw())
print(hand)

[J♥, 5♠, A♦, A♠, 2♣]


In [5]:
def make_card(shorthand):
    assert len(shorthand) == 2 or len(shorthand) == 3
    suit_dict = {"C":SUITS.clubs, "D":SUITS.diamonds, "H":SUITS.hearts, "S":SUITS.spades}
    return Card(RANKS[shorthand[:-1].upper()], suit_dict[shorthand[-1].upper()])

def make_hand(shorthand):
    if isinstance(shorthand, str):
        shorthand = shorthand.split(" ")
    assert len(shorthand) == 5
    return [make_card(card_shorthand) for card_shorthand in shorthand]

In [6]:
print(make_card("jc"))
print(make_hand("js 8h 9d qs 10s"))

J♣
[J♠, 8♥, 9♦, Q♠, 10♠]


In [7]:
def is_in_pair(hand):
    return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 2 for c1 in hand]

def is_in_toak(hand):
    return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 3 for c1 in hand]

def is_in_foak(hand):
    return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 4 for c1 in hand]

In [8]:
print(is_in_pair(hand), is_in_toak(hand), is_in_foak(hand))

[False, False, True, True, False] [False, False, False, False, False] [False, False, False, False, False]


In [9]:
def next_rank(rank):
    return RANK_NAMES[rank.value + 1] if rank != RANKS.A else None

def next_n_rank(rank, n):
    for i in range(n):
        rank = next_rank(rank) if rank is not None else None
    return rank

def is_start_of_straight(hand):  # note that this does not consider valid straights like "5432A".
    return [all([len([c2 for c2 in hand if c2.rank == next_n_rank(c1.rank, n)]) > 0 for n in range(1, 5)]) for c1 in hand]

In [10]:
print(next_rank(hand[0].rank))
print(is_start_of_straight(make_hand("ah ks qs 10d jd")))
print(is_start_of_straight(make_hand("ah 5h 4h 3h 2h")))

rank.Q
[False, False, False, True, False]
[False, False, False, False, False]


In [11]:
def is_in_flush(hand):
    return [len([c2 for c2 in hand if c2.suit == c1.suit]) == 5 for c1 in hand]

In [12]:
print(is_in_flush(hand))

[False, False, False, False, False]


In [13]:
HANDS = Enum("hand", {"straight flush":0, "four of a kind":1, "full house":2, "flush":3, "straight":4,
                      "three of a kind":5, "two pair": 6, "one pair":7, "high card":8})
HAND_NAMES = [r for r in HANDS]  # list of rank enums to support value-to-name lookup

In [14]:
def classify_hand(hand):
    is_flush = any(is_in_flush(hand))  # can use all() here too
    is_straight = any(is_start_of_straight(hand))
    is_foak = any(is_in_foak(hand))
    is_toak = any(is_in_toak(hand))
    num_pairs = int(sum(is_in_pair(hand)) / 2)
    
    if is_flush and is_straight:
        hand_type = HANDS["straight flush"]
    elif is_foak:
        hand_type = HANDS["four of a kind"]
    elif is_toak and num_pairs == 1:
        hand_type = HANDS["full house"]
    elif is_flush:
        hand_type = HANDS["flush"]
    elif is_straight:
        hand_type = HANDS["straight"]
    elif is_toak:
        hand_type = HANDS["three of a kind"]
    elif num_pairs == 2:
        hand_type = HANDS["two pair"]
    elif num_pairs == 1:
        hand_type = HANDS["one pair"]
    else:
        hand_type = HANDS["high card"]
    
    details = {
        "hand_type": hand_type,
        "is_flush": is_flush,
        "is_straight": is_straight,
        "is_foak": is_foak,
        "is_toak": is_toak,
        "is_one_pair": num_pairs == 1,
        "is_two_pair": num_pairs == 2,
        "is_in_flush": is_in_flush(hand),
        "is_start_of_straight": is_start_of_straight(hand),
        "is_in_foak": is_in_foak(hand),
        "is_in_toak": is_in_toak(hand),
        "is_in_pair": is_in_pair(hand),
    }
    return hand_type, details

In [15]:
print(classify_hand(make_hand("jc 10c 9c 8c 7c"))[0])
print(classify_hand(make_hand("5c 5d 5h 5s 2d"))[0])
print(classify_hand(make_hand("6s 6h 6d ks kh"))[0])
print(classify_hand(make_hand("jh 9h 8h 4h 3h"))[0])
print(classify_hand(make_hand("10d 9s 8h 7d 6c"))[0])
print(classify_hand(make_hand("qc qs qh 9h 2s"))[0])
print(classify_hand(make_hand("jh js 3c 3s 2h"))[0])
print(classify_hand(make_hand("10s 10h 8s 7h 4c"))[0])
print(classify_hand(make_hand("kd qd 7s 4s 3h"))[0])

hand.straight flush
hand.four of a kind
hand.full house
hand.flush
hand.straight
hand.three of a kind
hand.two pair
hand.one pair
hand.high card


In [16]:
def random_hand():
    deck = Deck()
    return [deck.draw() for _ in range(5)]

for i in range(10):
    hand = random_hand()
    print(f"{hand}: {classify_hand(hand)}")

[9♦, 2♥, 9♥, A♥, 5♣]: (<hand.one pair: 7>, {'hand_type': <hand.one pair: 7>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': True, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [True, False, True, False, False]})
[9♣, 9♠, 8♥, K♦, 2♠]: (<hand.one pair: 7>, {'hand_type': <hand.one pair: 7>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': True, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [True, True, False, False, False]})
[9♣, 9♥, 3♥, 2♥, 8♦]: (<hand.one pair: 7>, {'hand_type': <hand.one pair: 7>,

## Dataset

Using the rule-based hand classifier `classify_hand()` above, we generate two randomly-drawn datasets `ds10k` (10,000 hands) and `ds1M` (1,000,000 hands), as well as two balanced datasets `mds10k` (10,000 hands) and `mds1M` (1,000,000 hands).

You can download the generated datasets here:
- ds10k: [Google Drive](https://drive.google.com/file/d/1G1s8hDv951WbdYpHaWZVmCLFeRqK_w1y/view?usp=sharing)
- ds1M: [Google Drive](https://drive.google.com/file/d/1PaUi7eU-VoHTtuWSywDK3Y9VDb4oQIny/view?usp=sharing)
- mds10k: [Google Drive](https://drive.google.com/file/d/1zwpkT4iVL2q3XhCq1K-liMfMHmw4hUr1/view?usp=sharing)
- mds1M: [Google Drive](https://drive.google.com/file/d/1np8Y_SSMt1PrJRCRYf3UQgaQ7a1Xb6j5/view?usp=sharing)

In [17]:
from tqdm import tqdm
import torch

In [18]:
def make_sample(hand):
    hand_details = classify_hand(hand)[1]
    return {"hand":hand, **hand_details}

In [19]:
def generate_dataset(num=10000):
    dataset = []
    with tqdm(range(num)) as t:
        for i in t:
            hand = random_hand()
            dataset.append(make_sample(hand))
    return dataset       

In [20]:
ds = generate_dataset(5)
print(ds)

100%|██████████| 5/5 [00:00<00:00, 3010.12it/s]

[{'hand': [7♠, 8♠, Q♦, K♠, 2♠], 'hand_type': <hand.high card: 8>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': False, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [False, False, False, False, False]}, {'hand': [A♠, 6♠, J♠, K♥, K♠], 'hand_type': <hand.one pair: 7>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': True, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [False, False, False, True, True]}, {'hand': [K♥, 9♠, 2♠, 7♣, 10♥], 'hand_type': <hand.high card: 8>, 'is_flush': False, 'is_straight




In [21]:
ds10k = generate_dataset(10000)

100%|██████████| 10000/10000 [00:02<00:00, 3924.04it/s]


In [22]:
ds1M = generate_dataset(1000000)

100%|██████████| 1000000/1000000 [04:24<00:00, 3785.44it/s]


In [30]:
np.unique([d["hand_type"].name for d in ds10k], return_counts=True)

(array(['flush', 'four of a kind', 'full house', 'high card', 'one pair',
        'straight', 'straight flush', 'three of a kind', 'two pair'],
       dtype='<U15'),
 array([  12,    3,   19, 5096, 4150,   41,    1,  197,  481]))

In [31]:
np.unique([d["hand_type"].name for d in ds1M], return_counts=True)

(array(['flush', 'four of a kind', 'full house', 'high card', 'one pair',
        'straight', 'straight flush', 'three of a kind', 'two pair'],
       dtype='<U15'),
 array([  2002,    247,   1441, 502266, 421813,   3587,     16,  21124,
         47504]))

In [32]:
def print_dataset_distribution(dataset, sort_by_count=False):
    hand_types, counts = np.unique([d["hand_type"].name for d in dataset], return_counts=True)
    hands = [(t, c) for t, c in zip(hand_types, counts)]
    if sort_by_count:
        hands = sorted(hands, key=lambda item: item[1])
    else:
        hands = sorted(hands, key=lambda item: HANDS[item[0]].value)
    for t, c in hands:
        print(f"{c / len(dataset) * 100:.4f}%: {t}")

In [33]:
print_dataset_distribution(ds1M)

0.0016%: straight flush
0.0247%: four of a kind
0.1441%: full house
0.2002%: flush
0.3587%: straight
2.1124%: three of a kind
4.7504%: two pair
42.1813%: one pair
50.2266%: high card


In [34]:
def is_hand_valid(hand):
    if len(hand) != 5:
        return False
    for i in range(5):
        for j in range(i + 1, 5):
            if hand[i].rank == hand[j].rank and hand[i].suit == hand[j].suit:
                return False
    return True

In [35]:
print(is_hand_valid(make_hand("4c 4c 5c 6c 10c")))
print(is_hand_valid(make_hand("3c 4c 5c 6c 10c")))

False
True


In [36]:
def permute_hand(hand):
    return [hand[i] for i in np.random.permutation(len(hand))]

In [37]:
def generate_random_hand_with_at_least_type(hand_type):
    # Note that this could generate a higher type than hand_type, albeit not likely in general.
    # e.g. when hand_type=HANDS["straight"], it could potentially generate a straight flush.
    if hand_type == HANDS["straight flush"]:
        start_rank = np.random.choice([r for r in RANKS if r not in [RANKS.J, RANKS.Q, RANKS.K, RANKS.A]])
        suit = np.random.choice(SUITS)
        hand = [Card(next_n_rank(start_rank, i), suit) for i in range(5)]
    elif hand_type == HANDS["four of a kind"]:
        ranks = np.random.choice(RANKS, 2, replace=False)
        hand = [Card(ranks[0], suit) for suit in SUITS]
        hand.append(Card(ranks[1], np.random.choice(SUITS)))
    elif hand_type == HANDS["full house"]:
        ranks = np.random.choice(RANKS, 2, replace=False)
        suits3 = np.random.choice(SUITS, 3, replace=False)
        suits2 = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits3]
        hand.extend([Card(ranks[1], s) for s in suits2])
    elif hand_type == HANDS["flush"]:
        suit = np.random.choice(SUITS)
        ranks = np.random.choice(RANKS, 5, replace=False)
        hand = [Card(r, suit) for r in ranks]
    elif hand_type == HANDS["straight"]:
        start_rank = np.random.choice([r for r in RANKS if r not in [RANKS.J, RANKS.Q, RANKS.K, RANKS.A]])
        hand = [Card(next_n_rank(start_rank, i), np.random.choice(SUITS)) for i in range(5)]
    elif hand_type == HANDS["three of a kind"]:
        ranks = np.random.choice(RANKS, 3, replace=False)
        suits3 = np.random.choice(SUITS, 3, replace=False)
        hand = [Card(ranks[0], s) for s in suits3]
        hand.extend([Card(r, np.random.choice(SUITS)) for r in ranks[1:]])
    elif hand_type == HANDS["two pair"]:
        ranks = np.random.choice(RANKS, 3, replace=False)
        suits1 = np.random.choice(SUITS, 2, replace=False)
        suits2 = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits1]
        hand.extend([Card(ranks[1], s) for s in suits2])
        hand.append(Card(ranks[2], np.random.choice(SUITS)))
    elif hand_type == HANDS["one pair"]:
        ranks = np.random.choice(RANKS, 4, replace=False)
        suits = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits]
        hand.extend([Card(r, np.random.choice(SUITS)) for r in ranks[1:]])
    elif hand_type == HANDS["high card"]:
        ranks = np.random.choice(RANKS, 5, replace=False)
        hand = [Card(r, np.random.choice(SUITS)) for r in ranks]
    hand = permute_hand(hand)
    assert is_hand_valid(hand)
    return hand        

In [38]:
for hand_type in HANDS:
    print(generate_random_hand_with_at_least_type(hand_type), hand_type)

[4♦, 6♦, 7♦, 3♦, 5♦] hand.straight flush
[6♣, 6♦, 6♠, 6♥, 5♣] hand.four of a kind
[4♥, 2♥, 4♦, 2♣, 4♠] hand.full house
[J♥, Q♥, 5♥, 4♥, 7♥] hand.flush
[2♦, 5♦, 4♠, 6♦, 3♠] hand.straight
[5♠, Q♠, 5♥, 5♦, A♦] hand.three of a kind
[2♥, 3♦, 10♣, 10♦, 3♠] hand.two pair
[J♦, A♦, 8♥, Q♣, 8♣] hand.one pair
[2♣, 4♣, 8♥, J♠, 6♣] hand.high card


In [39]:
def generate_mined_dataset(num=10000):
    dataset = []
    type_probs = {HANDS["straight flush"]:0.01, HANDS["four of a kind"]:0.1, HANDS["full house"]:0.1,
                  HANDS["flush"]:0.1, HANDS["straight"]:0.1, HANDS["three of a kind"]:0.1, 
                  HANDS["two pair"]:0.1, HANDS["one pair"]:0.1, HANDS["high card"]:0.1, None:0.19}
    with tqdm(range(num)) as t:
        for i in t:
            target_type = np.random.choice(list(type_probs.keys()), p=list(type_probs.values()))
            if target_type == None:
                hand = random_hand()
            else:
                hand = generate_random_hand_with_at_least_type(target_type)
            dataset.append(make_sample(hand))
    return dataset

In [40]:
mds10k = generate_mined_dataset(10000)

100%|██████████| 10000/10000 [00:03<00:00, 2888.62it/s]


In [42]:
mds1M = generate_mined_dataset(1000000)

100%|██████████| 1000000/1000000 [06:03<00:00, 2754.16it/s]


In [47]:
def serialize_dataset(dataset):
    return [{**{k:v for k, v in d.items()}, "hand":[(c.rank.name, c.suit.name) for c in d["hand"]], 
             "hand_type":d["hand_type"].name} for d in dataset]

def deserialize_dataset(dataset):
    return [{**{k:v for k, v in d.items()}, "hand":[Card(RANKS[r], SUITS[s]) for r, s in d["hand"]],
             "hand_type":HANDS[d["hand_type"]]} for d in dataset]  

In [48]:
print(ds10k[:2])
print(serialize_dataset(ds10k[:2]))
print(deserialize_dataset(serialize_dataset(ds10k[:2])))

[{'hand': [K♦, Q♦, 8♠, Q♠, A♠], 'hand_type': <hand.one pair: 7>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': True, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [False, True, False, True, False]}, {'hand': [Q♥, 7♥, 5♣, 10♣, 4♠], 'hand_type': <hand.high card: 8>, 'is_flush': False, 'is_straight': False, 'is_foak': False, 'is_toak': False, 'is_one_pair': False, 'is_two_pair': False, 'is_in_flush': [False, False, False, False, False], 'is_start_of_straight': [False, False, False, False, False], 'is_in_foak': [False, False, False, False, False], 'is_in_toak': [False, False, False, False, False], 'is_in_pair': [False, False, False, False, False]}]
[{'hand': [('K', 'diamonds'), ('Q', 'diamonds'), ('8', 'spades'), ('Q', 'spades'), ('A', 'spades

In [542]:
torch.save(serialize_dataset(ds10k), "/tmp/ds10k.pt")
torch.save(serialize_dataset(ds1M), "/tmp/ds1M.pt")
torch.save(serialize_dataset(mds10k), "/tmp/mds10k.pt")
torch.save(serialize_dataset(mds1M), "/tmp/mds1M.pt")

In [44]:
ds10k = deserialize_dataset(torch.load("/tmp/ds10k.pt"))
ds1M = deserialize_dataset(torch.load("/tmp/ds1M.pt"))
mds10k = deserialize_dataset(torch.load("/tmp/mds10k.pt"))
mds1M = deserialize_dataset(torch.load("/tmp/mds1M.pt"))

In [49]:
print_dataset_distribution(mds1M)

1.0953%: straight flush
9.9918%: four of a kind
10.0277%: full house
9.9835%: flush
10.0899%: straight
10.4068%: three of a kind
10.9183%: two pair
18.0493%: one pair
19.4374%: high card


## Model

The model is structured as follows:
- A single-layer Transformer encoder layer with 6 heads
- Five classifiers building on top of the Transformer encoding that output per-card binary classifications for the following intermediate labels:
  - `is_in_pair`
  - `is_in_toak`
  - `is_in_foak`
  - `is_start_of_straight`
  - `is_in_flush`
- A final classifier that takes the outputs of the five classifiers and classifies the hand into one of the 9 types.


In [50]:
import torch
import torch.nn as nn

DEVICE = "cuda"

In [51]:
NC = 17  # 17 = len(RANKS) + len(SUITS)

def encode_hand(hand):
    suit_map = {SUITS["clubs"]:0, SUITS["diamonds"]:1, SUITS["hearts"]:2, SUITS["spades"]:3}
    x = torch.zeros([5, NC])
    for i, c in enumerate(hand):
        x[i, c.rank.value] = 1
        x[i, len(RANKS) + suit_map[c.suit]] = 1
    return x  # [5, NC]

In [52]:
class MultiHeadedAttentionBackbone(nn.Module):
    def __init__(self, num_heads=8, feature_size=32, kdim=32, vdim=32, hidden_size=32):
        super(MultiHeadedAttentionBackbone, self).__init__()
        self.num_heads = num_heads
        self.feature_size = feature_size
        self.kdim, self.vdim = kdim, vdim
        self.proj_key = nn.Linear(NC, kdim)
        self.proj_query = nn.Linear(NC, kdim)
        self.proj_value = nn.Linear(NC, vdim)
        self.linear1 = nn.Linear(vdim, hidden_size)
        self.linear2 = nn.Linear(hidden_size, feature_size)
        self.act = nn.ReLU()
    
    def forward(self, x):  # [B, 5, NC] -> [B, 5, feature_size]
        x = x.view(-1, NC)  # [B * 5, NC]
        key, query, value = self.proj_key(x), self.proj_query(x), self.proj_value(x)  # [B * 5, kdim/vdim]
        key = key.view(-1, 5, self.num_heads, self.kdim // self.num_heads).transpose(1, 2)  # [B, nh, 5, kdim / nh]
        query = query.view(-1, 5, self.num_heads, self.kdim // self.num_heads).transpose(1, 2)  # [B, nh, 5, kdim / nh]
        value = value.view(-1, 5, self.num_heads, self.vdim // self.num_heads).transpose(1, 2)  # [B, nh, 5, vdim / nh]
        attention = torch.matmul(query, key.transpose(-1, -2))  # [B, nh, 5, 5]
        x = torch.matmul(attention, value)  # [B, nh, 5, vdim / nh]
        x = x.transpose(1, 2).reshape(-1, 5, self.vdim)  # [B, 5, vdim]
        return self.linear2(self.act(self.linear1(x)))

In [53]:
class BinaryPredictor(nn.Module):
    def __init__(self, feature_size=32, hidden_size=16):
        super(BinaryPredictor, self).__init__()
        self.feature_size = feature_size
        self.hidden_size = hidden_size
        self.linear1 = nn.Linear(feature_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 1)
        self.act = nn.ReLU()
        
    def forward(self, x):  # [B * 5, feature_size] -> [B * 5]
        return self.linear2(self.act(self.linear1(x))).squeeze(dim=-1)

In [54]:
class TypePredictor(nn.Module):
    def __init__(self, hidden_size=32):
        super(TypePredictor, self).__init__()
        self.hidden_size = hidden_size
        self.linear1 = nn.Linear(25, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 9)
        self.act = nn.ReLU()
        
    def forward(self, x):  # [B, 5, 5] -> [B, 9]
        return self.linear2(self.act(self.linear1(x.view(-1, 25))))  # [B, 25] -> [B, 9]

In [55]:
class HandClassifier(nn.Module):
    def __init__(self, backbone=MultiHeadedAttentionBackbone, 
                 binary_predictor=BinaryPredictor, type_predictor=TypePredictor, 
                 binary_loss=nn.BCEWithLogitsLoss, type_loss=nn.CrossEntropyLoss, loss_weights=None):
        super(HandClassifier, self).__init__()
        if loss_weights is not None:
            self.loss_weights = loss_weights
        else:
            self.loss_weights = {k:1.0 for k in ["pair", "toak", "foak", "sofs", "flush", "type"]}

        self.backbone = backbone()
        self.c = self.backbone.feature_size
        self.pred_pair = binary_predictor(feature_size=self.c)
        self.pred_toak = binary_predictor(feature_size=self.c)
        self.pred_foak = binary_predictor(feature_size=self.c)
        self.pred_sofs = binary_predictor(feature_size=self.c)
        self.pred_flush = binary_predictor(feature_size=self.c)
        self.pred_type = type_predictor()
        self.binary_loss = binary_loss()
        self.type_loss = type_loss()
    
    def forward(self, batch):
        batch_size = len(batch)
        x = torch.stack([encode_hand(h["hand"]) for h in batch]).to(DEVICE)  # [B, 5, NC]
        labels = {k:torch.stack([torch.tensor(h[k], dtype=torch.float) for h in batch], dim=0).to(DEVICE) 
                  for k in batch[0].keys() if k != "hand" and k != "hand_type"}  # [B, 5]
        labels["hand_type"] = torch.stack([torch.tensor(h["hand_type"].value) for h in batch]).to(DEVICE)  # [B]
        
        features = self.backbone(x)  # [B, 5, C]
        features = features.view([batch_size * 5, -1])  # [B * 5, C]
        y_pair = self.pred_pair(features).view([batch_size, 5])  # [B, 5]
        y_toak = self.pred_toak(features).view([batch_size, 5])  # [B, 5]
        y_foak = self.pred_foak(features).view([batch_size, 5])  # [B, 5]
        y_sofs = self.pred_sofs(features).view([batch_size, 5])  # [B, 5]
        y_flush = self.pred_flush(features).view([batch_size, 5])  # [B, 5]
        y_type = self.pred_type(torch.where(
            torch.stack([y_pair, y_toak, y_foak, y_sofs, y_flush], dim=-1) > 0.0,
            1.0, -1.0))  # [B, 9] for logits of 9 types
        preds = {"pair":y_pair, "toak":y_toak, "foak":y_foak, "sofs":y_sofs, "flush":y_flush, "type":y_type}
        
        l_pair = self.binary_loss(y_pair.flatten(), labels["is_in_pair"].flatten())  # [B * 5], [B * 5]
        l_toak = self.binary_loss(y_toak.flatten(), labels["is_in_toak"].flatten())  # [B * 5], [B * 5]
        l_foak = self.binary_loss(y_foak.flatten(), labels["is_in_foak"].flatten())  # [B * 5], [B * 5]
        l_sofs = self.binary_loss(y_sofs.flatten(), labels["is_start_of_straight"].flatten())  # [B * 5], [B * 5]
        l_flush = self.binary_loss(y_flush.flatten(), labels["is_in_flush"].flatten())  # [B * 5], [B * 5]
        l_type = self.type_loss(y_type, labels["hand_type"])  # [B, 9], [B]
        losses = {"pair":l_pair, "toak":l_toak, "foak":l_foak, "sofs":l_sofs, "flush":l_flush, "type":l_type}
        loss = sum([losses[k] * self.loss_weights[k] for k in losses])
        
        return preds, loss, losses

## Training

Training using the balanced dataset `mds1M` for 10 epochs, evaluating with `mds10k` after each epoch.

The trained model parameters can be downloaded here:
- trained_model.pt: [Google Drive](https://drive.google.com/file/d/1ogqoOWLlPfH7XtrdVCRcgjS8bBv1Fl4K/view?usp=sharing)

In [56]:
from timeit import default_timer as timer
from torch.utils.data import DataLoader

In [57]:
def collate_losses(losses_history, loss_history):
    losses = {k:np.array([l[k] for l in losses_history]) for k in losses_history[0]}
    losses["total"] = np.array(loss_history)
    return losses

In [58]:
def collate_data(batch):
    return batch

In [59]:
def train_epoch(model, optimizer, dataset, batch_size=16):
    model.train()
    losses_history = []
    loss_history = []
    train_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_data)

    with tqdm(train_dataloader) as t:
        for batch in t:
            pred, loss, losses = model(batch)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            losses_history.append({k:v.item() for k, v in losses.items()})
            loss_history.append(loss.item())

    return collate_losses(losses_history, loss_history)

In [60]:
def evaluate(model, dataset):
    model.eval()
    losses_history = []
    loss_history = []
    val_dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_data)

    with torch.no_grad():
        with tqdm(val_dataloader) as t:
            for batch in val_dataloader:
                pred, loss, losses = model(batch)
                losses_history.append({k:v.item() for k, v in losses.items()})
                loss_history.append(loss.item())

    return collate_losses(losses_history, loss_history)

In [61]:
model = HandClassifier().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [62]:
EPOCHS = 10

In [63]:
for epoch in range(EPOCHS):
    t0 = timer()
    train_losses = train_epoch(model, optimizer, mds1M, batch_size=256)
    t1 = timer()
    eval_losses = evaluate(model, mds10k)
    t2 = timer()
    print("Epoch {}: train loss: {}, time: {:.3f}s; eval time:{:.3f}s".format(
        epoch, train_losses["total"].mean(), t1 - t0, t2 - t1))
    print("Eval losses:")
    for k, v in eval_losses.items():
        print(f"  {k}: {v.mean()}")

torch.save(model.state_dict(), "/tmp/model.pt")  # full_model_rc1.pt

100%|██████████| 3907/3907 [01:05<00:00, 59.82it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 7/3907 [00:00<01:04, 60.10it/s]

Epoch 0: train loss: 3.2327789754430625, time: 65.315s; eval time:1.354s
Eval losses:
  pair: 0.17626495567560196
  toak: 0.14021759887039661
  foak: 0.010860708715835493
  sofs: 0.09654889646172524
  flush: 0.030597994036832825
  type: 1.0309969388008118
  total: 1.4854870942115783


100%|██████████| 3907/3907 [01:05<00:00, 59.35it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 7/3907 [00:00<01:04, 60.73it/s]

Epoch 1: train loss: 0.5758888616048072, time: 65.832s; eval time:1.377s
Eval losses:
  pair: 0.0008964573390097939
  toak: 0.0002534146185711279
  foak: 2.9855095651456053e-06
  sofs: 0.07696698193103076
  flush: 0.0035200736076715176
  type: 0.2275260203510523
  total: 0.30916593297123907


100%|██████████| 3907/3907 [01:05<00:00, 59.55it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 7/3907 [00:00<01:04, 60.35it/s]

Epoch 2: train loss: 0.2899755928174352, time: 65.610s; eval time:1.356s
Eval losses:
  pair: 0.0003478460978536987
  toak: 0.0004278624134446801
  foak: 3.9139058614122744e-05
  sofs: 0.03892876878529787
  flush: 0.0061080940746409285
  type: 0.21606109404861926
  total: 0.26191280455887317


100%|██████████| 3907/3907 [01:07<00:00, 58.21it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:06, 58.62it/s]

Epoch 3: train loss: 0.18898900623841147, time: 67.126s; eval time:1.440s
Eval losses:
  pair: 0.00031238914864568413
  toak: 0.00023700149800325612
  foak: 2.7471535365964785e-05
  sofs: 0.017522205374660007
  flush: 0.001887308512669182
  type: 0.07129649670124054
  total: 0.09128287282735109


100%|██████████| 3907/3907 [01:06<00:00, 58.86it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:11, 54.63it/s]

Epoch 4: train loss: 0.04674918251505419, time: 66.382s; eval time:1.380s
Eval losses:
  pair: 0.0001826443112885954
  toak: 0.00010746970938275951
  foak: 7.742229638501285e-06
  sofs: 0.008484614562388743
  flush: 0.0008494839529497935
  type: 0.012590069139935076
  total: 0.022222023960389196


100%|██████████| 3907/3907 [01:05<00:00, 59.22it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:05, 59.33it/s]

Epoch 5: train loss: 0.013606977452315906, time: 65.973s; eval time:1.386s
Eval losses:
  pair: 9.928752578484818e-05
  toak: 4.522226110883736e-05
  foak: 3.301297513674939e-06
  sofs: 0.005914230295699545
  flush: 0.0004717644658461559
  type: 0.003794745466738823
  total: 0.010328551337413956


100%|██████████| 3907/3907 [01:05<00:00, 59.38it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:05, 59.81it/s]

Epoch 6: train loss: 0.008455526402799615, time: 65.803s; eval time:1.326s
Eval losses:
  pair: 6.155378279867847e-05
  toak: 2.7467006290693162e-05
  foak: 1.5026298570425922e-06
  sofs: 0.004648981631765865
  flush: 0.0002879352529573243
  type: 0.001678982823208571
  total: 0.006706423136372177


100%|██████████| 3907/3907 [01:05<00:00, 59.21it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:06, 58.63it/s]

Epoch 7: train loss: 0.006176574672331558, time: 65.994s; eval time:1.365s
Eval losses:
  pair: 4.0927081141554565e-05
  toak: 1.7984007946623847e-05
  foak: 1.3296369718275968e-06
  sofs: 0.0034598687237328376
  flush: 0.0001748335982263473
  type: 0.00017881288347396094
  total: 0.0038737559535948094


100%|██████████| 3907/3907 [01:06<00:00, 58.79it/s]
  0%|          | 0/625 [00:01<?, ?it/s]
  0%|          | 6/3907 [00:00<01:07, 58.20it/s]

Epoch 8: train loss: 0.0049157091768614386, time: 66.461s; eval time:1.389s
Eval losses:
  pair: 5.302976056605502e-05
  toak: 1.538497289417009e-05
  foak: 1.4529387347703703e-06
  sofs: 0.0029448354291554024
  flush: 0.00011595185719859629
  type: 0.00011014274632542395
  total: 0.003240797678726449


100%|██████████| 3907/3907 [01:06<00:00, 58.92it/s]
  0%|          | 0/625 [00:01<?, ?it/s]

Epoch 9: train loss: 0.0040780217000469235, time: 66.319s; eval time:1.363s
Eval losses:
  pair: 5.4882844384809456e-05
  toak: 1.2831331505153187e-05
  foak: 1.2235363135744136e-06
  sofs: 0.0026504645480689326
  flush: 8.036554988575659e-05
  type: 0.0008702795249988313
  total: 0.0036700473625056475





## Evaluation

Detailed evaluation of the `HandClassifier` model, including accuracies in the intermediate outputs and per-class accuracy in the hand type classification.

In [64]:
model.load_state_dict(torch.load("/tmp/model.pt"))
model = model.eval()

In [65]:
batch = mds10k[:10]
with torch.no_grad():
    pred, loss, losses = model(batch)
    for i in range(len(batch)):
        print("hand: ", batch[i]["hand"])
        print("pred: ", {k:v[i].cpu() for k, v, in pred.items()})
        print("pred results:")
        for k in ["pair", "toak", "foak", "sofs", "flush"]:
            print("  {}: {}".format(k, torch.nn.functional.sigmoid(pred[k][i]).cpu().numpy()))
        print("  hand type: {}".format(HAND_NAMES[torch.argmax(pred["type"][i], dim=-1)]))

hand:  [J♥, 9♠, 2♦, J♣, J♦]
pred:  {'pair': tensor([ -55.0838,  -68.2008, -160.5719,  -51.9155,  -54.4946]), 'toak': tensor([  27.3604, -149.6196, -205.8250,   25.1986,   20.6010]), 'foak': tensor([ -79.0714, -345.9962, -392.5587,  -72.1998,  -53.8967]), 'sofs': tensor([-239.7527,  -49.4829, -131.1804, -230.2635, -224.4551]), 'flush': tensor([-143.4251, -111.5795,  -76.0997, -134.4922, -121.8825]), 'type': tensor([-62.4672, -15.0344,  -9.3527, -20.4145, -16.3644,  13.5573, -89.3692,
        -30.0189,  -4.0962])}
pred results:
  pair: [1.1951634e-24 2.4031069e-30 0.0000000e+00 2.8404978e-23 2.1541788e-24]
  toak: [1. 0. 0. 1. 1.]
  foak: [4.5678571e-35 0.0000000e+00 0.0000000e+00 4.4057775e-32 3.9170865e-24]
  sofs: [0.0000000e+00 3.2347707e-22 0.0000000e+00 0.0000000e+00 0.0000000e+00]
  flush: [0.000000e+00 0.000000e+00 8.919007e-34 0.000000e+00 0.000000e+00]
  hand type: hand.three of a kind
hand:  [6♦, J♣, 7♠, 3♣, 10♥]
pred:  {'pair': tensor([-100.8473,  -95.2242,  -94.8169,  -95.78



In [66]:
# Confusion matrix: [[TP, FN], [FP, TN]]
def binary_confusion(preds, labels):
    return np.array([[torch.logical_and(preds > 0.0, labels).sum(), 
                      torch.logical_and(preds <= 0.0, labels).sum()],
                     [torch.logical_and(preds > 0.0, torch.logical_not(labels)).sum(),
                      torch.logical_and(preds <= 0.0, torch.logical_not(labels)).sum()]], dtype=int)

def type_confusion(preds, labels):
    confusion = np.zeros([9, 9], dtype=int)
    preds, labels = torch.argmax(preds, dim=-1).flatten(), torch.flatten(labels)
    combinations, counts = torch.unique(torch.stack([preds, labels], dim=1), dim=0, return_counts=True)
    confusion[combinations[:, 1].cpu(), combinations[:, 0].cpu()] = counts.cpu()
    return confusion

def eval_accuracy(model, dataset):
    val_dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_data)
    confusion = {"pair":np.zeros([2, 2], dtype=int), "toak":np.zeros([2, 2], dtype=int), 
                 "foak":np.zeros([2, 2], dtype=int), "sofs":np.zeros([2, 2], dtype=int),
                 "flush":np.zeros([2, 2], dtype=int), "type":np.zeros([9, 9], dtype=int)}
    correct_counts = {"pair":0, "toak":0, "foak":0, "sofs":0, "flush":0, "type":0}
    with torch.no_grad():
        with tqdm(val_dataloader) as t:
            for batch in t:
                pred, loss, losses = model(batch)
                labels = {k:torch.stack([torch.tensor(h[k], dtype=torch.bool) for h in batch], dim=0).to(DEVICE)
                    for k in batch[0].keys() if k != "hand" and k != "hand_type"}  # [B, 5]
                labels["hand_type"] = torch.stack([torch.tensor(h["hand_type"].value) for h in batch]).to(DEVICE)  # [B]
                confusion["pair"] += binary_confusion(pred["pair"], labels["is_in_pair"])
                confusion["toak"] += binary_confusion(pred["toak"], labels["is_in_toak"])
                confusion["foak"] += binary_confusion(pred["foak"], labels["is_in_foak"])
                confusion["sofs"] += binary_confusion(pred["sofs"], labels["is_start_of_straight"])
                confusion["flush"] += binary_confusion(pred["flush"], labels["is_in_flush"])
                confusion["type"] += type_confusion(pred["type"], labels["hand_type"])

    for k, c in confusion.items():
        if k != "type":
            print("{}: precision {:.2f}% recall {:.2f}% accuracy {:.2f}% positive labels {:.2f}%".format(
                k, c[0, 0] / (c[0, 0] + c[1, 0]) * 100, c[0, 0] / (c[0, 0] + c[0, 1]) * 100,
                (c[0, 0] + c[1, 1]) / c.sum() * 100, (c[0, 0] + c[0, 1]) / c.sum() * 100))
        else:
            print("{}: accuracy {:.2f}%".format(k, np.diag(c).sum() / c.sum() * 100))
            for i, hand_type in enumerate(HANDS):
                print("    {:15s}: accuracy {:.2f}% labels {:.2f}%".format(
                    hand_type.name, c[i, i] / c[i].sum() * 100, c[i].sum() / c.sum() * 100))


In [67]:
eval_accuracy(model, mds10k)

100%|██████████| 625/625 [00:03<00:00, 157.62it/s]

pair: precision 100.00% recall 100.00% accuracy 100.00% positive labels 19.78%
toak: precision 100.00% recall 100.00% accuracy 100.00% positive labels 12.18%
foak: precision 100.00% recall 100.00% accuracy 100.00% positive labels 7.99%
sofs: precision 96.91% recall 99.91% accuracy 99.93% positive labels 2.20%
flush: precision 99.97% recall 100.00% accuracy 100.00% positive labels 11.62%
type: accuracy 99.99%
    straight flush : accuracy 100.00% labels 1.07%
    four of a kind : accuracy 100.00% labels 9.99%
    full house     : accuracy 100.00% labels 9.69%
    flush          : accuracy 100.00% labels 10.55%
    straight       : accuracy 99.90% labels 9.92%
    three of a kind: accuracy 100.00% labels 10.61%
    two pair       : accuracy 100.00% labels 10.86%
    one pair       : accuracy 100.00% labels 18.04%
    high card      : accuracy 100.00% labels 19.27%





## Hand-crafted model

A `HandClassifier` model built with hand-crafted hyper-parameters and network parameters.
- Transformer encoder multi-headed attention hyper-parameters:
  - Number of heads: 6
  - Key/query size: 13 per head (78 in total)
  - Value size: 1 per head (6 in total)
  - Feedforward hidden layer size: 6
- Binary classifier hyper-parameters:
  - Hidden layer size: 4
- Type classifier hyper-parameters:
  - Hidden layer size: 25

The model parameters can be downloaded here:
- hand_crafted_model.pt: [Google Drive](https://drive.google.com/file/d/1o03VEvMGQcrmLWff_EnX1ytwpYLBeX3O/view?usp=sharing)

In [68]:
import copy

In [69]:
hmodel = HandClassifier(
    backbone=lambda: MultiHeadedAttentionBackbone(num_heads=6, feature_size=6, kdim=78, vdim=6, hidden_size=6), 
    binary_predictor=lambda feature_size: BinaryPredictor(feature_size, hidden_size=4),
    type_predictor=lambda: TypePredictor(hidden_size=25))
hmodel

HandClassifier(
  (backbone): MultiHeadedAttentionBackbone(
    (proj_key): Linear(in_features=17, out_features=78, bias=True)
    (proj_query): Linear(in_features=17, out_features=78, bias=True)
    (proj_value): Linear(in_features=17, out_features=6, bias=True)
    (linear1): Linear(in_features=6, out_features=6, bias=True)
    (linear2): Linear(in_features=6, out_features=6, bias=True)
    (act): ReLU()
  )
  (pred_pair): BinaryPredictor(
    (linear1): Linear(in_features=6, out_features=4, bias=True)
    (linear2): Linear(in_features=4, out_features=1, bias=True)
    (act): ReLU()
  )
  (pred_toak): BinaryPredictor(
    (linear1): Linear(in_features=6, out_features=4, bias=True)
    (linear2): Linear(in_features=4, out_features=1, bias=True)
    (act): ReLU()
  )
  (pred_foak): BinaryPredictor(
    (linear1): Linear(in_features=6, out_features=4, bias=True)
    (linear2): Linear(in_features=4, out_features=1, bias=True)
    (act): ReLU()
  )
  (pred_sofs): BinaryPredictor(
    (lin

In [70]:
# Backbone weights
# attention output channels:
#  - count of cards of the same rank in the hand
#  - count of cards of the same suit in the hand
#  - count of cards one rank higher in the hand
#  - count of cards two ranks higher in the hand
#  - count of cards three ranks higher in the hand
#  - count of cards four ranks higher in the hand
w = torch.zeros(78, 17)
w[:13, :13] = torch.eye(13, 13)  # picks out the rank
w[13:17, 13:] = torch.eye(4, 4)  # picks out the suit
w[26:39, :13] = torch.eye(13, 13)  # picks out the rank
w[39:52, :13] = torch.eye(13, 13)  # picks out the rank
w[52:65, :13] = torch.eye(13, 13)  # picks out the rank
w[65:78, :13] = torch.eye(13, 13)  # picks out the rank
b = torch.zeros(78)
hmodel.backbone.proj_key.weight = nn.Parameter(w)
hmodel.backbone.proj_key.bias = nn.Parameter(b)

w = torch.zeros(78, 17)
w[:13, :13] = torch.eye(13, 13)  # picks out the rank
w[13:17, 13:] = torch.eye(4, 4)  # picks out the suit
w[26:39, :13] = torch.diag(torch.ones(12), diagonal=-1)  # rank + 1
w[39:52, :13] = torch.diag(torch.ones(11), diagonal=-2)  # rank + 2
w[52:65, :13] = torch.diag(torch.ones(10), diagonal=-3)  # rank + 3
w[65:78, :13] = torch.diag(torch.ones(9), diagonal=-4)  # rank + 4
b = torch.zeros(78)
hmodel.backbone.proj_query.weight = nn.Parameter(w)
hmodel.backbone.proj_query.bias = nn.Parameter(b)

hmodel.backbone.proj_value.weight = nn.Parameter(torch.zeros(6, 17))
hmodel.backbone.proj_value.bias = nn.Parameter(torch.ones(6))

# trivial feedforward
hmodel.backbone.linear1.weight = nn.Parameter(torch.eye(6, 6))
hmodel.backbone.linear1.bias = nn.Parameter(torch.zeros(6))

hmodel.backbone.linear2.weight = nn.Parameter(torch.eye(6, 6))
hmodel.backbone.linear2.bias = nn.Parameter(torch.zeros(6))

In [71]:
# pred_pair weights
w = torch.zeros(4, 6)
w[0, 0] = -1
w[1, 0] = 1
b = torch.zeros(4)
b[0] = 2
b[1] = -2
hmodel.pred_pair.linear1.weight = nn.Parameter(w)
hmodel.pred_pair.linear1.bias = nn.Parameter(b)

w = torch.zeros(1, 4)
w[0, 0] = -2
w[0, 1] = -2
b = torch.zeros(1)
b[0] = 1
hmodel.pred_pair.linear2.weight = nn.Parameter(w)
hmodel.pred_pair.linear2.bias = nn.Parameter(b)

In [72]:
# pred_toak weights
w = torch.zeros(4, 6)
w[0, 0] = -1
w[1, 0] = 1
b = torch.zeros(4)
b[0] = 3
b[1] = -3
hmodel.pred_toak.linear1.weight = nn.Parameter(w)
hmodel.pred_toak.linear1.bias = nn.Parameter(b)

w = torch.zeros(1, 4)
w[0, 0] = -2
w[0, 1] = -2
b = torch.zeros(1)
b[0] = 1
hmodel.pred_toak.linear2.weight = nn.Parameter(w)
hmodel.pred_toak.linear2.bias = nn.Parameter(b)

In [73]:
# pred_foak weights
w = torch.zeros(4, 6)
w[0, 0] = -1
w[1, 0] = 1
b = torch.zeros(4)
b[0] = 4
b[1] = -4
hmodel.pred_foak.linear1.weight = nn.Parameter(w)
hmodel.pred_foak.linear1.bias = nn.Parameter(b)

w = torch.zeros(1, 4)
w[0, 0] = -2
w[0, 1] = -2
b = torch.zeros(1)
b[0] = 1
hmodel.pred_foak.linear2.weight = nn.Parameter(w)
hmodel.pred_foak.linear2.bias = nn.Parameter(b)

In [74]:
# pred_sofs weights
hmodel.pred_sofs.linear1.weight = nn.Parameter(-torch.diag(torch.ones(4), diagonal=2)[:4])
hmodel.pred_sofs.linear1.bias = nn.Parameter(torch.ones(4))

w = -torch.ones(4).view(1, 4) * 2
b = torch.zeros(1)
b[0] = 1
hmodel.pred_sofs.linear2.weight = nn.Parameter(w)
hmodel.pred_sofs.linear2.bias = nn.Parameter(b)

In [75]:
# pred_flush weights
w = torch.zeros(4, 6)
w[0, 1] = -1
w[1, 1] = 1
b = torch.zeros(4)
b[0] = 5
b[1] = -5
hmodel.pred_flush.linear1.weight = nn.Parameter(w)
hmodel.pred_flush.linear1.bias = nn.Parameter(b)

w = torch.zeros(1, 4)
w[0, 0] = -2
w[0, 1] = -2
b = torch.zeros(1)
b[0] = 1
hmodel.pred_flush.linear2.weight = nn.Parameter(w)
hmodel.pred_flush.linear2.bias = nn.Parameter(b)

In [76]:
# pred_type weights
hmodel.pred_type.linear1.weight = nn.Parameter(torch.eye(25, 25))
hmodel.pred_type.linear1.bias = nn.Parameter(torch.zeros(25))

w = torch.zeros(9, 25)
b = torch.zeros(9)
# straight flush
w[0, 3::5] = 2
w[0, 4::5] = 2
b[0] = -10
# four of a kind
w[1, 2::5] = 1
b[1] = -3
# full house
w[2, 0::5] = 2
w[2, 1::5] = 2
b[2] = -8
# flush
w[3, 4::5] = 1
b[3] = -4
# straight
w[4, 3::5] = 1
b[4] = 0
# three of a kind
w[5, 1::5] = 1
b[5] = -2
# two pair
w[6, 0::5] = 3
b[6] = -8
# one pair
w[7, 0::5] = 1
b[7] = -1
b[8] = 0.5
hmodel.pred_type.linear2.weight = nn.Parameter(w)
hmodel.pred_type.linear2.bias = nn.Parameter(b)

In [77]:
hmodel = hmodel.to(DEVICE)
eval_accuracy(hmodel, mds10k)

100%|██████████| 625/625 [00:03<00:00, 157.16it/s]

pair: precision 100.00% recall 100.00% accuracy 100.00% positive labels 19.78%
toak: precision 100.00% recall 100.00% accuracy 100.00% positive labels 12.18%
foak: precision 100.00% recall 100.00% accuracy 100.00% positive labels 7.99%
sofs: precision 100.00% recall 100.00% accuracy 100.00% positive labels 2.20%
flush: precision 100.00% recall 100.00% accuracy 100.00% positive labels 11.62%
type: accuracy 100.00%
    straight flush : accuracy 100.00% labels 1.07%
    four of a kind : accuracy 100.00% labels 9.99%
    full house     : accuracy 100.00% labels 9.69%
    flush          : accuracy 100.00% labels 10.55%
    straight       : accuracy 100.00% labels 9.92%
    three of a kind: accuracy 100.00% labels 10.61%
    two pair       : accuracy 100.00% labels 10.86%
    one pair       : accuracy 100.00% labels 18.04%
    high card      : accuracy 100.00% labels 19.27%



