In [6]:
import numpy as np
import random

class Node:
	def __init__(self, num_actions):
		self.regret_sum = np.zeros(num_actions)
		self.strategy = np.zeros(num_actions)
		self.strategy_sum = np.zeros(num_actions)
		self.num_actions = num_actions
		self.number = 1

	def get_strategy(self):
		normalizing_sum = 0
		for a in range(self.num_actions):
			if self.regret_sum[a] > 0:
				self.strategy[a] = self.regret_sum[a]
			else:
				self.strategy[a] = 0
			normalizing_sum += self.strategy[a]

		for a in range(self.num_actions):
			if normalizing_sum > 0:
				self.strategy[a] /= normalizing_sum
			else:
				self.strategy[a] = 1.0/self.num_actions

		return self.strategy

	def get_average_strategy(self):
		avg_strategy = np.zeros(self.num_actions)
		normalizing_sum = 0
		
		for a in range(self.num_actions):
			normalizing_sum += self.strategy_sum[a]
		for a in range(self.num_actions):
			if normalizing_sum > 0:
				avg_strategy[a] = self.strategy_sum[a] / normalizing_sum
			else:
				avg_strategy[a] = 1.0 / self.num_actions
		
		return avg_strategy



In [9]:
class KuhnCFR:
	def __init__(self, iterations, decksize):
		self.nbets = 2
		self.iterations = iterations
		self.decksize = decksize
		self.cards = np.arange(decksize)
		self.bet_options = 2 #actions
		self.nodes = {}

	def cfr_iterations_external(self):
		util = np.zeros(2)
		for t in range(1, self.iterations + 1): 
			for i in range(2): # Players
				random.shuffle(self.cards)
				util[i] += self.external_cfr(self.cards[:2], [], 2, 0, i, t)

		print('Average game value: {}'.format(util[0]/(self.iterations)))
		for i in sorted(self.nodes):
			print(i, self.nodes[i].get_average_strategy())

	def external_cfr(self, cards, history, pot, nodes_touched, traversing_player, t):
		print('THIS IS ITERATION', t)


		# who is the acting player and who is the opponent
		plays = len(history)
		acting_player = plays % 2
		opponent_player = 1 - acting_player

		# History is in a terminal state then calculate payments
		if plays >= 2:
			if history[-1] == 0 and history[-2] == 1: #bet fold
				if acting_player == traversing_player:
					return 1
				else:
					return -1
			if (history[-1] == 0 and history[-2] == 0) or (history[-1] == 1 and history[-2] == 1): #check check or bet call, go to showdown
				if acting_player == traversing_player:
					if cards[acting_player] > cards[opponent_player]:
						return 1 #profit
					else:
						return -1
				else:
					if cards[acting_player] > cards[opponent_player]:
						return -1
					else:
						return 1


		infoset = str(cards[acting_player]) + str(history) # infoset are card acting player can see and history
		if infoset not in self.nodes:
			self.nodes[infoset] = Node(self.bet_options)
		else: 
			self.nodes[infoset].number += 1

		nodes_touched += 1

		# Here is where self play is done
		if acting_player == traversing_player:
			util = np.zeros(self.bet_options) #2 actions
			node_util = 0
			strategy = self.nodes[infoset].get_strategy()
			for a in range(self.bet_options):
				next_history = history + [a]
				pot += a
				util[a] = self.external_cfr(cards, next_history, pot, nodes_touched, traversing_player, t)
				node_util += strategy[a] * util[a]

			for a in range(self.bet_options):
				regret = util[a] - node_util
				self.nodes[infoset].regret_sum[a] += regret
			return node_util

		else: #acting_player != traversing_player sample strategy
			strategy = self.nodes[infoset].get_strategy()
			util = 0

			if random.random() < strategy[0]:
				next_history = history + [0]
			else: 
				next_history = history + [1]
				pot += 1
				
			util = self.external_cfr(cards, next_history, pot, nodes_touched, traversing_player, t)

			for a in range(self.bet_options):
				self.nodes[infoset].strategy_sum[a] += strategy[a]
			return util




In [10]:
if __name__ == "__main__":
	k = KuhnCFR( 10000, 10)
	k.cfr_iterations_external()

THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 1
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 2
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 3
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 4
THIS IS ITERATION 5
THIS IS ITERATION 5
THIS IS ITERATION 5
THIS IS ITERATION 5


In [11]:
for i in k.nodes:
    print(i,k.nodes[i].get_average_strategy())

4[] [0.00353535 0.99646465]
8[0] [0.5 0.5]
8[1] [0.00253807 0.99746193]
7[] [4.99001996e-04 9.99500998e-01]
4[0] [0. 1.]
7[0, 1] [0.5 0.5]
5[] [0. 1.]
1[0] [4.84027106e-04 9.99515973e-01]
5[0, 1] [0.5 0.5]
1[1] [0.00822846 0.99177154]
1[] [9.60614793e-04 9.99039385e-01]
1[0, 1] [0.5 0.5]
2[] [0.00306435 0.99693565]
5[0] [0.5 0.5]
2[0, 1] [0.5 0.5]
5[1] [0.00136612 0.99863388]
3[] [0.00102145 0.99897855]
7[0] [0.5 0.5]
3[0, 1] [0.5 0.5]
0[] [0. 1.]
3[0] [0.5 0.5]
0[0, 1] [0.5 0.5]
3[1] [0.0020429 0.9979571]
9[] [0.00147348 0.99852652]
9[0, 1] [0.5 0.5]
8[] [0. 1.]
8[0, 1] [0.5 0.5]
9[0] [0.5 0.5]
9[1] [0.00156413 0.99843587]
7[1] [5.000e-04 9.995e-01]
4[0, 1] [0.5 0.5]
6[] [0. 1.]
0[0] [0.5 0.5]
6[0, 1] [0.5 0.5]
0[1] [0.5 0.5]
4[1] [5.04032258e-04 9.99495968e-01]
2[1] [0.00257998 0.99742002]
2[0] [0.5 0.5]
6[0] [0.5 0.5]
6[1] [0.00101215 0.99898785]


In [5]:
for i in k.nodes:
    print(i,k.nodes[i].number)

7[] 1990
8[0] 989
8[1] 2017
3[] 2015
7[1] 1969
0[] 1965
1[0] 1024
1[1] 1997
1[] 1985
9[0] 995
9[1] 2007
8[] 2016
6[0] 1014
8[0, 1] 705
3[0, 1] 727
0[0] 1014
0[1] 2013
5[] 2020
4[] 2016
4[0, 1] 775
6[1] 2026
2[] 1981
9[] 1967
4[0] 995
7[0, 1] 699
4[1] 1929
6[] 2045
7[0] 1007
6[0, 1] 747
2[0] 999
5[0, 1] 723
2[0, 1] 650
3[0] 989
3[1] 2039
2[1] 2003
5[1] 1985
5[0] 989
9[0, 1] 692
0[0, 1] 653
1[0, 1] 656
