In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# [A♠, K♥, Q♦, J♣, 10♠]

# ---- card suit and rank encoder ----
class CardEncoder(nn.Module):
    def __init__(self, suits=6, ranks=15, suit_dim=4, rank_dim=8, output_dim=16):
        super().__init__()
        self.suit_emb = nn.Embedding(suits, suit_dim)
        self.rank_emb = nn.Embedding(ranks, rank_dim)
        self.out = nn.Linear(suit_dim + rank_dim, output_dim)

    def forward(self, card_tensor):
        suit_tensor = card_tensor[:, 0]
        rank_tensor = card_tensor[:, 1]
        suit_emb = self.suit_emb(suit_tensor)
        rank_emb = self.rank_emb(rank_tensor)
        combined = torch.cat((suit_emb, rank_emb), dim=-1)
        output = self.out(combined)
        return output

# ---- poker hand classifier ----
class PokerHandClassifier(nn.Module):
    def __init__(self, input_dim=16, hidden_dim=32, num_classes=11):
        super().__init__()
        self.fc1 = nn.Linear(input_dim * 5, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, hand_tensor):
        batch_size = hand_tensor.size(0)
        hand_flat = hand_tensor.view(batch_size, -1)
        x = F.relu(self.fc1(hand_flat))
        output = self.fc2(x)
        return output




# ---- START ----
class PokerHandClassifier_WithAttention(nn.Module):
    def __init__(self, input_dim=16, hidden_dim=32, num_classes=11, num_heads=4):
        super().__init__()
        # self-attention layer lets cards "interact"
        self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
        
        # optional feedforward layer after attention
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, hand_tensor):
        # hand_tensor: (batch, num_cards, input_dim)
        attn_output, _ = self.attn(hand_tensor, hand_tensor, hand_tensor)
        
        # permutation-invariant pooling (mean or sum)
        pooled = attn_output.mean(dim=1)  # (batch, input_dim)
        
        x = F.relu(self.fc1(pooled))
        output = self.fc2(x)
        return output
# ---- END ----


# ---- START ----
class PokerHandClassifier_SelfAttention(nn.Module):
    def __init__(self, input_dim=16, hidden_dim=32, num_classes=11, num_heads=4, num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
            for _ in range(num_layers)
        ])
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, hand_tensor):
        x = hand_tensor
        for attn in self.layers:
            attn_out, _ = attn(x, x, x)
            x = x + attn_out  # residual connection
        pooled = x.mean(dim=1)
        x = F.relu(self.fc1(pooled))
        return self.fc2(x)
# ---- END ----




class AttentionPooling(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x: (batch, num_cards, input_dim)
        attn_scores = self.attn(x)                 # (batch, num_cards, 1)
        attn_weights = torch.softmax(attn_scores, dim=1)
        pooled = torch.sum(attn_weights * x, dim=1)  # weighted sum (batch, input_dim)
        return pooled

# ---- Alternative classifier with attention pooling ----
class PokerHandClassifier_AttentionPooling(nn.Module):
    def __init__(self, input_dim=16, hidden_dim=32, num_classes=11):
        super().__init__()
        self.pool = AttentionPooling(input_dim, hidden_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, hand_tensor):
        pooled = self.pool(hand_tensor)
        x = F.relu(self.fc1(pooled))
        return self.fc2(x)





# --- Multihead Self-Attention Block (SAB) ---
class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads=4):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=dim_out, num_heads=num_heads, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(dim_out, dim_out * 2),
            nn.ReLU(),
            nn.Linear(dim_out * 2, dim_out),
        )
        self.ln1 = nn.LayerNorm(dim_out)
        self.ln2 = nn.LayerNorm(dim_out)
        self.proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()

    def forward(self, x):
        # Project to match MHA input dimension if necessary
        x_proj = self.proj(x)
        attn_out, _ = self.mha(x_proj, x_proj, x_proj)
        x = self.ln1(x_proj + attn_out)
        ff_out = self.fc(x)
        return self.ln2(x + ff_out)

# --- Pooling by Multihead Attention (PMA) ---
class PMA(nn.Module):
    def __init__(self, dim, num_heads=4, num_seeds=1):
        super().__init__()
        self.seed_vectors = nn.Parameter(torch.randn(num_seeds, dim))
        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, x):
        batch_size = x.size(0)
        seed = self.seed_vectors.unsqueeze(0).repeat(batch_size, 1, 1)  # (B, num_seeds, dim)
        out, _ = self.mha(seed, x, x)
        return out  # (B, num_seeds, dim)

