In [91]:
from jax import random
from jax import numpy as jnp
import numpy as np
from datetime import datetime
import csv
import h5py
import os
import pandas as pd

In [None]:
class zole:
    def __init__(self, seed: int = int(datetime.now().timestamp())) -> None:
        self.key = random.PRNGKey(seed)
        self.key, subkey = random.split(self.key)
        card_indices = random.permutation(subkey, 26)
        split_indices = [8, 16, 24]
        player_card_indices = jnp.split(card_indices, split_indices)
        self.unplayed_cards = jnp.zeros((4, 26), dtype=int)

        for player, cards in enumerate(player_card_indices):
            self.unplayed_cards = self.unplayed_cards.at[player, cards].set(1)

        self.key, subkey = random.split(self.key)
        self.big = random.randint(subkey, (1,), 0, 2)

        self.suit_splits = jnp.array([0, 14, 18, 22, 26])  # Trump, (Clubs, Spades, Hearts, (Diamonds))
                                               # Q-J -- A-10-K-9 -- 8-7
        self.points = {0:3, 1:3, 2:3, 3:3, 4:2, 5:2, 6:2, 7:2, 8:11, 9:10, 10:4, 11:0, 12:0, 13:0,
                       14:11, 15:10, 16:4, 17:0, 18:11, 19:10, 20:4, 21:0, 22:11, 23:10, 24:4, 25:0}
        self.leader = 0
        self.leaders = [0]
        self.begin_state = self.unplayed_cards
        self.cards_played = []
        self.big_points = 0
        self.small_points = 0
        self.big_win = False

    def play_hand(self) -> None:
        self.key, subkey = random.split(self.key)
        cards_played = [None, None, None]
        points = 0
        column = self.unplayed_cards[self.leader,:]
        playable = jnp.where(column == 1)[0]
        cards_played[self.leader] = random.choice(subkey, playable)

        self.unplayed_cards = self.unplayed_cards.at[self.leader, int(cards_played[self.leader])].set(0)
        suit_led = jnp.searchsorted(self.suit_splits, cards_played[self.leader], side="right") - 1
        suits_played = [suit_led,suit_led,suit_led]
        points += self.points[int(cards_played[self.leader])]

        for i in range(1,3):
            turn = (self.leader + i) % 3

            column = self.unplayed_cards[turn,self.suit_splits[suit_led]:self.suit_splits[suit_led+1]]
            column = jnp.pad(column, (self.suit_splits[suit_led],26-self.suit_splits[suit_led+1]))
            playable = jnp.where(column == 1)[0]
            
            if playable.shape[0] == 0:
                column = self.unplayed_cards[turn, :]
                playable = jnp.where(column == 1)[0]
                self.key, subkey = random.split(self.key)
                cards_played[turn] = random.choice(subkey, playable)
                suit = jnp.searchsorted(self.suit_splits, cards_played[turn], side="right") -1
                suits_played[turn] = suit
            else:
                for element in playable:
                    assert self.unplayed_cards[turn, element] == 1
                self.key, subkey = random.split(self.key)
                cards_played[turn] = random.choice(subkey, playable)
            
            points += self.points[int(cards_played[turn])]
            self.unplayed_cards = self.unplayed_cards.at[turn, int(cards_played[turn])].set(0)
            
        if suits_played[0] == suits_played[1] == suits_played[2]:
            self.leader = jnp.argmin(jnp.array(cards_played))
        elif (suits_played[0] == 0) or (suits_played[1] == 0) or (suits_played[2] == 0):
            self.leader = jnp.argmin(jnp.array(cards_played))
        else:
            winner = -1
            lowest_card = 30
            for i in range(len(suits_played)):
                if suits_played[i] == suit_led:
                    if cards_played[i] < lowest_card:
                        lowest_card = cards_played[i]
                        winner = i
            self.leader = winner
        self.leaders.append(self.leader)
        self.cards_played.append(cards_played)
        if self.leader == self.big:
            self.big_points += points
        else:
            self.small_points += points
    def play_game(self) -> None:
        for i in range(8):
            self.play_hand()
        pocket_cards = jnp.where(self.begin_state[3,:])[0]
        self.big_points += self.points[int(pocket_cards[0])]
        self.big_points += self.points[int(pocket_cards[1])]
        assert self.big_points + self.small_points == 120

        if self.big_points > 60:
            self.big_win = True
    def write_game(self) -> None:
        print(f"Initial State: \n{self.begin_state}")
        print(f"Takers: {jnp.array(self.leaders[1:])}")
        print(f"Cards Played: \n{jnp.array(self.cards_played)}")
        print(f"Big: {self.big}")
        print(F"Big win: {self.big_win}")
    

