In [6]:
import sys  
sys.path.insert(1, 'C:/Users/davand/OneDrive/Documents/Programming/MAAIF/pymdp/')

import numpy as np
import jax
import jax.numpy as jnp
import copy
from pymdp.multi_agent.kuhn_poker import KuhnPokerEnv
from pymdp.pdo_agents.agent_gradient import PDOAgentGradient
from pymdp.pdo_agents.full_policy import TabularSoftmaxPolicy
from typing import Tuple, Dict, List



class PDOKuhnPokerAgent(PDOAgentGradient):
    def __init__(self, A: np.ndarray, B: np.ndarray, player_idx: int, learning_rate: float = 0.01, beta: float = 1.0, time_horizon: int = 2):
        self.learning_rate = learning_rate
        self.beta = beta
        self.time_horizon = time_horizon
        # self.env = KuhnPokerEnv()
        
        # Use the provided A and B matrices
        self.A = A
        self.B = B
        self.C = self.initialize_preference_dist(player_idx)
        
        super().__init__(A=self.A, B=self.B, time_horizon=self.time_horizon, beta=self.beta)
        
        self.policy = self.initialize_policy()

    def initialize_preference_dist(self, player_idx: int):
        # preference for agent is the reward, i.e., negative of the pot if they lost, and positive if they won
        # Define the preference matrix C
        # The shape should be (num_observations, num_observations, num_actions)
    
        # list of three vectors: one of length num_cards, one of length num_actions, need third obs for pot observations - reward modality, 
        # may need extra level - Null level - predict 100% prob for this outcome until end of game


        # Initialize C with zeros
        C = np.zeros((self.num_cards, self.num_actions, self.num_rewards))
        
        # Define preferences based on the reward modality
        for reward in range(self.num_rewards):
            if reward == 0:  # -2 (big loss)
                C[:, :, reward] = -2.0
            elif reward == 1:  # -1 (small loss)
                C[:, :, reward] = -1.0
            elif reward == 2:  # 0 (neutral)
                C[:, :, reward] = 0.0
            elif reward == 3:  # 1 (small win)
                C[:, :, reward] = 1.0
            elif reward == 4:  # 2 (big win)
                C[:, :, reward] = 2.0

        # Normalize C using softmax
        C = jax.nn.softmax(C, axis=-1)
        
        return C

    def initialize_policy(self):
        # Initialize the policy using TabularSoftmaxPolicy
        observation_sequences = self.generate_consistent_observation_seqs()
        return TabularSoftmaxPolicy(action_counts=self.num_controls, 
                                    observation_sequences=observation_sequences)



In [7]:
def play_pdo_kuhn_poker(num_episodes: int = 10000):
    env = KuhnPokerEnv()
    A1 = copy.deepcopy(env.get_likelihood_dist(0))
    B1 = copy.deepcopy(env.get_transition_dist(0))
    A2 = copy.deepcopy(env.get_likelihood_dist(1))
    B2 = copy.deepcopy(env.get_transition_dist(1))
    agent1 = PDOKuhnPokerAgent(A=A1, B=B1, player_idx=0)
    agent2 = PDOKuhnPokerAgent(A=A2, B=B2, player_idx=1)

    for episode in range(num_episodes):
        observation = env.reset()
        done = False
        
        while not done:
            if env.turn == 0 or env.turn == 2:
                action = agent1.select_action(observation)
            else:
                action = agent2.select_action(observation)
            
            next_observation, reward, done, _ = env.step(action)
            
            if done:
                if env.turn == 0:
                    agent1.update_policy(observation, action, reward)
                    agent2.update_policy(next_observation, 'fold' if reward > 0 else 'call', -reward)
                else:
                    agent2.update_policy(observation, action, reward)
                    agent1.update_policy(next_observation, 'fold' if reward < 0 else 'call', -reward)
            else:
                if env.turn == 1:  # Agent 1 just acted
                    agent1.update_policy(observation, action, 0)
                else:  # Agent 2 just acted
                    agent2.update_policy(observation, action, 0)
            
            observation = next_observation

        if episode % 1000 == 0:
            print(f"Episode {episode}")
            print("Agent 1 policy:", agent1.policy.table)
            print("Agent 2 policy:", agent2.policy.table)
            print()

    print("Final policies:")
    print("Agent 1:", agent1.policy.table)
    print("Agent 2:", agent2.policy.table)

In [8]:
play_pdo_kuhn_poker()

test
test
[[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

TypeError: BranchingAgent.__init__() missing 1 required positional argument: 'env'