In [1]:
%load_ext autoreload
%autoreload 2
import tichu_rustipy as tr
import numpy as np
from IPython.display import display, HTML
import pickle

def display_colored_hand(hand_str):
    # Convert ANSI escape codes to HTML
    hand_str = (hand_str
        .replace('\x1b[31m', '<span style="color: red">')
        .replace('\x1b[32m', '<span style="color: green">')
        .replace('\x1b[33m', '<span style="color: yellow">')
        .replace('\x1b[34m', '<span style="color: dodgerblue ">')
        .replace('\x1b[0m', '</span>')
    )
    display(HTML(hand_str))
def save_dict(dictionary, filename):
    with open(filename, 'wb') as file:
        pickle.dump(dictionary, file)

# Loading the dictionary from a file
def load_dict(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)

In [2]:
db = tr.BSWSimple("../tichu_rust/bsw_filtered.db")
display_colored_hand(tr.print_hand(db.get_round(0)[0].first_14))
db.len()

400881

In [3]:
#db_as_np = tr.bulk_transform_db_into_np56_array(db)
#np.save("db_as_np_filtered", db_as_np)
#db_as_np = np.load("db_as_np.npy")
db_as_np = np.load("db_as_np_filtered.npy")

## Later, when using, need to reshape: db_as_np[0].reshape(4, 14)
#print(db_as_np.shape)

In [None]:
# Calculate labels and mapping from incoming card tuples to index
incoming_card_combination_to_label_num = {}
incoming_card_labels = np.zeros(len(db_as_np), dtype=np.uint16)
i = 0
while i < len(db_as_np):
    prh_round = db.get_round(i//4)
    for j in range(4):
        incoming_card_combo = tr.prh_to_incoming_cards(prh_round[j])
        if not incoming_card_combo in incoming_card_combination_to_label_num:
            incoming_card_combination_to_label_num[incoming_card_combo] = len(incoming_card_combination_to_label_num)
        incoming_card_labels[i+j] = incoming_card_combination_to_label_num[incoming_card_combo]
    i += 4
label_num_to_incoming_card_combination = {value: key for key,value in incoming_card_combination_to_label_num.items()}
# Save
np.save("incoming_card_labels_filtered", incoming_card_labels)
save_dict(incoming_card_combination_to_label_num, "incoming_card_combination_to_label_num_filtered.pkl")
save_dict(label_num_to_incoming_card_combination, "label_num_to_incoming_card_combination_filtered.pkl")

In [30]:
incoming_card_labels = np.load("incoming_card_labels.npy")
incoming_card_combination_to_label_num = load_dict("incoming_card_combination_to_label_num.pkl")
label_num_to_incoming_card_combination = load_dict("label_num_to_incoming_card_combination.pkl")

In [4]:
incoming_card_labels = np.load("incoming_card_labels_filtered.npy")
incoming_card_combination_to_label_num = load_dict("incoming_card_combination_to_label_num_filtered.pkl")
label_num_to_incoming_card_combination = load_dict("label_num_to_incoming_card_combination_filtered.pkl")

In [31]:
len(incoming_card_combination_to_label_num)

2389

## Card Game Neural Network Architecture
### Input Processing

+ Input shape: [N_samples, 56] (binary representation of hands)
+ Split into:
  - Regular cards [4, 13] (4 colors × 13 values)
  - Special cards [4]



### Regular Cards Path
#### First Layer: ColorInvariantConv -> explained below

8 types of filters (4 filters each = 32 total):

+ Street detection:
  - (4, 5) -> [1, 9] × 4 = 36 features
  - (4, 6) -> [1, 8] × 4 = 32 features
  - (4, 7) -> [1, 7] × 4 = 28 features
+ Single color patterns:
  - (1, 5) -> [4, 9] × 4 = 144 features
+ Pair street patterns:
  - (4, 2) -> [1, 12] × 4 = 48 features
  - (4, 3) -> [1, 11] × 4 = 44 features
  - (4, 4) -> [1, 10] × 4 = 40 features
+ Value patterns:
  - (4, 1) -> [1, 13] × 4 = 52 features

Total features from regular cards: 424

### Special Cards Path

+ Simple dense layer: 4 -> 16 features

### Two Architecture Options
#### Option 1: Direct Flatten

1. Flatten all ColorInvariantConv outputs
2. Concatenate with special cards features
3. Total features: 424 + 16 = 440
4. Dense layers: 256 -> 128
5. Output layer: [N, 2389]

#### Option 2: Separate Processing

1. Process each filter type through additional Conv1d (16 features each)
2. 8 parallel paths of length 16,  more features
3. Concatenate with special cards features (16)
4. Total features: Not sure, a lot
5. Dense layers: 256 -> 128
6. Output layer: [N, 2389]

### Key Features

+ Color invariance through ColorInvariantConv in first layer
+ Game-specific filter sizes capturing relevant patterns
+ Separate processing of special cards
+ Direct modeling of joint probability distribution over 2389 valid combinations
+ No padding in convolutions to preserve pattern semantics
<img src="./model-comparison.svg" />

In [44]:
import torch
import torch
from model import CardNet  # Make sure the model architecture is imported

best_model = CardNet()
best_model = best_model.cuda()
state_dict = torch.load('best_model.pt')
best_model.load_state_dict(state_dict)

  state_dict = torch.load('best_model.pt')


<All keys matched successfully>

In [45]:
from model import predict
probs = predict(best_model, db_as_np[4*167+1:4*167+20])
ind = np.argpartition(probs, -4, axis=1)[:, -4:]
print(ind)

[[128 133 246   7]
 [ 38   7  87  64]
 [ 87  22   7  64]
 [ 87  57  48  64]
 [128  22   7  64]
 [ 22 128   7  64]
 [  3  22   7  64]
 [ 87  22   7  64]
 [128  22   7  64]
 [ 64 133 128   7]
 [  3  87   7  64]
 [ 38  87   7  64]
 [  3  57   7  64]
 [128  22   7  64]
 [ 64   7 153 191]
 [  7  48 191 153]
 [ 47  22   7  64]
 [ 20 100   3  87]
 [128  57   7  64]]


In [9]:
db_as_np[4*167+1:4*167+2].reshape(-1, 4, 14)

array([[[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1]]], dtype=uint8)

In [46]:
probs[0, 64]

np.float32(0.0010674623)

In [None]:
for idx in range(db.len()):
    for p_id in range(4):
        if tr.print_hand(db.get_round(idx)[p_id].first_14).count('A') == 4:
            print(idx, p_id)

In [38]:
display_colored_hand(tr.print_hand(db.get_round(167)[1].first_14))
print(label_num_to_incoming_card_combination[64])
print(label_num_to_incoming_card_combination[100])
print(label_num_to_incoming_card_combination[246])
print(label_num_to_incoming_card_combination[7])

(1, 2, 13)
(1, 3, 12)
(0, 1, 12)
(1, 2, 12)