# --- Full Set Transformer Classifier ---
class PokerHandClassifier_SetTransformer(nn.Module):
    def __init__(self, input_dim=16, dim_hidden=32, num_heads=4, num_classes=11):
        super().__init__()
        self.sab1 = SAB(input_dim, dim_hidden, num_heads)
        self.sab2 = SAB(dim_hidden, dim_hidden, num_heads)
        self.pma = PMA(dim_hidden, num_heads, num_seeds=1)
        self.fc = nn.Sequential(
            nn.Linear(dim_hidden, dim_hidden),
            nn.ReLU(),
            nn.Linear(dim_hidden, num_classes),
        )

    def forward(self, hand_tensor):
        # hand_tensor: (batch, num_cards, input_dim)
        x = self.sab1(hand_tensor)
        x = self.sab2(x)
        x = self.pma(x).squeeze(1)  # (batch, dim_hidden)
        return self.fc(x)









# ---- card string to card tuple converter ----
def card_str_to_tuple(card_str):
    # card_str --> "none", "unknown", "hearts_A", "spades_X", etc.
    if card_str == "none": suit_str, rank_str = "none", "none"
    elif card_str == "unknown": suit_str, rank_str = "unknown", "unknown"
    else: suit_str, rank_str = card_str.split('_')
    suit_map = {"none": 0, "unknown": 1, "clubs": 2, "diamonds": 3, "hearts": 4, "spades": 5}
    rank_map = {"none": 0, "unknown": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6,
                "7": 7, "8": 8, "9": 9, "X": 10, "J": 11, "Q": 12, "K": 13, "A": 14}
    suit = suit_map[suit_str]
    rank = rank_map[rank_str]
    return (suit, rank)

# ---- poker hand classifier (heuristic) ----
def classify_poker_hand(card_tuples):
    ranks = [rank for _, rank in card_tuples if rank > 1]
    rank_counts = {rank: ranks.count(rank) for rank in set(ranks)}
    is_flush = len(set(suit for suit, _ in card_tuples if suit > 1)) == 1
    sorted_ranks = sorted(ranks)
    # check for straight (including low-Ace straight)
    is_straight = False
    if len(sorted_ranks) >= 5:
        if sorted_ranks == list(range(sorted_ranks[0], sorted_ranks[0] + 5)):
            is_straight = True
        elif sorted_ranks[-4:] == [10, 11, 12, 13] and sorted_ranks[0] == 14:
            is_straight = True

    if is_straight and is_flush and sorted_ranks[-1] == 14:
        return "royal flush"
    elif is_straight and is_flush:
        return "straight flush"
    elif 4 in rank_counts.values():
        return "four of a kind"
    elif 3 in rank_counts.values() and 2 in rank_counts.values():
        return "full house"
    elif is_flush:
        return "flush"
    elif is_straight:
        return "straight"
    elif 3 in rank_counts.values():
        return "three of a kind"
    elif list(rank_counts.values()).count(2) == 2:
        return "two pair"
    elif 2 in rank_counts.values():
        return "one pair"
    elif len(ranks) > 0:
        return "high card"
    else:
        return "nothing"

# ---- poker hand label to index mapping ----
def poker_hand_label_to_index(label):
    label_map = {
        "nothing": 0,
        "high card": 1,
        "one pair": 2,
        "two pair": 3,
        "three of a kind": 4,
        "straight": 5,
        "flush": 6,
        "full house": 7,
        "four of a kind": 8,
        "straight flush": 9,
        "royal flush": 10,
    }
    return label_map[label]

def reverse_poker_hand_index(index):
    index_map = {
        0: "nothing",
        1: "high card",
        2: "one pair",
        3: "two pair",
        4: "three of a kind",
        5: "straight",
        6: "flush",
        7: "full house",
        8: "four of a kind",
        9: "straight flush",
        10: "royal flush",
    }
    return index_map[index]

# ---- convert card strings to tensor ----
def card_strings_to_tensor(card_strings):
    card_tuples = [card_str_to_tuple(cs) for cs in card_strings]
    card_tensor = torch.tensor(card_tuples, dtype=torch.long)
    return card_tensor

