In [16]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from uc_data import UCIrvineDataset
import torch.nn.functional as F
from poker_utils.model import plot_train_loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
train_dataset = UCIrvineDataset(train=True)
val_dataset = UCIrvineDataset(train=False)
trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

In [17]:
class HandSuitEncoder(nn.Module):
    def __init__(self, preflop_embeddings, suit_dim=4, output_dim=16, freeze_hand_emb=True):
        super().__init__()
        self.suit_embedder = nn.Embedding(4, suit_dim)
        self.hand_embedder = nn.Embedding.from_pretrained(preflop_embeddings, freeze=freeze_hand_emb)
        self.model = nn.Sequential(
            nn.Linear(16 + suit_dim*2, 32),
            nn.LeakyReLU(),
            nn.Linear(32, output_dim)
        )
        
    def forward(self, hand_id, suit1_id, suit2_id):
        preflop_emb = self.hand_embedder(hand_id)
        suit1_emb = self.suit_embedder(suit1_id)
        suit2_emb = self.suit_embedder(suit2_id)
        suit_vec = torch.cat([suit1_emb, suit2_emb], dim=1)
        x = torch.cat([preflop_emb, suit_vec], dim=1)
        return self.model(x)
    

In [None]:
class TransformerBoardEncoder(nn.Module):
    def __init__(self, card_dim=16, board_dim=32, num_heads=4, num_layers=2):
        super().__init__()

        self.card_embedder = nn.Embedding(52, card_dim)
        self.position_embedder = nn.Embedding(6, card_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, card_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=card_dim,
            nhead=num_heads,
            dim_feedforward=64,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_proj = nn.Linear(card_dim, board_dim)

    def forward(self, board_ids):
        """
        card_ids: [B, L] where L <= 5, -1 for padding
        """
        B, L = board_ids.shape

        mask = (board_ids != -1)
        safe_ids = board_ids.clone()
        safe_ids[~mask] = 0

        card_embs = self.card_embedder(safe_ids)

        cls_token = self.cls_token.expand(B, 1, -1)
        x = torch.cat([cls_token, card_embs], dim=1)

        pos_ids = torch.arange(L + 1, device=board_ids.device).unsqueeze(0).expand(B, L + 1)
        pos_embs = self.position_embedder(pos_ids)
        x = x + pos_embs

        pad_mask = torch.cat([torch.zeros(B, 1, dtype=torch.bool), ~mask], dim=1)

        out = self.transformer(x, src_key_padding_mask=pad_mask)

        cls_out = out[:, 0, :]
        return self.output_proj(cls_out) 

In [19]:
class HandStrengthPredictor(nn.Module):
    def __init__(self, suit_dim=4, hand_encoder_output_dim=16,
                 card_dim=16, board_dim=32, num_heads=4, num_layers=2, freeze_hand_emb=True):
        super().__init__()
        self.hand_encoder = HandSuitEncoder(
            torch.load("model_weights/preflop_embeddings.pt", weights_only=True).float(),
            suit_dim=suit_dim,
            output_dim=hand_encoder_output_dim,
            freeze_hand_emb=freeze_hand_emb
            )
        self.board_encoder = TransformerBoardEncoder(
            card_dim=card_dim, board_dim=board_dim, num_heads=num_heads, num_layers=num_layers
            )
        self.model = nn.Sequential(
            nn.Linear(hand_encoder_output_dim+board_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 10)
        )
    
    def forward(self, hand_id, suit1_id, suit2_id, board_ids):
        hand_encoded = self.hand_encoder(hand_id, suit1_id, suit2_id)
        board_encoded = self.board_encoder(board_ids)
        x = torch.cat([hand_encoded, board_encoded], dim=1)
        return self.model(x)

In [None]:
def train_model(model, trainloader, valloader, optimizer, device, epochs=50):
    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        tot_train_loss = 0
        model.train()
        for batch in trainloader:
            optimizer.zero_grad()
            hand_id, suit1_id, suit2_id, board, strength = [x.to(device) for x in batch]
            strength_pred = model(hand_id,suit1_id, suit2_id, board)
            batch_loss = F.cross_entropy(strength_pred, strength)
            batch_loss.backward()
            optimizer.step()
            tot_train_loss += batch_loss.item()
        avg_train_loss = tot_train_loss / len(trainloader)
        train_losses.append(avg_train_loss)
    
        model.eval()
        with torch.no_grad():
            tot_val_loss = 0
            for batch in valloader:
                hand_id, suit1_id, suit2_id, board, strength = [x.to(device) for x in batch]
                strength_pred = model(hand_id,suit1_id, suit2_id, board)
                batch_loss = F.cross_entropy(strength_pred, strength)
                tot_train_loss += batch_loss.item()
            avg_val_loss = tot_train_loss / len(valloader)
            val_losses.append(avg_val_loss)
        
            if epoch % (epochs//5) == 0:
                print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
                
    return train_losses, val_losses

In [21]:
model = HandStrengthPredictor(
        suit_dim=8,
        hand_encoder_output_dim=16,
        card_dim=64,
        board_dim=64,
        num_heads=4, 
        num_layers=4, 
        freeze_hand_emb=False
        )
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
train_losses, val_losses = train_model(
        model=model,
        trainloader=trainloader,
        valloader=valloader,
        optimizer=optimizer,
        epochs=10
        )

  output = torch._nested_tensor_from_mask(


KeyboardInterrupt: 

In [None]:
plot_train_loss(train_losses, val_losses)