In [6]:
import sys  
sys.path.insert(1, 'C:/Users/davand/OneDrive/Documents/Programming/MAAIF/pymdp/')

import numpy as np
import jax
import jax.numpy as jnp
import copy
from pymdp.multi_agent.kuhn_poker import KuhnPokerEnv
from pymdp.pdo_agents.agent_gradient import PDOAgentGradient
from pymdp.pdo_agents.full_policy import TabularSoftmaxPolicy
from typing import Tuple, Dict, List



class PDOKuhnPokerAgent(PDOAgentGradient):
    def __init__(self, A: np.ndarray, B: np.ndarray, learning_rate: float = 0.01, beta: float = 1.0, time_horizon: int = 2):
        self.learning_rate = learning_rate
        self.beta = beta
        self.time_horizon = time_horizon
        # self.env = KuhnPokerEnv()
        
        # Use the provided A and B matrices
        self.A = A
        self.B = B
        self.C = self.initialize_preference_dist()
        
        super().__init__(A=self.A, B=self.B, time_horizon=self.time_horizon, beta=self.beta)
        
        self.policy = self.initialize_policy()

    def initialize_preference_dist(self):
        # preference for agent is the reward, i.e., negative of the pot if they lost, and positive if they won
        # Define the preference matrix C
        # The shape should be (num_observations, num_observations, num_actions)
        num_observations = self.A.shape[0]
        num_actions = self.B.shape[-1]
    
        # list of two vectors: one of length num_cards, one of length num_actions, need third obs for pot observations - reward modality, may need extra level - Null level - predict 100% prob for this outcome until end of game


        # Initialize C with zeros
        C = np.zeros((num_observations, num_observations, num_actions))
        
        # Define preferences based on game outcomes
        # Winning states have positive preference, losing states have negative preference
        # The magnitude can be set to the pot size (1 for small pot, 2 for big pot)
        
        # Winning states (getting the pot)
        C[-2, :, :] = 1  # Small pot win
        C[-1, :, :] = 2  # Big pot win
        
        # Losing states (losing the pot)
        C[:, -2, :] = -1  # Small pot loss
        C[:, -1, :] = -2  # Big pot loss
        
        return C

    def initialize_policy(self):
        # Initialize the policy using TabularSoftmaxPolicy
        observation_sequences = self.generate_consistent_observation_seqs()
        return TabularSoftmaxPolicy(action_counts=self.num_controls, 
                                    observation_sequences=observation_sequences)

    def update_policy(self, observation: Dict, action: str, reward: float):
        # Update the policy using gradient descent
        grad = jax.grad(self.G)(self.policy)
        self.policy.table -= self.learning_rate * grad.table

    def select_action(self, observation: Dict) -> str:
        obs_seq = self.observation_to_sequence(observation)
        action_probs = self.policy.policy_for_observations(obs_seq)
        action = jax.random.choice(jax.random.PRNGKey(0), 2, p=action_probs)
        return 'check' if action == 0 else 'bet'

    def observation_to_sequence(self, observation: Dict) -> Tuple:
        # Convert the Kuhn Poker observation to a sequence compatible with the policy
        card = observation['card']
        history = observation['history']
        
        # Map card to integer (Jack: 0, Queen: 1, King: 2)
        card_mapping = {'J': 0, 'Q': 1, 'K': 2}
        card_int = card_mapping[card]
        
        # Convert history to a sequence of integers
        # 'c' for check: 0, 'b' for bet: 1
        history_sequence = tuple(0 if action == 'c' else 1 for action in history)
        
        # Combine card and history into a single sequence
        return (card_int,) + history_sequence


In [7]:
def play_pdo_kuhn_poker(num_episodes: int = 10000):
    env = KuhnPokerEnv()
    A, B = initialize_kuhn_poker_matrices()
    agent1 = PDOKuhnPokerAgent(A=A, B=B)
    agent2 = PDOKuhnPokerAgent(A=A, B=B)

    for episode in range(num_episodes):
        observation = env.reset()
        done = False
        
        while not done:
            if env.turn == 0 or env.turn == 2:
                action = agent1.select_action(observation)
            else:
                action = agent2.select_action(observation)
            
            next_observation, reward, done, _ = env.step(action)
            
            if done:
                if env.turn == 0:
                    agent1.update_policy(observation, action, reward)
                    agent2.update_policy(next_observation, 'fold' if reward > 0 else 'call', -reward)
                else:
                    agent2.update_policy(observation, action, reward)
                    agent1.update_policy(next_observation, 'fold' if reward < 0 else 'call', -reward)
            else:
                if env.turn == 1:  # Agent 1 just acted
                    agent1.update_policy(observation, action, 0)
                else:  # Agent 2 just acted
                    agent2.update_policy(observation, action, 0)
            
            observation = next_observation

        if episode % 1000 == 0:
            print(f"Episode {episode}")
            print("Agent 1 policy:", agent1.policy.table)
            print("Agent 2 policy:", agent2.policy.table)
            print()

    print("Final policies:")
    print("Agent 1:", agent1.policy.table)
    print("Agent 2:", agent2.policy.table)


def initialize_kuhn_poker_matrices():
    # Define the state space
    # States: (player1_card, player2_card, betting_round, last_action)
    # Cards: Jack (0), Queen (1), King (2)
    # Betting rounds: 0 (initial), 1 (after first action)
    # Last action: None (-1), Check (0), Bet (1)
    num_states = 3 * 3 * 2 * 3  # 54 states

    # Define the observation space
    # Observations: (own_card, betting_round, last_action)
    num_obs = 3 * 2 * 3  # 18 observations

    # Initialize A matrix (observation likelihood)
    A = np.zeros((num_obs, num_states))
    for s in range(num_states):
        p1_card, p2_card, betting_round, last_action = np.unravel_index(s, (3, 3, 2, 3))
        for o in range(num_obs):
            obs_card, obs_betting_round, obs_last_action = np.unravel_index(o, (3, 2, 3))
            if obs_card == p1_card and obs_betting_round == betting_round and obs_last_action == last_action:
                A[o, s] = 1

    # Normalize A matrix
    A /= A.sum(axis=0, keepdims=True)

    # Initialize B matrix (transition probabilities)
    num_actions = 2  # Check/Call (0) or Bet/Raise (1)
    B = np.zeros((num_states, num_states, num_actions))

    for s in range(num_states):
        p1_card, p2_card, betting_round, last_action = np.unravel_index(s, (3, 3, 2, 3))
        
        if betting_round == 0:
            for a in range(num_actions):
                next_s = np.ravel_multi_index((p1_card, p2_card, 1, a), (3, 3, 2, 3))
                B[next_s, s, a] = 1
        elif betting_round == 1 and last_action != -1:
            # Game ends, stay in the same state
            B[s, s, :] = 1
    
    print(A)
    print(B)

    return A, B

In [8]:
play_pdo_kuhn_poker()

test
test
[[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

TypeError: BranchingAgent.__init__() missing 1 required positional argument: 'env'