# Building Q-Table for Choosing Cards

In [1]:
from card_choice_gym import CardChoiceEnv
import numpy as np
import random
import pandas as pd
import csv

In [2]:
def LearnJoker(q_in=np.zeros((4, 9, 3, 4, 19, 4)), alpha_in=0.01, epsilon_in=0.5, gamma_in=0.95, episodes_in=100):
  acts = ['STRG-BEAT', 'STRG-LOSS', 'WEAK-BEAT','WEAK-LOSS']
  env = CardChoiceEnv()

  alpha, gamma, epsilon = alpha_in, gamma_in, epsilon_in
  q = q_in
  
  wins = []
  good_calls = np.zeros((7,))

  for i in range(episodes_in): 
    done = False
    s = env.reset()
    s0, s1, s2, s3, s4 = s
    while True: 
      if np.random.random() < epsilon:
      # choose random action
        act_num = random.randint(0, 3)
      else:
        # greedy
        act_num = np.argmax(q[s0, s1, s2, s3, s4])
      
      action = acts[act_num]

      s_, r, done, _ = env.step(action)

      s_0, s_1, s_2, s_3, s_4 = s_
      td_target = r + gamma * np.argmax(q[s_0, s_1, s_2, s_3, s_4])
      td_error = td_target - q[s0, s1, s2, s3, s4, act_num]
      s = s_

      q[s0, s1, s2, s3, s4, act_num] += alpha * td_error
      if done:
        if r > 0:
          wins.append(i)
          good_calls = np.vstack((good_calls, [env.call_state]))

        break
  return wins, good_calls, q

In [3]:
eps = 100000
wins, calls, q = LearnJoker(epsilon_in=0.1, episodes_in=eps)
len(wins)

208

In [11]:
eps = 10000
wins, calls, q = LearnJoker(epsilon_in=0.01, episodes_in=eps, q_in = q)
len(calls)

1450

In [4]:
with open('../data/calls.csv', 'a+') as out:
    csv_out = csv.writer(out)
    # csv_out.writerow(['call','order', 'already', 'jokers', "aces", "kings", "queens"])
    for row in calls:
        csv_out.writerow(row)

In [9]:
df = pd.read_csv('../data/calls.csv', dtype=int)
df

Unnamed: 0,call,order,already,jokers,aces,kings,queens
0,0,0,0,0,0,0,0
1,2,3,6,0,0,3,0
2,1,2,4,1,1,0,2
3,1,3,7,0,1,1,2
4,1,2,5,1,0,1,1
...,...,...,...,...,...,...,...
16551,1,1,5,0,2,0,1
16552,2,3,6,1,1,1,2
16553,2,0,2,0,2,1,1
16554,0,3,6,0,1,0,0


In [10]:
df[df.call == 9].queens.describe()

count    0.0
mean     NaN
std      NaN
min      NaN
25%      NaN
50%      NaN
75%      NaN
max      NaN
Name: queens, dtype: float64

In [10]:
def save_q_table(table):
    file = '../models/q-table.npy'
    with open(file, "wb"):
        np.save(file, table, allow_pickle=True)

In [11]:
save_q_table(q)