### CFR -- Beer Quiche Game

In [1]:
import numpy as np
from typing import List, Dict
import random
import sys

In [8]:
p1_actions = ['B', 'Q']
p2_actions = ['b', 'd']

class information_set():
    def __init__(self):
        self.cumulative_regrets = np.zeros(shape=2)
        self.strategy_sum = np.zeros(shape=2)
        self.num_actions = 2 # 2 for both players
        
    def normalize(self, strategy):
        """Normalize strategy. If no positive regrets, strategy is unif. random"""
        if sum(strategy) > 0:
            strategy /= sum(strategy)
        else:
            strategy = np.ones(self.num_actions)/self.num_actions
        
        return strategy
    
    def get_strategy(self, reach_probability):
        """Return regret matching strategy"""
        strategy = np.maximum(0, self.cumulative_regrets)
        strategy = self.normalize(strategy)
        
        self.strategy_sum += reach_probability * strategy
        
        return strategy
    
    def get_average_strategy(self):
        return self.normalize(self.strategy_sum.copy())

In [45]:
class BeerQuiche():   
    @staticmethod
    def is_terminal(history):
        return history in ['Bb', 'Bd', 'Qb', 'Qd']
    
    @staticmethod
    def get_payoff(history, p1_type):
        payoff = 0
        if p1_type == 'T': # tough
            if history == 'Bb':
                payoff = 2
            elif history == 'Bd':
                payoff = 1
            elif history == 'Qb':
                payoff = 1
            else:
                payoff = 0
        else:
            if history == 'Bb':
                payoff = -2
            elif history == 'Bd':
                payoff = 0
            elif history == 'Qb':
                payoff = -1
            else:
                payoff = 2
        
        return payoff

In [48]:
class cfr_trainer():
    def __init__(self):
        self.infoset_map: Dict[str, info_set] = {}
        
    def get_information_set(self, type_and_history):
        """add if needed and return"""
        if type_and_history not in self.infoset_map:
            self.infoset_map[type_and_history] = information_set()
        
        return self.infoset_map[type_and_history]
    
    def cfr(self, p1_type, history, reach_probabilities, active_player):
        if BeerQuiche.is_terminal(history):
            return BeerQuiche.get_payoff(history, p1_type)
        
        if active_player == 0:
            info_set = self.get_information_set(p1_type + history)
        else:
            info_set = self.get_information_set(history)
        
        strategy = info_set.get_strategy(reach_probabilities[active_player])
        
        op = (active_player + 1) % 2
        counterfactual_values = np.zeros(2)
        if active_player == 0:
            actions = p1_actions
        else:
            actions = p2_actions
            
        for ix, action in enumerate(actions):
            action_probability = strategy[ix]
            
            new_reach_probabilities = reach_probabilities.copy()
            new_reach_probabilities[active_player] *= action_probability
            
            counterfactual_values[ix] = -self.cfr(p1_type, history + action, new_reach_probabilities, op)
        
        node_value = counterfactual_values.dot(strategy)
        for ix, action in enumerate(actions):
            info_set.cumulative_regrets[ix] += reach_probabilities[op] * (counterfactual_values[ix] - node_value)
        
        return node_value

    def train(self, num_iterations):
        util = 0
        p = 1/3
        types = [1, 2]
        for _ in range(num_iterations):
            p1_t = np.random.choice(types, p=[p, 1-p])
            p1_type = 'T' if p1_t == 1 else 'W'
            history = ''
            reach_probabilities = np.ones(2)
            util += self.cfr(p1_type, history, reach_probabilities, 0)
        
        return util

In [49]:
trainer = cfr_trainer()
num_itrs = 1000000
util = trainer.train(num_itrs)

In [50]:
for name, info_set in sorted(trainer.infoset_map.items(), key= lambda s: len(s[0])):
    print(f"{name}: {info_set.get_average_strategy()}")

W: [0.24937692 0.75062308]
B: [0.50115336 0.49884664]
Q: [9.999995e-01 5.000000e-07]
T: [9.99998499e-01 1.50133018e-06]
