In [73]:
import gym
from gym import spaces
import numpy as np
import pandas as pd
from props import Card, Deck
from more_itertools import quantify
import datetime
import random

class NineEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 2}

    def __init__(self):
        self.observation_space = spaces.Tuple((
            spaces.Discrete(4),
            spaces.Discrete(9),
            spaces.Discrete(3),
            spaces.Discrete(4),
            spaces.Discrete(18),
        ))
        self.action_space = spaces.Discrete(4)

    def reset(self, seed=None):
        # We need the following line to seed self.np_random
        # super().reset()

        self.played = np.ndarray((4), dtype=Card)
        self.first_to_play = random.randint(0, 3)

        self._ord = (4 - self.first_to_play) % 4
        self._tot = 9

        self._set_players_cards() # this has to set self._jok
        self._set_calls() # this has to set self._tmc

        self.first_suit = None

        if self.first_to_play != 0: 
            self._pre_plays()
        else:
            self._btl = 1
        observation = self._get_obs()

        return observation

    def step(self, action): 
        # Map the action (element of {0,1,2,3}) to the card we play
        self._act(action)
        self._post_plays()
        self._set_players_cards() 
        self._set_calls()
        self._tot -= 1

        if not self.hand_winner:
            self._tmc += 1
        
        self.first_to_play = self.hand_winner
        self._ord = (4 - self.first_to_play) % 4

        self.played = None
        self._pre_plays()
        done = self._tot == 0
        reward = 1 if self._tmc == 0 and done else 0   # No negative rewards
        observation = self._get_obs()
        
        # No info
        return observation, reward, done, None

    def _act(self, action):
        playable = self._playable(0)
        if action == 'STRG-BEAT':
            card = self._choose_strg_beat(playable)
        elif action == 'WEAK-BEAT':
            card = self._choose_weak_beat(playable)
        elif action == 'STRG-LOSS':
            card = self._choose_strg_loss(playable)
        elif action == 'WEAK-LOSS':
            card = self._choose_weak_loss(playable)
        else:
            return None

        if self._ord == 0:
            self.first_suit = card.value
            self.played = np.ndarray((4), dtype=Card)
            self.played[0] = card
        elif self._ord == 3: 
            self.first_suit = None
            self.played[0] = card
            self.hand_winner = self._winner()
        else:
            self.played[0] = card
            

        if card.value == 13:
            for card in self._table[0]:
                if card.value == 13:
                    self._table[0].remove(card)
                    break
        else:
            self.remove_from_table(0, card.suit, card.value)
        return card

    def _get_highest(self, cards ,suit=None):
        '''Returns a joker if there is one in cards.
            Otherwise returns the highest valued card of the specified suit'''
        if not suit:
            return cards[np.argmax([card.value for card in cards])]
        else:
            acc = None;[acc := Card(x,y) for (x, y) in map(lambda card: (card.value, card.suit), cards) if (not acc) or (suit == y and x > acc.value) or x == 13]
            return acc if acc else None

    def _get_obs(self):
        return (self._ord, self._tot, self._jok, self._btl, self._tmc)

    def _set_players_cards(self):
        self._table = [list(), list(), list(), list()]
        self._deck = Deck()
        self._deck.shuffle()
        for hand in self._table:
            self._deck.deal(hand, times=9)
        self._jok = quantify(map(lambda x: x.value == 13, self._table[0]))

    def _set_calls(self):
        self.calls = np.zeros((4))
        for i in range(4):
            cur_player = (self.first_to_play + i) % 4
            want = quantify(map(lambda x: x.value > 11, self._table[cur_player]))
            if i < 3 or want == 9 - self.calls.sum():
                self.calls[cur_player] = want
            else:
                if want > 0:
                    self.calls[cur_player] = want
                else:
                    self.calls[cur_player] = 0
        self._tmc = -1 * self.calls[0]

    def _playable(self, player):
        all = self._table[player]
        firsts = list(filter(lambda x: x.suit == self.first_suit, all))
        if not firsts:
            return all
        else:
            firsts.extend(list(filter(lambda x: x.value == 13, all)))
            return firsts
    
    def _choose_strg_beat(self, playable): 
        highest = self._get_highest(playable, self.first_suit) 
        return highest if highest else self._get_highest(playable) 

    def _choose_weak_beat(self, playable): 
        beats = self._get_beats(playable) 
        if not beats:
            return self._choose_weak_loss(playable)
        lowest = Card(13, 0)
        for card in beats:
            if card.value < lowest.value:
                lowest = card
        return lowest

    def _choose_strg_loss(self, playable): 
        loses = self._get_loses(playable)
        if not loses:
            return self._choose_strg_beat(playable)
        else:
            highest = self._get_highest(playable, suit=self.first_suit)
            return highest if highest else self._get_highest(playable)
    
    def _choose_weak_loss(self, playable):
        acc = Card(13, 0); [acc := Card(card.value, card.suit) for card in playable if card.value < acc.value]
        return acc

    def _get_winning_card(self):
        if not self.played.any():
            return None
        elif quantify(map(lambda x: x and x.value == 13, self.played)): # No shadow jokers
            return Card(13,0)
        else:
            firsts = list(filter(lambda x: x and x.suit == self.first_suit, self.played))
            firsts[np.argmax(card.value for card in firsts)]

    def _get_beats(self, playable):
        cur_winner = self._get_winning_card() # returning None
        if self.first_suit == None or not cur_winner:
            return playable
        else:
            if cur_winner.value == 13:
                return list(filter(lambda x: x.value == 13, playable))
            else:
                return list(filter(lambda x: x and x.value > cur_winner.value, playable))
    
    def _get_loses(self, playable):
        if self.first_suit == None:
            return playable
        loses = []
        for card in playable:
            if card.suit == self.first_suit and card.value < max([y.value for y in list(filter(lambda x : x and x.suit == self.first_suit, playable))]):
                loses.append(card)
        return loses

    def _card_to_weight(self, card):
        weight = 0
        if card.value == 13:
            return 200
        if self.first_suit == card.suit:
            weight += 100
        return weight + card.value  

    def _winner(self):
        jok = 5
        joks = 0
        for i in range(4):
            cur = (self.first_to_play + i) % 4
            
            if self.played[cur].value == 13:
                jok = cur
                joks += 1
        if joks:
            return jok
        else:
            weights = []
            for i in range(4):
                weights.append(self._card_to_weight(self.played[i]))
            return np.argmax(weights)

    def remove_from_table(self, player, suit, value):
        for card in self._table[player]:
            if (card.suit, card.value) == (suit, value):
                self._table[player].remove(card)
                break

    def _play_rand(self, player):
        poss = self._playable(player)
        choice = poss[random.randint(0, len(poss) - 1)]
        self.remove_from_table(player,choice.value, choice.suit)
        self.played[player] = choice
        return choice

    def _pre_plays(self):
        self.played = np.ndarray((4), dtype=Card)
        for i in range(self._ord):
            cur = (self.first_to_play + i) % 4
            temp = self._play_rand(cur)
            self.played[cur] = temp
            if i == 0 and self.played[cur].value != 13:
                self.first_suit = self.played[cur].suit
        self._btl = 1 if self._get_beats(self._playable(0)) else 0

    def _post_plays(self):
        for i in range(1, 4 - self._ord):
            cur = i % 4
            self.played[cur] = self._play_rand(cur)

        self.hand_winner = self._winner()
        self.first_suit = None
        self.played = None
        self._set_players_cards()
        self._set_calls() 

