In [17]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from sklearn.model_selection import train_test_split
from trump_dataset import TrumpDataset
from torch.utils.data import DataLoader

In [12]:


path_to_data = Path('C:\\Users\\niede\\PycharmProjects\\jass-kit-py\\')
# Load CSV
data = pd.read_csv(path_to_data / '2018_10_18_trump.csv')

# Column names
cards = [
    'DA','DK','DQ','DJ','D10','D9','D8','D7','D6',
    'HA','HK','HQ','HJ','H10','H9','H8','H7','H6',
    'SA','SK','SQ','SJ','S10','S9','S8','S7','S6',
    'CA','CK','CQ','CJ','C10','C9','C8','C7','C6'
]

forehand = ['FH']
user = ['user']
trump = ['trump']

# Assign column names
data.columns = cards + forehand + user + trump

# Drop the user column
data = data.drop(columns=['user'])

# Convert trump column to category
data.trump = data.trump.astype('category')

# Convert card columns and forehand column to boolean
data[cards + forehand] = data[cards + forehand].astype(bool)

# Rename trump categories (merge 6 and 10 as PUSH)
data.trump = data.trump.cat.rename_categories({
    0: 'DIAMONDS',
    1: 'HEARTS',
    2: 'SPADES',
    3: 'CLUBS',
    4: 'OBE_ABE',
    5: 'UNE_UFE',
    6: 'PUSH',
    10: 'PUSH'
})

data.head()


Unnamed: 0,DA,DK,DQ,DJ,D10,D9,D8,D7,D6,HA,...,CK,CQ,CJ,C10,C9,C8,C7,C6,FH,trump
0,False,False,False,False,False,False,False,False,True,True,...,False,False,True,False,False,False,True,False,False,UNE_UFE
1,True,False,False,True,False,False,False,False,False,False,...,False,True,False,False,False,False,True,True,False,PUSH
2,False,False,False,False,False,False,False,False,False,True,...,False,False,False,True,True,False,False,False,False,UNE_UFE
3,False,True,False,False,False,False,False,False,True,True,...,False,False,True,False,False,False,False,False,True,OBE_ABE
4,False,False,True,False,False,False,True,False,True,False,...,False,False,False,False,False,False,False,False,True,UNE_UFE


In [13]:
# Verify that each hand contains exactly 9 cards
card_counts = data[cards].sum(axis=1)

# Check if all rows have exactly 9 True values
all_nine = (card_counts == 9).all()

print("All hands contain exactly 9 cards:", all_nine)

# Optionally show rows that do not meet this condition
invalid_rows = data.loc[card_counts != 9]
print("Number of invalid hands:", len(invalid_rows))
if len(invalid_rows) > 0:
    display(invalid_rows)


All hands contain exactly 9 cards: True
Number of invalid hands: 0


In [14]:
# 1) Extract features and labels
X = data[cards + forehand].astype(int).values    # convert bool → int for NN
y = data.trump.cat.codes.values                  # categorical → numeric 0..N

# Optional: store mapping for inference later
label_mapping = dict(enumerate(data.trump.cat.categories))
print("Label mapping:", label_mapping)

# 2) Train (70 percent) + temp split (30 percent)
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.30, random_state=42, stratify=y
)

# 3) Validation (15 percent) + Test (15 percent)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp
)

# Show shapes
print("Train:", X_train.shape, y_train.shape)
print("Val:", X_val.shape, y_val.shape)
print("Test:", X_test.shape, y_test.shape)

Label mapping: {0: 'DIAMONDS', 1: 'HEARTS', 2: 'SPADES', 3: 'CLUBS', 4: 'OBE_ABE', 5: 'UNE_UFE', 6: 'PUSH'}
Train: (251876, 37) (251876,)
Val: (53974, 37) (53974,)
Test: (53974, 37) (53974,)


In [16]:
train_ds = TrumpDataset(X_train, y_train)
val_ds   = TrumpDataset(X_val, y_val)
test_ds  = TrumpDataset(X_test, y_test)


In [18]:
# Create DataLoaders
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=128, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False)