<a href="https://colab.research.google.com/github/cuixianze/reinforce-learning/blob/main/kuhn_poker_CFR_itr_ipynb%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
from collections import defaultdict

class KuhnPokerCFR:
    def __init__(self):
        self.cards = ['A', 'B', 'C']
        self.regret_sum = defaultdict(lambda: [0.0, 0.0])  # regret for actions [pass, bet]
        self.strategy_sum = defaultdict(lambda: [0.0, 0.0])
        self.actions = ['p', 'b']  # p = pass/check, b = bet

    def get_strategy(self, info_set, realization_weight):
        regrets = self.regret_sum[info_set]
        positive_regrets = [r if r > 0 else 0 for r in regrets]
        normalizing_sum = sum(positive_regrets)

        if normalizing_sum > 0:
            strategy = [r / normalizing_sum for r in positive_regrets]
        else:
            strategy = [0.5, 0.5]

        # 평균 전략 업데이트
        for i in range(2):
            self.strategy_sum[info_set][i] += realization_weight * strategy[i]

        return strategy

    def get_action(self, strategy):
        r = random.random()
        return 0 if r < strategy[0] else 1

    def cfr(self, history, cards, p0, p1, player):
        opponent = 1 - player
        info_set = cards[player] + ":" + history

        if self.is_terminal(history):
            return self.get_payoff(history, cards, player)

        strategy = self.get_strategy(info_set, p0 if player == 0 else p1)
        util = [0.0, 0.0]
        node_util = 0.0

        for a in range(2):
            next_history = history + self.actions[a]
            if player == 0:
                util[a] = -self.cfr(next_history, cards, p0 * strategy[a], p1, 1)
            else:
                util[a] = -self.cfr(next_history, cards, p0, p1 * strategy[a], 0)
            node_util += strategy[a] * util[a]

        for a in range(2):
            regret = util[a] - node_util
            if player == 0:
                self.regret_sum[info_set][a] += p1 * regret
            else:
                self.regret_sum[info_set][a] += p0 * regret

        return node_util

    def is_terminal(self, history):
        return history in ["pp", "pbp", "pbb", "bb", "bp"]

    def get_payoff(self, history, cards, player):
        opponent = 1 - player
        if history == "pp":
            return 1 if cards[player] > cards[opponent] else -1
        elif history == "pbp":
            return 1 if cards[player] > cards[opponent] else -1
        elif history == "pbb":
            return 2 if cards[player] > cards[opponent] else -2
        elif history == "bb":
            return 2 if cards[player] > cards[opponent] else -2
        elif history == "bp":
            return 1 if player == 0 else -1  # player 0이 베팅했고, 1이 폴드한 상황
        return 0


    def train(self, iterations):
        for i in range(iterations):
            cards = self.cards[:]
            random.shuffle(cards)
            self.cfr("", cards, 1.0, 1.0, 0)

    def get_average_strategy(self):
        average_strategy = {}
        for info_set in self.strategy_sum:
            strategy = self.strategy_sum[info_set]
            normalizing_sum = sum(strategy)
            if normalizing_sum > 0:
                average = [x / normalizing_sum for x in strategy]
            else:
                average = [0.5, 0.5]
            average_strategy[info_set] = average
        return average_strategy

# 사용 예시
cfr = KuhnPokerCFR()
cfr.train(5000000)
for info_set, strategy in sorted(cfr.get_average_strategy().items()):
    print(f"{info_set}: Pass={strategy[0]:.2f}, Bet={strategy[1]:.2f}")



A:: Pass=0.67, Bet=0.33
A:b: Pass=1.00, Bet=0.00
A:p: Pass=1.00, Bet=0.00
A:pb: Pass=1.00, Bet=0.00
B:: Pass=1.00, Bet=0.00
B:b: Pass=0.67, Bet=0.33
B:p: Pass=1.00, Bet=0.00
B:pb: Pass=1.00, Bet=0.00
C:: Pass=0.00, Bet=1.00
C:b: Pass=0.00, Bet=1.00
C:p: Pass=0.00, Bet=1.00
C:pb: Pass=0.03, Bet=0.97


In [None]:

for i in sorted(cfr.regret_sum.items()):
  print(i)

('A:', [-36.757123933155306, 541.6200040579884])
('A:b', [0.25, -832581.75])
('A:p', [0.125, -5.375])
('A:pb', [0.25, -833765.9668498458])
('B:', [1.0, -833121.25])
('B:b', [540.0780066497488, -1549.492268943936])
('B:p', [1.04071918343522, -2.044592621447986])
('B:pb', [0.25, -833123.25])
('C:', [-277449.53823086387, 1.340345538370328])
('C:b', [-833880.3005133941, 0.75])
('C:p', [-0.125, 0.125])
('C:pb', [-8.369474974302193, 0.25])


In [None]:
for i in sorted(cfr.strategy_sum.items()):
  print(i)



('A:', [1111243.826691325, 555414.1733086652])
('A:b', [1665757.5, 0.5])
('A:p', [1665757.5, 0.5])
('A:pb', [1111243.576691325, 0.25])
('B:', [1666292.625, 2.375])
('B:b', [1112746.6276579536, 554612.3723420489])
('B:p', [1667343.66367518, 15.336324820115394])
('B:pb', [1666292.125, 0.5])
('C:', [8.835311804883206, 1667038.164688195])
('C:b', [0.5, 1666882.5])
('C:p', [0.5, 1666882.5])
('C:pb', [0.25, 8.585311804883206])


In [None]:
print(sorted(["A","B","C"]))
print("A">"B")

['A', 'B', 'C']
False
