In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as tg

from poker_embeddings.poker_utils.model import plot_train_loss, benchmark_dataloader
from poker_embeddings.poker_utils.constants import DECK_DICT
from poker_embeddings.poker_utils.datasets import UCIrvineDataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
X = pd.read_csv("../data/uc_irvine/X.csv")
y = pd.read_csv("../data/uc_irvine/y.csv")
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.6, random_state=29, stratify=y['CLASS']
    )
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_val = X_val.reset_index(drop=True)
y_val = y_val.reset_index(drop=True)

In [None]:
train_dataset = UCIrvineDataset(X_train, y_train, add_random_cards=True, use_card_ids=True,
                           graph=True, normalize_x=True)
val_dataset = UCIrvineDataset(X_val, y_val, add_random_cards=True, use_card_ids=True,
                           graph=True, normalize_x=True)


In [None]:
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv


In [None]:
trainloader = tg.loader.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    num_workers=10,
    pin_memory=True
    )
valloader = tg.loader.DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=10,
    pin_memory=True
    )

In [6]:
def create_board_edges(card_ids):
    ranks = card_ids // 4
    suits = card_ids % 4
    num_cards = card_ids.size(0)
    idx_i = card_ids.unsqueeze(1).expand(num_cards, num_cards)
    idx_j = card_ids.unsqueeze(0).expand(num_cards, num_cards)

    ranks_i = ranks.unsqueeze(1).expand(num_cards, num_cards)
    ranks_j = ranks.unsqueeze(0).expand(num_cards, num_cards)

    suits_i = suits.unsqueeze(1).expand(num_cards, num_cards)
    suits_j = suits.unsqueeze(0).expand(num_cards, num_cards)

    suit_match = suits_i == suits_j
    rank_close = (torch.abs(ranks_i - ranks_j) <= 1) | (torch.abs(ranks_i - ranks_j) == 12)

    not_self = idx_i != idx_j

    edge_mask = (suit_match | rank_close) & not_self

    edge_index = edge_mask.nonzero(as_tuple=False).t().contiguous()
    return edge_index

In [18]:
from poker_embeddings.poker_utils.constants import DECK_DICT

In [20]:
create_board_edges(torch.tensor([17,13]))

tensor([[0, 1],
        [1, 0]])

In [21]:
create_board_edges(torch.tensor([17,12]))

tensor([[0, 1],
        [1, 0]])

In [19]:
DECK_DICT

{0: '2c',
 1: '2d',
 2: '2h',
 3: '2s',
 4: '3c',
 5: '3d',
 6: '3h',
 7: '3s',
 8: '4c',
 9: '4d',
 10: '4h',
 11: '4s',
 12: '5c',
 13: '5d',
 14: '5h',
 15: '5s',
 16: '6c',
 17: '6d',
 18: '6h',
 19: '6s',
 20: '7c',
 21: '7d',
 22: '7h',
 23: '7s',
 24: '8c',
 25: '8d',
 26: '8h',
 27: '8s',
 28: '9c',
 29: '9d',
 30: '9h',
 31: '9s',
 32: 'Tc',
 33: 'Td',
 34: 'Th',
 35: 'Ts',
 36: 'Jc',
 37: 'Jd',
 38: 'Jh',
 39: 'Js',
 40: 'Qc',
 41: 'Qd',
 42: 'Qh',
 43: 'Qs',
 44: 'Kc',
 45: 'Kd',
 46: 'Kh',
 47: 'Ks',
 48: 'Ac',
 49: 'Ad',
 50: 'Ah',
 51: 'As'}

In [None]:
def create_board_graph(cards_id, y):
    if use_card_ids:
        x = cards_id.unsqueeze(1)
    else:
        rank = cards_id // 4
        suit = cards_id % 4
        if normalize_x:
            rank = rank / 12.0
            suit = suit / 3.0
        x = torch.stack([rank, suit], dim=1)

    edge_index = create_board_edges(cards_id)

    data = tg.data.Data(x=x, edge_index=edge_index, y=y)
    return data

In [5]:
create_board_graph(torch.tensor([1,2,3,4]), torch.randn(4, 10), 1)

Data(x=[4, 10], edge_index=[2, 12], y=1)

In [None]:
class GraphGenerator(nn.Module):
    def __init__(self, card_embedding_dim, in_channels, out_channels):
        super(GraphGenerator, self).__init__()
        self.card_embedder = nn.Embedding(53, card_embedding_dim, padding_idx=52)

    def forward(self, board_graphs)

In [None]:
class CardGNN(torch.nn.Module):
    def __init__(self, card_embeddings=None, in_dim=2, hidden_dim=16, out_dim=16, freeze_emb=True):
        super().__init__()
        self.freeze_emb = freeze_emb
        if card_embeddings is not None:
            padding_row = torch.zeros((1, card_embeddings.size(1)))
            card_embeddings_padded = torch.cat([card_embeddings, padding_row], dim=0)
            self.card_embedder = nn.Embedding.from_pretrained(
                card_embeddings_padded, padding_idx=52, freeze=freeze_emb)

        if card_embeddings is not None:
            self.gcn1 = GCNConv(card_embeddings.size(1), hidden_dim, add_self_loops=True)
            self.use_card_emb = True
        else:
            self.gcn1 = GCNConv(in_dim, hidden_dim, add_self_loops=True)
            self.use_card_emb = False
        self.gcn2 = GCNConv(hidden_dim, out_dim, add_self_loops=True)
        self.out_dim = out_dim

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        if self.use_card_emb:
            x = self.card_embedder(x.squeeze(-1))

        x = self.gcn1(x, edge_index)
        x = torch.relu(x)
        x = self.gcn2(x, edge_index)
        return x