In [183]:
import h5py
import numpy as np

def save_game_to_hdf5(file: str, begin_state, leader, cards_played, big, big_win):
    begin_state = np.expand_dims(begin_state, axis=0)
    leader = np.array([leader])
    cards_played = np.expand_dims(cards_played, axis=0)
    big_win = np.array([big_win])

    with h5py.File(file, 'a') as f:
        if 'begin_states' in f:
            f['begin_states'].resize(f['begin_states'].shape[0] + 1, axis=0)
            f['begin_states'][-1] = begin_state
        else:
            f.create_dataset('begin_states', data=begin_state, maxshape=(None, 4, 26), chunks=True)

        if 'leaders' in f:
            f['leaders'].resize(f['leaders'].shape[0] + 1, axis=0)
            f['leaders'][-1] = leader
        else:
            f.create_dataset('leaders', data=leader, maxshape=(None,), chunks=True)

        if 'cards_played' in f:
            f['cards_played'].resize(f['cards_played'].shape[0] + 1, axis=0)
            f['cards_played'][-1] = cards_played
        else:
            f.create_dataset('cards_played', data=cards_played, maxshape=(None, 8, 3), chunks=True)

        if 'big' in f:
            f['big'].resize(f['big'].shape[0] + 1, axis=0)
            f['big'][-1] = big
        else:
            f.create_dataset('big', data=big, maxshape=(None,), chunks=True)

        if 'big_win' in f:
            f['big_win'].resize(f['big_win'].shape[0] + 1, axis=0)
            f['big_win'][-1] = big_win
        else:
            f.create_dataset('big_win', data=big_win, maxshape=(None,), chunks=True)


In [199]:
game = zole(int(datetime.now().timestamp()))
game.play_game()
game.write_game()
save_game_to_hdf5("game.h5", game.begin_state, game.leader, game.cards_played, game.big, game.big_win)

Initial State: 
[[0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 1 1]
 [1 1 0 0 0 0 0 1 0 0 0 0 1 0 1 1 0 0 1 0 0 0 0 1 0 0]
 [0 0 1 0 1 0 1 0 0 1 0 1 0 1 0 0 0 1 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]]
Takers: [2 1 2 2 1 1 2 2]
Cards Played: 
[[ 8  7  4]
 [ 5  1  9]
 [ 3 12  2]
 [25 23 22]
 [19  0  6]
 [16 15 17]
 [20 18 11]
 [24 14 13]]
Big: [0]
Big win: False


In [200]:
game = zole()
game.play_game()
game.write_game()
save_game_to_hdf5("game.h5", game.begin_state, game.leader, game.cards_played, game.big, game.big_win)

Initial State: 
[[0 1 0 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0]
 [1 0 0 1 0 0 0 0 0 0 0 0 1 0 0 1 0 1 0 1 0 1 0 1 0 0]
 [0 0 1 0 0 0 0 1 1 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 1]
 [0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
Takers: [2 0 2 2 1 2 1 2]
Cards Played: 
[[16 15  2]
 [ 1  3  7]
 [20 21 18]
 [10 12  9]
 [ 6  0  8]
 [14 17 11]
 [24 23 25]
 [22 19 13]]
Big: [1]
Big win: False


In [201]:
def load_game_from_hdf5(file: str):
    with h5py.File(file, 'r') as f:
        begin_states = f['begin_states'][:]
        leaders = f['leaders'][:]
        cards_played = f['cards_played'][:]
        big = f['big'][:]
        big_win = f['big_win'][:]
    
    return begin_states, leaders, cards_played, big, big_win

begin_states, leaders, cards_played, big, big_win = load_game_from_hdf5('game.h5')

print(begin_states.shape)

(5, 4, 26)
