In [67]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset


In [68]:
class HandDataset(Dataset):
    
    def __init__(self, data_path, columns: list):
        cols = ['hand'] + columns
        self.feature_columns = columns
        self.data = pd.read_csv(data_path, usecols=cols)
        if 'hand_type' in columns:
            self.data['hand_type'] = self.data['hand_type'].map({'pair': 0, 'suited': 1, 'offsuit': 2})
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        hand = row['hand']
        X_np = row[self.feature_columns].values.astype(np.float32)
        return hand, torch.tensor(X_np, dtype=torch.float32)
        
        
        

In [69]:
data = HandDataset('data/processed/full_data.csv', columns=['hand_type','flop_equity','turn_equity','river_equity'])

In [70]:
dataloader = DataLoader(data, batch_size=10, shuffle=True)

In [71]:
next(iter(dataloader))

[('64o', '42s', 'J4s', 'K6o', 'J9s', '87o', 'T5o', 'T2s', '92s', '94s'),
 tensor([[2.0000, 0.2966, 0.3392, 0.3805],
         [1.0000, 0.2475, 0.3098, 0.3704],
         [1.0000, 0.4713, 0.4791, 0.4890],
         [2.0000, 0.6001, 0.5680, 0.5442],
         [1.0000, 0.5277, 0.5427, 0.5543],
         [2.0000, 0.3881, 0.4238, 0.4541],
         [2.0000, 0.4311, 0.4332, 0.4390],
         [1.0000, 0.4074, 0.4312, 0.4478],
         [1.0000, 0.3659, 0.4017, 0.4255],
         [1.0000, 0.3830, 0.4094, 0.4370]])]