if __name__ == "__main__":
    # ---- training set ----
    card_strings = [
        ["hearts_A", "diamonds_K", "clubs_Q", "spades_J", "none"], # high card
        ["hearts_A", "hearts_K", "hearts_Q", "hearts_J", "hearts_X"], # royal flush
        ["spades_9", "spades_8", "spades_7", "spades_6", "spades_5"], # straight flush
        ["diamonds_3", "clubs_3", "hearts_3", "spades_3", "hearts_9"], # four of a kind
        ["hearts_4", "spades_4", "clubs_4", "diamonds_9", "hearts_9"], # full house
        ["clubs_2", "clubs_5", "clubs_9", "clubs_J", "clubs_K"], # flush
        ["hearts_6", "spades_5", "diamonds_4", "clubs_3", "hearts_2"], # straight
        ["spades_7", "hearts_7", "clubs_7", "diamonds_2", "hearts_5"], # three of a kind
        ["diamonds_8", "clubs_8", "hearts_4", "spades_4", "hearts_K"], # two pair
        ["clubs_J", "spades_J", "hearts_3", "diamonds_6", "clubs_9"], # one pair
        ["hearts_5", "spades_X", "clubs_7", "diamonds_4", "hearts_2"], # high card
        ["hearts_5", "spades_5", "none", "none", "none"], # one pair with 3 empty cards
        ["clubs_9", "diamonds_9", "hearts_4", "spades_4", "unknown"], # two pair and 1 unknown card
        ["none", "none", "none", "none", "none"], # empty hand
    ]

    card_tensors = []
    labels = []
    for hand in card_strings:
        card_tensor = card_strings_to_tensor(hand)
        card_tensors.append(card_tensor)
        card_tuples = [card_str_to_tuple(cs) for cs in hand]
        label = classify_poker_hand(card_tuples)
        labels.append(poker_hand_label_to_index(label))
    card_tensors = torch.stack(card_tensors)
    labels = torch.tensor(labels, dtype=torch.long)
    print("Card Tensors:\n", card_tensors)
    print("Labels:\n", labels)
    encoder = CardEncoder()
    # ---- choose classifier ----
    classifier = PokerHandClassifier()
    # ---- choose classifier ----
    encoded_cards = encoder(card_tensors.view(-1, 2))
    encoded_hands = encoded_cards.view(card_tensors.size(0), 5, -1)
    outputs = classifier(encoded_hands)
    print("Classifier Outputs:\n", outputs)


