In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as tg
from torch_geometric.nn import GINEConv
from poker_embeddings.poker_utils.model import plot_train_loss, benchmark_dataloader
from poker_embeddings.poker_utils.constants import DECK_DICT
from poker_embeddings.poker_utils.datasets import UCIrvineDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
X = pd.read_csv("../data/uc_irvine/X.csv")
y = pd.read_csv("../data/uc_irvine/y.csv")
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.6, random_state=29, stratify=y['CLASS']
    )
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_val = X_val.reset_index(drop=True)
y_val = y_val.reset_index(drop=True)

train_dataset = UCIrvineDataset(X_train, y_train, add_random_cards=True, use_card_ids=True,
                           graph=True, normalize_x=True)
val_dataset = UCIrvineDataset(X_val, y_val, add_random_cards=True, use_card_ids=True,
                           graph=True, normalize_x=True)

In [None]:
benchmark_dataloader(train_dataset, batch_sizes=[128,256,512], num_workers_list=[4,8,10], num_runs=1, graph=True)

Dataset size: 410004 samples


In [None]:
trainloader = tg.loader.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    num_workers=10,
    pin_memory=True
    )
valloader = tg.loader.DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=10,
    pin_memory=True
    )

In [66]:

class CardGNN(nn.Module):
    def __init__(self, card_emb_dim=16, hidden_dim=16, out_dim=16, edge_attr_dim=2):
        super().__init__()
        self.card_embedder = nn.Embedding(53, card_emb_dim, padding_idx=52)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.card_emb_projector = nn.Linear(card_emb_dim, hidden_dim)
        self.gine1 = GINEConv(nn=self.node_mlp, edge_dim=edge_attr_dim)
        self.gine2 = GINEConv(nn=self.node_mlp, edge_dim=edge_attr_dim)
        self.final = nn.Linear(hidden_dim, out_dim)
        self.output_layer = nn.Linear(out_dim, 10)
    def forward(self, data):
        card_ids = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        x = self.card_embedder(card_ids)
        x = self.card_emb_projector(x)

        x = self.gine1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.gine2(x, edge_index, edge_attr)

        x = self.final(x)
        graphs_pooled = tg.utils.scatter(x, data.batch, dim=0, reduce='mean')
        return self.output_layer(graphs_pooled)


In [None]:
def train_model(model, trainloader, optimizer, scheduler=None, device=None,
                valloader=None, epochs=50, leftoff=0, save=True):

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    class_weights = torch.load("../model_weights/class_weights.pt", weights_only=True).to(device)
    for epoch in range(epochs):
        tot_train_loss = 0
        correct_train = 0
        total_train = 0

        model.train()
        for batch_data in trainloader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            logits = model(batch_data)

            batch_loss = F.cross_entropy(logits, batch_data.y, weight=class_weights)
            batch_loss.backward()
            optimizer.step()

            tot_train_loss += batch_loss.item()
            preds = logits.argmax(dim=1)
            correct_train += (preds == batch_data.y).sum().item()
            total_train += batch_data.y.size(0)

        avg_train_loss = tot_train_loss / len(trainloader)
        train_acc = correct_train / total_train
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_acc)

        if valloader is not None:
            model.eval()
            tot_val_loss = 0
            correct_val = 0
            total_val = 0

            with torch.no_grad():
                for batch_data in valloader:
                    batch_data = batch_data.to(device)
                    logits = model(batch_data)
                    batch_loss = F.cross_entropy(logits, batch_data.y, weight=class_weights)

                    tot_val_loss += batch_loss.item()
                    preds = logits.argmax(dim=1)
                    correct_val += (preds == batch_data.y).sum().item()
                    total_val += batch_data.y.size(0)

            avg_val_loss = tot_val_loss / len(valloader)
            val_acc = correct_val / total_val
            val_losses.append(avg_val_loss)
            val_accuracies.append(val_acc)

        if valloader is not None:
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
        else:
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")
        if save:
            if (epoch + 1) % 5 == 0:
                torch.save(model.state_dict(), f"../model_weights/hand_strength_predictor{leftoff+epoch+1}.pth")

        if scheduler is not None:
                scheduler.step()

    if valloader is not None:
        return {"train_loss":train_losses,
                "val_loss":val_losses,
                "train_accuracy":train_accuracies,
                "val_accuracy":val_accuracies}
    else:
        return {'train_loss':train_losses, "train_accuracy":train_accuracies}

In [None]:
model = CardGNN().to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

In [None]:
res = train_model(
        model=model,
        trainloader=trainloader,
        valloader=None,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=1,
        leftoff=0,
        save=True
        )