In [82]:
def LearnJoker(q_in=np.ndarray((4, 9, 3, 4, 18, 4))):
  acts = ['STRG-BEAT', 'STRG-LOSS', 'WEAK-BEAT','WEAK-LOSS']
  env = NineEnv()

  alpha = 0.01
  gamma = .93

  epsilon = 1
  q = q_in

  wins = []

  for i in range(100000): 
    if not i % 5000:
      epsilon -= 0.05
    done = False
    s = env.reset()
    s0, s1, s2, s3, s4 = s
    s1 -= 1
    s4 += 8
    s4 = int(s4)
    while True: 
      if np.random.random() < epsilon:
      # choose random action
        act_num = random.randint(0, 3)
      else:
        # greedy
        act_num = np.argmax(q[s0, s1, s2, s3, s4])
      
      action = acts[act_num]

      s_, r, done, _ = env.step(action)

      s_0, s_1, s_2, s_3, s_4 = s_
      s_1 -= 1
      s_4 += 8
      s_4 = int(s_4)
    
      td_target = r + gamma * np.argmax(q[s_0, s_1, s_2, s_3, s_4])
      td_error = td_target - q[s0, s1, s2, s3, s4, act_num]
      s = s_

      q[s0, s1, s2, s3, s4, act_num] += alpha * td_error
      if done:
        if r > 0:
          wins.append(i)
        break
  return wins, q

In [83]:
pct, q = LearnJoker()
pct

[0,
 2,
 3,
 6,
 15,
 17,
 18,
 19,
 22,
 28,
 39,
 42,
 43,
 44,
 46,
 48,
 50,
 52,
 54,
 57,
 58,
 61,
 66,
 70,
 72,
 77,
 79,
 84,
 87,
 89,
 99,
 102,
 103,
 104,
 107,
 114,
 115,
 118,
 120,
 121,
 132,
 148,
 158,
 164,
 172,
 174,
 179,
 185,
 186,
 193,
 195,
 198,
 199,
 208,
 212,
 215,
 216,
 220,
 226,
 238,
 241,
 242,
 247,
 248,
 255,
 256,
 257,
 259,
 264,
 266,
 273,
 277,
 280,
 282,
 286,
 297,
 303,
 307,
 308,
 311,
 316,
 318,
 319,
 320,
 321,
 322,
 324,
 325,
 329,
 333,
 334,
 338,
 339,
 340,
 343,
 345,
 349,
 357,
 362,
 363,
 369,
 382,
 387,
 390,
 394,
 396,
 398,
 414,
 418,
 425,
 427,
 432,
 439,
 440,
 441,
 448,
 451,
 452,
 456,
 469,
 470,
 479,
 480,
 482,
 483,
 487,
 497,
 498,
 499,
 500,
 507,
 511,
 515,
 523,
 553,
 559,
 563,
 564,
 568,
 575,
 580,
 581,
 596,
 602,
 608,
 611,
 616,
 630,
 633,
 635,
 641,
 643,
 647,
 651,
 652,
 653,
 655,
 657,
 658,
 672,
 678,
 682,
 683,
 700,
 705,
 706,
 710,
 713,
 714,
 715,
 716,
 719,
 72

In [84]:
for i in range(100): print(quantify([x < 1000 * i for x in pct]))

0
229
470
689
916
1125
1353
1593
1830
2047
2271
2490
2719
2935
3180
3405
3615
3856
4065
4280
4520
4715
4936
5177
5399
5619
5863
6085
6319
6532
6734
6988
7244
7481
7711
7923
8165
8371
8593
8801
9017
9230
9447
9649
9859
10071
10267
10475
10701
10908
11136
11359
11590
11822
12047
12286
12497
12738
12977
13218
13439
13665
13887
14125
14332
14537
14741
14934
15145
15350
15553
15743
15951
16160
16372
16593
16786
16992
17206
17399
17597
17797
18009
18225
18442
18646
18854
19061
19278
19470
19661
19845
20042
20222
20402
20588
20765
20949
21139
21319
