In [86]:
import json
import gzip
import glob
import numpy as np
import pandas as pd
from pathlib import Path
from cards import INDEX_TO_CARD, CARD_TO_INDEX, CARD_LIST

# project modules will be added later
# from src.jass_parser import extract_moves


In [87]:
import importlib
import state_encoder
import card_dataset
import cards
importlib.reload(state_encoder)
importlib.reload(card_dataset)
importlib.reload(cards)


import state_encoder
import card_dataset
import cards


In [88]:
data_dir = Path("../data2")
data_files = sorted(data_dir.glob("*.txt"))
len(data_files)

2

In [89]:
def load_game_file(path):
    with open(path, "r", encoding="utf8") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

games = []
for f in data_files:
    games.extend(load_game_file(f))

len(games)

200000

In [90]:
ALL_CARDS = [
    s+r
    for s in ["S","H","D","C"]
    for r in ["6","7","8","9","10","J","Q","K","A"]
]

In [91]:
suits = {"S":0, "H":1, "D":2, "C":3}
ranks = {"6":0, "7":1, "8":2, "9":3, "10":4, "J":5, "Q":6, "K":7, "A":8}

CARDS = [
    'DA','DK','DQ','DJ','D10','D9','D8','D7','D6',
    'HA','HK','HQ','HJ','H10','H9','H8','H7','H6',
    'SA','SK','SQ','SJ','S10','S9','S8','S7','S6',
    'CA','CK','CQ','CJ','C10','C9','C8','C7','C6'
]
CARD_TO_IDX = {c: i for i, c in enumerate(CARDS)}
N_CARDS = len(CARDS)

def rank(card):
    return ranks[card[1:]]

def suit(card):
    return {"D":0, "H":1, "S":2, "C":3}[card[0]]

def valid_moves(hand, table, trump):
    if not table:
        return hand[:]

    lead_suit = suit(table[0])
    same_suit = [c for c in hand if suit(c) == lead_suit]

    if same_suit:
        return same_suit

    return hand[:]

def extract_samples_from_game(game_json):
    samples = []

    g = game_json["game"]
    trump = g["trump"]
    forehand = g["forehand"]
    tricks = g["tricks"]

    # First reconstruct who owned which card using trick["first"]
    ownership = {}
    for trick in tricks:
        first = trick["first"]
        for i, card in enumerate(trick["cards"]):
            player = (first + i) % 4
            ownership[card] = player

    # Build hands per player
    hands = {p: [] for p in range(4)}
    for card, p in ownership.items():
        hands[p].append(card)
    for p in range(4):
        hands[p].sort()

    # Now simulate the game to build samples
    played_cards = []
    hands_sim = {p: list(hands[p]) for p in range(4)}

    for trick_idx, trick in enumerate(tricks):
        table = []
        first = trick["first"]

        for play_idx, card in enumerate(trick["cards"]):
            player = (first + play_idx) % 4

            vm = valid_moves(hands_sim[player], table, trump)

            samples.append({
                "trump": trump,
                "forehand": forehand,
                "trick": trick_idx,
                "player": player,
                "table": list(table),
                "hand": list(hands_sim[player]),
                "valid_moves": list(vm),
                "action": card,
                "played_cards": list(played_cards),
            })

            hands_sim[player].remove(card)
            table.append(card)
            played_cards.append(card)

    return samples

In [92]:
all_samples = []
for g in games:
    all_samples.extend(extract_samples_from_game(g))

len(all_samples)

7200000

In [93]:
import numpy as np
import torch

N_TRUMP = 7    # DIAMONDS HEARTS SPADES CLUBS OBE_ABE UNE_UFE PUSH
N_PLAYERS = 4
N_TRICKS = 9

# Mapping trump int to index 0 to 6
TRUMP_IDX = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 10:6}

FEATURE_SIZE = N_CARDS * 3 + N_TRUMP + N_PLAYERS + N_TRICKS
FEATURE_SIZE


128

In [94]:
def encode_sample(sample):
    trump = sample["trump"]
    player = sample["player"]
    trick = sample["trick"]
    hand = sample["hand"]
    table = sample["table"]
    played = sample["played_cards"]
    valid_moves = sample["valid_moves"]
    action = sample["action"]

    x = np.zeros(FEATURE_SIZE, dtype=np.float32)
    offset = 0

    # hand
    for c in hand:
        x[offset + CARD_TO_IDX[c]] = 1.0
    offset += N_CARDS

    # table cards in current trick
    for c in table:
        x[offset + CARD_TO_IDX[c]] = 1.0
    offset += N_CARDS

    # already played cards
    for c in played:
        x[offset + CARD_TO_IDX[c]] = 1.0
    offset += N_CARDS

    # trump one hot
    t_idx = TRUMP_IDX.get(trump, 6)
    x[offset + t_idx] = 1.0
    offset += N_TRUMP

    # player one hot
    x[offset + player] = 1.0
    offset += N_PLAYERS

    # trick index one hot
    x[offset + trick] = 1.0

    # action index
    y = CARD_TO_IDX[action]

    # legal move mask
    legal_mask = np.zeros(N_CARDS, dtype=np.float32)
    for c in valid_moves:
        legal_mask[CARD_TO_IDX[c]] = 1.0

    # convert to tensors
    state = torch.from_numpy(x)
    action_idx = torch.tensor(y, dtype=torch.long)
    legal_mask_t = torch.from_numpy(legal_mask)

    return state, action_idx, legal_mask_t


