In [1]:
%load_ext autoreload
%autoreload 2
import tichu_rustipy as tr
import numpy as np
from IPython.display import display, HTML
import pickle

def display_colored_hand(hand_str):
    # Convert ANSI escape codes to HTML
    hand_str = (hand_str
        .replace('\x1b[31m', '<span style="color: red">')
        .replace('\x1b[32m', '<span style="color: green">')
        .replace('\x1b[33m', '<span style="color: yellow">')
        .replace('\x1b[34m', '<span style="color: dodgerblue ">')
        .replace('\x1b[0m', '</span>')
    )
    display(HTML(hand_str))
def save_dict(dictionary, filename):
    with open(filename, 'wb') as file:
        pickle.dump(dictionary, file)

# Loading the dictionary from a file
def load_dict(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)

In [2]:
db = tr.BSWSimple("../tichu_rust/bsw_filtered.db")
display_colored_hand(tr.print_hand(db.get_round(0)[0].first_14))
db.len()

400881

## Exchange db

In [3]:
#exch_db = tr.bulk_transform_db_into_np56_array(db)
#np.save("exch_model/db_as_np_filtered", exch_db)
#exch_db = np.load("exch_model/db_as_np.npy")
exch_db = np.load("exch_model/db_as_np_filtered.npy")