Card Tensors:
 tensor([[[ 4, 14],
         [ 3, 13],
         [ 2, 12],
         [ 5, 11],
         [ 0,  0]],

        [[ 4, 14],
         [ 4, 13],
         [ 4, 12],
         [ 4, 11],
         [ 4, 10]],

        [[ 5,  9],
         [ 5,  8],
         [ 5,  7],
         [ 5,  6],
         [ 5,  5]],

        [[ 3,  3],
         [ 2,  3],
         [ 4,  3],
         [ 5,  3],
         [ 4,  9]],

        [[ 4,  4],
         [ 5,  4],
         [ 2,  4],
         [ 3,  9],
         [ 4,  9]],

        [[ 2,  2],
         [ 2,  5],
         [ 2,  9],
         [ 2, 11],
         [ 2, 13]],

        [[ 4,  6],
         [ 5,  5],
         [ 3,  4],
         [ 2,  3],
         [ 4,  2]],

        [[ 5,  7],
         [ 4,  7],
         [ 2,  7],
         [ 3,  2],
         [ 4,  5]],

        [[ 3,  8],
         [ 2,  8],
         [ 4,  4],
         [ 5,  4],
         [ 4, 13]],

        [[ 2, 11],
         [ 5, 11],
         [ 4,  3],
         [ 3,  6],
         [ 2,  9]],

        [[ 4,  

In [None]:
# ---- train loop ----
def train_loop(model, data_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch_cards, batch_labels in data_loader:
            optimizer.zero_grad()
            encoded_cards = encoder(batch_cards.view(-1, 2))
            encoded_hands = encoded_cards.view(batch_cards.size(0), 5, -1)
            outputs = model(encoded_hands)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Example of setting up a data loader and training
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(card_tensors, labels)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
train_loop(classifier, data_loader, criterion, optimizer, num_epochs=100)


In [5]:
# ---- print tensor hands, expected output, and classifier outputs ----
for i in range(card_tensors.size(0)):
    hand_tensor = card_tensors[i].unsqueeze(0)
    expected_label = labels[i].item()
    encoded_cards = encoder(hand_tensor.view(-1, 2))
    encoded_hand = encoded_cards.view(1, 5, -1)
    outputs = classifier(encoded_hand)
    predicted_label = torch.argmax(outputs, dim=1).item()
    print(f"Hand Tensor: {hand_tensor}")
    print(f"Expected Label: [{expected_label}] {reverse_poker_hand_index(expected_label)}, Predicted Label: [{predicted_label}] {reverse_poker_hand_index(predicted_label)}")
    print(f"Classifier Outputs: {outputs}\n")

Hand Tensor: tensor([[[ 4, 14],
         [ 3, 13],
         [ 2, 12],
         [ 5, 11],
         [ 0,  0]]])
Expected Label: [1] high card, Predicted Label: [1] high card
Classifier Outputs: tensor([[-2.5744,  4.0819, -3.4891, -3.8336, -3.7617, -3.5863, -4.1415, -5.0651,
         -1.9476, -5.8475,  0.4197]], grad_fn=<AddmmBackward0>)

Hand Tensor: tensor([[[ 4, 14],
         [ 4, 13],
         [ 4, 12],
         [ 4, 11],
         [ 4, 10]]])
Expected Label: [10] royal flush, Predicted Label: [10] royal flush
Classifier Outputs: tensor([[-0.8615, -0.0085, -2.1830, -1.9596, -3.2149, -3.1581, -5.0357, -5.0554,
         -2.6598, -4.8811,  3.2399]], grad_fn=<AddmmBackward0>)

Hand Tensor: tensor([[[5, 9],
         [5, 8],
         [5, 7],
         [5, 6],
         [5, 5]]])
Expected Label: [9] straight flush, Predicted Label: [9] straight flush
Classifier Outputs: tensor([[-3.8701, -2.2778, -0.2952, -1.5848, -0.2915, -3.1198, -3.6149, -1.9124,
         -2.1781,  3.7862, -3.7009]], grad_fn

In [None]:
# ---- function that outputs the softmax probability for each class >= 1.00%, formatted to .00% and sorted largest to smallest % ----
def get_class_probabilities(model, hand_tensor):
    model.eval()
    with torch.no_grad():
        encoded_cards = encoder(hand_tensor.view(-1, 2))
        encoded_hand = encoded_cards.view(1, 5, -1)
        outputs = model(encoded_hand)
        probabilities = F.softmax(outputs, dim=1).squeeze(0)
        significant_probs = {reverse_poker_hand_index(i): prob.item() 
                             for i, prob in enumerate(probabilities) if prob.item() >= 0.01}
        sorted_probs = dict(sorted(significant_probs.items(), key=lambda item: item[1], reverse=True))
        formatted_probs = {k: f"{v*100:.2f}%" for k, v in sorted_probs.items()}
        return formatted_probs

# ---- test to see if the model can correctly predict each royal flush hand ----
royal_flush_hands = [
    ["hearts_A", "hearts_K", "hearts_Q", "hearts_J", "hearts_X"],
    ["diamonds_A", "diamonds_K", "diamonds_Q", "diamonds_J", "diamonds_X"],
    ["clubs_A", "clubs_K", "clubs_Q", "clubs_J", "clubs_X"],
    ["spades_A", "spades_K", "spades_Q", "spades_J", "spades_X"],
]

for hand in royal_flush_hands:
    hand_tensor = card_strings_to_tensor(hand).unsqueeze(0)
    encoded_cards = encoder(hand_tensor.view(-1, 2))
    encoded_hand = encoded_cards.view(1, 5, -1)
    outputs = classifier(encoded_hand)
    probabilities = get_class_probabilities(classifier, hand_tensor)
    print(f"\nhand: {hand}")
    print(f"class probabilities: {probabilities}")


In [None]:
royal_flush_hands = [
    ["none", "hearts_A", "diamonds_K", "clubs_Q", "spades_J"],
]

for hand in royal_flush_hands:
    hand_tensor = card_strings_to_tensor(hand).unsqueeze(0)
    encoded_cards = encoder(hand_tensor.view(-1, 2))
    encoded_hand = encoded_cards.view(1, 5, -1)
    outputs = classifier(encoded_hand)
    probabilities = get_class_probabilities(classifier, hand_tensor)
    print(f"\nhand: {hand}")
    print(f"class probabilities: {probabilities}")