In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

In [76]:
from poker_embeddings.models.card import HandGNN

In [77]:
class HandScorer(nn.Module):
    def __init__(self,
                 rank_embedding_dim=8,
                 suit_embedding_dim=8,
                 hidden_dim=16,
                 edge_attr_dim=2,
                 node_mlp_layers=2,
                 gnn_layers=2,
                 reduction='mean',
                 final_hidden_dim=64,
                 out_dim=16):
        super().__init__()
        self.hand_encoder = HandGNN(
            rank_embedding_dim=rank_embedding_dim,
            suit_embedding_dim=suit_embedding_dim,
            hidden_dim=hidden_dim,
            edge_attr_dim=edge_attr_dim,
            node_mlp_layers=node_mlp_layers,
            gnn_layers=gnn_layers,
            reduction=reduction
            )

        self.final = nn.Sequential(
            nn.Linear(hidden_dim, final_hidden_dim),
            nn.ReLU(),
            nn.Linear(final_hidden_dim, final_hidden_dim//2),
            nn.ReLU(),
            nn.Linear(final_hidden_dim//2, out_dim)
        )
        self.output_layer = nn.Linear(out_dim, 1)
    def forward(self, data):
        x = self.hand_encoder(data)
        x = self.final(x)
        return self.output_layer(x)


In [78]:
full_state_dict = torch.load("model_weights/hand_rank_model/hand_rank_predictor225.pth")
model = HandScorer()
encoder_weights = {k.replace("hand_encoder.", ""): v
                   for k, v in full_state_dict.items()
                   if k.startswith("hand_encoder.")}

model.hand_encoder.load_state_dict(encoder_weights)
for param in model.hand_encoder.parameters():
    param.requires_grad = False

In [79]:
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

hand_encoder.rank_embedder.weight: requires_grad=False
hand_encoder.suit_embedder.weight: requires_grad=False
hand_encoder.node_mlp_layers.0.weight: requires_grad=False
hand_encoder.node_mlp_layers.0.bias: requires_grad=False
hand_encoder.node_mlp_layers.2.weight: requires_grad=False
hand_encoder.node_mlp_layers.2.bias: requires_grad=False
hand_encoder.card_emb_projector.weight: requires_grad=False
hand_encoder.card_emb_projector.bias: requires_grad=False
hand_encoder.gnn_layers.0.lin.weight: requires_grad=False
hand_encoder.gnn_layers.0.lin.bias: requires_grad=False
hand_encoder.gnn_layers.1.lin.weight: requires_grad=False
hand_encoder.gnn_layers.1.lin.bias: requires_grad=False
final.0.weight: requires_grad=True
final.0.bias: requires_grad=True
final.2.weight: requires_grad=True
final.2.bias: requires_grad=True
final.4.weight: requires_grad=True
final.4.bias: requires_grad=True
output_layer.weight: requires_grad=True
output_layer.bias: requires_grad=True


In [80]:
from torch.utils.data import Dataset

In [81]:
from poker_embeddings.poker_utils.datasets import UCIrvineDataset
import random

In [82]:
class PairwiseHandDataset(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.length = len(base_dataset)

    def __len__(self):
        return self.length

    def __getitem__(self, idx1):
        x1 = self.base_dataset[idx1]
        ix2 = random.randint(0, self.length - 1)
        while ix2 == idx1:
            ix2 = random.randint(0, self.length - 1)
        x2 = self.base_dataset[ix2]
        score1 = x1.y[0,1].item()
        score2 = x2.y[0,1].item()
        label = 1 if score1 < score2 else -1 # smaller is stronger in Treys
        return x1, x2, label

In [83]:
X = pd.read_csv("data/uc_irvine/X.csv")
y = pd.read_csv("data/uc_irvine/y.csv")

In [84]:
base_data = UCIrvineDataset(X, y)
dataset = PairwiseHandDataset(base_data)

In [85]:
import torch_geometric as tg

In [86]:
hmm = dataset[0]

In [87]:
hmm

(Data(x=[6], edge_index=[2, 30], edge_attr=[30, 2], y=[1, 2]),
 Data(x=[5], edge_index=[2, 20], edge_attr=[20, 2], y=[1, 2]),
 1)

In [88]:
trainloader = tg.loader.DataLoader(dataset, batch_size=64)

In [90]:
what = next(iter(trainloader))

In [93]:
what[-1].dtype

torch.int64