In [95]:
sample = all_samples[0]
state, action_idx, legal_mask = encode_sample(sample)

state.shape, action_idx, legal_mask.shape, legal_mask.sum()


(torch.Size([128]), tensor(9), torch.Size([36]), tensor(9.))

In [96]:
torch.save(all_samples, "all_samples_raw.pt")

In [97]:
len(all_samples)

7200000

In [98]:
from card_dataset import CardPlayingDataset

dataset = CardPlayingDataset(all_samples)
len(dataset)

7200000

In [99]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=512, shuffle=True)

In [100]:
cleaned_samples = []
illegal_count = 0
total_count = 0

for s in all_samples:
    total_count += 1

    # Skip states where the chosen action is not legal
    if s["action"] not in s["valid_moves"]:
        illegal_count += 1
        continue

    # Also skip if there are no legal moves at all (should not happen but just in case)
    if not s["valid_moves"]:
        illegal_count += 1
        continue

    cleaned_samples.append(s)

print(f"Total samples:   {total_count}")
print(f"Illegal samples: {illegal_count}")
print(f"Kept samples:    {len(cleaned_samples)} "
      f"({len(cleaned_samples)/total_count*100:.2f} percent)")

all_samples = cleaned_samples


Total samples:   7200000
Illegal samples: 376620
Kept samples:    6823380 (94.77 percent)


In [101]:
import random

# Shuffle for randomness
random.shuffle(all_samples)

n = len(all_samples)
val_size = n // 10  # 10 percent validation
train_samples = all_samples[val_size:]
val_samples   = all_samples[:val_size]

print("Train:", len(train_samples), "Validation:", len(val_samples))

Train: 6141042 Validation: 682338


In [102]:
from card_model import CardPolicyNetwork
from torch.utils.data import DataLoader
from card_dataset import CardPlayingDataset
import torch
import torch.nn as nn
from tqdm import tqdm


train_dataset = CardPlayingDataset(train_samples)
val_dataset   = CardPlayingDataset(val_samples)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=1024, shuffle=False)

In [103]:
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_samples = 0

    for X, legal_mask, y in train_loader:
        X = X.to(device)
        legal_mask = legal_mask.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        logits = model(X)

        # mask illegal moves
        masked_logits = logits + (legal_mask == 0) * -1e9

        loss = criterion(masked_logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X.size(0)
        total_samples += X.size(0)

    return total_loss / total_samples

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X, legal_mask, y in loader:
            X = X.to(device)
            legal_mask = legal_mask.to(device)
            y = y.to(device)

            logits = model(X)
            masked_logits = logits + (legal_mask == 0) * -1e9

            preds = masked_logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


In [104]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CardPolicyNetwork().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=3, factor=0.5
)

num_epochs = 10
best_val_acc = 0.0
best_state_dict = None

train_history = []
val_history = []

for epoch in range(num_epochs):

    train_loss = train_one_epoch(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device
    )

    val_acc = evaluate(
        model=model,
        loader=val_loader,
        device=device
    )

    train_history.append(train_loss)
    val_history.append(val_acc)

    print(f"Epoch {epoch+1:02d}/{num_epochs}  "
          f"Train Loss: {train_loss:.4f}  "
          f"Val Acc: {val_acc:.4f}")

    # Step scheduler based on validation accuracy
    scheduler.step(val_acc)

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state_dict = model.state_dict()

print("\nBest Validation Accuracy:", best_val_acc)


Epoch 01/10  Train Loss: 0.6957  Val Acc: 0.7150
Epoch 02/10  Train Loss: 0.6514  Val Acc: 0.7221
Epoch 03/10  Train Loss: 0.6398  Val Acc: 0.7243
Epoch 04/10  Train Loss: 0.6331  Val Acc: 0.7257
Epoch 05/10  Train Loss: 0.6283  Val Acc: 0.7265
Epoch 06/10  Train Loss: 0.6245  Val Acc: 0.7267
Epoch 07/10  Train Loss: 0.6214  Val Acc: 0.7268
Epoch 08/10  Train Loss: 0.6186  Val Acc: 0.7261
Epoch 09/10  Train Loss: 0.6163  Val Acc: 0.7259
Epoch 10/10  Train Loss: 0.6142  Val Acc: 0.7254

Best Validation Accuracy: 0.726815742344703


In [105]:
# Load best model parameters back into the model
model.load_state_dict(best_state_dict)

# Save model weights
torch.save(model.state_dict(), "card_model.pt")
print("Saved model → card_model.pt")

# Save card index mapping (for inference)
import pickle
with open("card_labels.pkl", "wb") as f:
    pickle.dump({
        "CARD_LIST": CARD_LIST,
        "CARD_TO_INDEX": CARD_TO_INDEX,
        "INDEX_TO_CARD": INDEX_TO_CARD
    }, f)

print("Saved card_model.pt and card_labels.pkl")

Saved model → card_model.pt
Saved card_model.pt and card_labels.pkl