In [None]:
# Calculate labels and mapping from incoming card tuples to index
incoming_card_combination_to_label_num = {}
incoming_card_labels = np.zeros(len(exch_db), dtype=np.uint16)
i = 0
while i < len(exch_db):
    prh_round = db.get_round(i//4)
    for j in range(4):
        incoming_card_combo = tr.prh_to_incoming_cards(prh_round[j])
        if not incoming_card_combo in incoming_card_combination_to_label_num:
            incoming_card_combination_to_label_num[incoming_card_combo] = len(incoming_card_combination_to_label_num)
        incoming_card_labels[i+j] = incoming_card_combination_to_label_num[incoming_card_combo]
    i += 4
label_num_to_incoming_card_combination = {value: key for key,value in incoming_card_combination_to_label_num.items()}
# Save
np.save("exch_model/incoming_card_labels", incoming_card_labels)
save_dict(incoming_card_combination_to_label_num, "exch_model/incoming_card_combination_to_label_num.pkl")
save_dict(label_num_to_incoming_card_combination, "exch_model/label_num_to_incoming_card_combination.pkl")

In [4]:
exch_labels = np.load("exch_model/incoming_card_labels.npy")
incoming_card_combination_to_label_num = load_dict("exch_model/incoming_card_combination_to_label_num.pkl")
label_num_to_incoming_card_combination = load_dict("exch_model/label_num_to_incoming_card_combination.pkl")

In [5]:
len(incoming_card_combination_to_label_num)

2389

## Hand strength db

In [6]:
#hand_strength_db =  tr.bulk_transform_db_into_np90_array(db)
#np.save("hand_strength_model/db_as_np_filtered", hand_strength_db)

hand_strength_db = np.load("hand_strength_model/db_as_np_filtered.npy")

In [None]:
# Calculate labels and mapping from incoming card tuples to index
hand_strength_labels = np.zeros(len(hand_strength_db), dtype=np.int16)
i = 0
while i < len(hand_strength_db):
    prh_round = db.get_round(i//4)
    for j in range(4):
        hand_strength_labels[i+j] =  prh_round[j].round_score_relative_gain_gt_as_t()
    i += 4
# Save
np.save("hand_strength_model/labels_filtered", hand_strength_labels)

In [7]:
hand_strength_labels = np.load("hand_strength_model/labels_filtered.npy")


## Card Game Neural Network Architecture
### Input Processing

+ Input shape: [N_samples, 56] (binary representation of hands)
+ Split into:
  - Regular cards [4, 13] (4 colors × 13 values)
  - Special cards [4]



### Regular Cards Path
#### First Layer: ColorInvariantConv -> explained below

8 types of filters (4 filters each = 32 total):

+ Street detection:
  - (4, 5) -> [1, 9] × 4 = 36 features
  - (4, 6) -> [1, 8] × 4 = 32 features
  - (4, 7) -> [1, 7] × 4 = 28 features
+ Single color patterns:
  - (1, 5) -> [4, 9] × 4 = 144 features
+ Pair street patterns:
  - (4, 2) -> [1, 12] × 4 = 48 features
  - (4, 3) -> [1, 11] × 4 = 44 features
  - (4, 4) -> [1, 10] × 4 = 40 features
+ Value patterns:
  - (4, 1) -> [1, 13] × 4 = 52 features

Total features from regular cards: 424

### Special Cards Path

+ Simple dense layer: 4 -> 16 features

### Two Architecture Options
#### Option 1: Direct Flatten

1. Flatten all ColorInvariantConv outputs
2. Concatenate with special cards features
3. Total features: 424 + 16 = 440
4. Dense layers: 256 -> 128
5. Output layer: [N, 2389]

#### Option 2: Separate Processing

1. Process each filter type through additional Conv1d (16 features each)
2. 8 parallel paths of length 16,  more features
3. Concatenate with special cards features (16)
4. Total features: Not sure, a lot
5. Dense layers: 256 -> 128
6. Output layer: [N, 2389]

### Key Features

+ Color invariance through ColorInvariantConv in first layer
+ Game-specific filter sizes capturing relevant patterns
+ Separate processing of special cards
+ Direct modeling of joint probability distribution over 2389 valid combinations
+ No padding in convolutions to preserve pattern semantics
<img src="./model-comparison.svg" />

In [8]:
import torch
import torch
from exch_model.model import CardNet  # Make sure the model architecture is imported
from hand_strength_model.model import HandStrengthNet

exch_model = CardNet()
exch_model = exch_model.cuda()
state_dict = torch.load('exch_model/best_model.pt')
exch_model.load_state_dict(state_dict)

hand_strength_model = HandStrengthNet()
hand_strength_model = hand_strength_model.cuda()
state_dict = torch.load('hand_strength_model/best_model.pt')
hand_strength_model.load_state_dict(state_dict)


  state_dict = torch.load('exch_model/best_model.pt')
  state_dict = torch.load('hand_strength_model/best_model.pt')


<All keys matched successfully>

In [9]:
from exch_model.model import predict as exch_predict
probs = exch_predict(exch_model, exch_db[4*167+1:4*167+20])
ind = np.argpartition(probs, -4, axis=1)[:, -4:]
print(ind)

[[128 133 246   7]
 [ 38   7  87  64]
 [ 87  22   7  64]
 [ 87  57  48  64]
 [128  22   7  64]
 [ 22 128   7  64]
 [  3  22   7  64]
 [ 87  22   7  64]
 [128  22   7  64]
 [ 64 133 128   7]
 [  3  87   7  64]
 [ 38  87   7  64]
 [  3  57   7  64]
 [128  22   7  64]
 [ 64   7 153 191]
 [  7  48 191 153]
 [ 47  22   7  64]
 [ 20 100   3  87]
 [128  57   7  64]]


In [10]:
display_colored_hand(tr.print_hand(db.get_round(1)[3].final_14))

In [11]:
from hand_strength_model.model import predict as score_predict
scores = score_predict(hand_strength_model, hand_strength_db[:20])
print(scores)

[  92.42789   -97.90347    62.338425 -146.95308    66.089836  -24.292103
  -18.120998  -62.53246    93.01433   -15.131215  -33.71086   -16.614683
  -59.7088     18.452246  -93.82926   135.00415    93.58096    50.12252
  -20.62323    45.161583]


In [12]:
def exchange_bot(hand):
    hand_as_np = tr.transform_into_np56_array(hand)
    out_combos = tr.get_legal_outgoing_card_combinations(hand)
    print(f"Trying {len(out_combos)} many different exchange possibilities!")
    ingoing_probabilities = exch_predict(exch_model, hand_as_np)[0, :]
    relevant_ingoing_indices = []
    for idx, probability in enumerate(ingoing_probabilities):
        if probability < 0.004:
            continue
        if not tr.get_legal_incoming_card_combinations(hand, label_num_to_incoming_card_combination[idx]):
            continue
        relevant_ingoing_indices.append(idx)
    print(f"Found {len(relevant_ingoing_indices)} relevant ingoing possibilities!")
    prob_sum = np.sum(ingoing_probabilities[relevant_ingoing_indices])

    model_call_on = []
    indices_to_info = []
    batch = np.zeros((64, 90), dtype=np.uint8)

    for out_hand, out_partner in out_combos:
        stripped_hand = hand ^ out_hand
        for ingoing_idx in relevant_ingoing_indices:
            in_prob = ingoing_probabilities[ingoing_idx] / prob_sum #Reweight such that it adds up to 1.
            in_partner = label_num_to_incoming_card_combination[ingoing_idx][2]
            in_hands = tr.get_legal_incoming_card_combinations(hand, label_num_to_incoming_card_combination[ingoing_idx])
            if tr.could_get_street_bomb(stripped_hand, out_hand, label_num_to_incoming_card_combination[ingoing_idx]):
                tr.prepare_batch_np90_array(stripped_hand, in_partner, out_partner, in_hands, batch)
                to_add = batch[:len(in_hands)].copy()
                in_prob /= len(in_hands)
            else:
                first_in = in_hands[0]
                to_add = tr.transform_into_np90_array(stripped_hand ^ first_in, in_partner, out_partner)
            for i in range(len(to_add)):
                model_call_on.append(to_add[i, :])
                indices_to_info.append((out_hand, out_partner, in_prob))
    scores = score_predict(hand_strength_model, np.array(model_call_on))
    combo_to_val = {}
    for i in range(len(indices_to_info)):
        out_hand, out_partner, in_prob = indices_to_info[i]
        if (out_hand, out_partner) not in combo_to_val:
            combo_to_val[(out_hand, out_partner)] = 0.
        combo_to_val[(out_hand, out_partner)] = combo_to_val[(out_hand, out_partner)] + in_prob * scores[i]
    return combo_to_val

In [13]:
def get_top_n(dictionary, n=10):
    # Sort dictionary items by value in descending order and get top n
    sorted_items = sorted(dictionary.items(), key=lambda x: x[1], reverse=True)[:n]
    return sorted_items

def format_exchange_possibilites(top_possibilities):
    for idx, ((out_hand, out_partner), value) in enumerate(top_possibilities):
        print(f"Option {idx+1} has expected score {value}")
        def format_out_partner():
            if out_partner == 14:
                return "🐦"
            elif out_partner == 15:
                return "🐉"
            elif out_partner == 16:
                return "1"
            elif out_partner == 0:
                return "↺"
            else:
                return ["2", "3", "4", "5", "6", "7", "8", "9", "T", "J", "Q", "K", "A"][out_partner-1]
        print(f"Give partner: {format_out_partner()}")
        print("Give all: ", end="")
        display_colored_hand(tr.print_hand(out_hand))
        print("-"*50)

In [14]:
round_num = 2
player_num = 0
hand = db.get_round(round_num)[player_num].first_14
display_colored_hand(tr.print_hand(hand))
combo_to_val = exchange_bot(hand)
top_ten = get_top_n(combo_to_val, n=10)
format_exchange_possibilites(top_ten)

Trying 1092 many different exchange possibilities!
Found 141 relevant ingoing possibilities!
Option 1 has expected score 89.45057678222656
Give partner: J
Give all: 

--------------------------------------------------
Option 2 has expected score 88.43009185791016
Give partner: J
Give all: 

--------------------------------------------------
Option 3 has expected score 83.92390441894531
Give partner: J
Give all: 

--------------------------------------------------
Option 4 has expected score 83.38031005859375
Give partner: 7
Give all: 

--------------------------------------------------
Option 5 has expected score 83.22826385498047
Give partner: J
Give all: 

--------------------------------------------------
Option 6 has expected score 79.143310546875
Give partner: 6
Give all: 

--------------------------------------------------
Option 7 has expected score 76.09078979492188
Give partner: J
Give all: 

--------------------------------------------------
Option 8 has expected score 73.41964721679688
Give partner: 2
Give all: 

--------------------------------------------------
Option 9 has expected score 70.98426055908203
Give partner: J
Give all: 

--------------------------------------------------
Option 10 has expected score 70.495849609375
Give partner: 7
Give all: 

--------------------------------------------------


In [None]:
for x in tr.get_legal_incoming_card_combinations(db.get_round(167)[1].first_14, label_num_to_incoming_card_combination[7]):
    display_colored_hand(tr.print_hand(x))

In [None]:
display_colored_hand(tr.print_hand(db.get_round(167)[1].first_14))
print(label_num_to_incoming_card_combination[64])
print(label_num_to_incoming_card_combination[100])
print(label_num_to_incoming_card_combination[246])
print(label_num_to_incoming_card_combination[7])