#TEST CODE


In [None]:
from __future__ import annotations
import itertools
from typing import List, Dict, Tuple, Set

SUITS = ['h', 'd', 'c', 's']  # Hearts, Diamonds, Clubs, Spades
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']

class Card:
    """
    Represents a single playing card with rank and suit.

    Attributes:
        rank (str): The rank of the card ('2' through 'A')
        suit (str): The suit of the card ('h', 'd', 'c', 's')
    """
    __slots__ = ('rank', 'suit')  # Memory optimization

    def __init__(self, rank: str, suit: str) -> None:
        if rank not in RANKS:
            raise ValueError(f"Invalid rank: {rank}. Must be one of {RANKS}")
        if suit not in SUITS:
            raise ValueError(f"Invalid suit: {suit}. Must be one of {SUITS}")
        self.rank = rank
        self.suit = suit

    def __repr__(self) -> str:
        return f"{self.rank}{self.suit}"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Card):
            return NotImplemented
        return self.rank == other.rank and self.suit == other.suit

    def __hash__(self) -> int:
        return hash((self.rank, self.suit))

    def get_numeric_rank(self) -> int:
        """Returns the numeric rank of the card (2-14, where Ace is 14)."""
        return RANKS.index(self.rank) + 2

def create_deck() -> List[Card]:
    """Creates a standard 52-card deck."""
    return [Card(rank, suit) for rank in RANKS for suit in SUITS]

def get_hand_bucket(hand: List[Card]) -> str:
    """
    Determines the preflop bucket for a given two-card hand.

    Args:
        hand (List[Card]): A list containing exactly two Card objects.

    Returns:
        str: The name of the preflop bucket (e.g., 'AA', 'AKs', 'KQo').

    Raises:
        ValueError: If hand doesn't contain exactly two cards.
    """
    if len(hand) != 2:
        raise ValueError("Hand must consist of exactly two cards.")

    # Sort cards by rank (higher rank first)
    sorted_hand = sorted(hand, key=lambda card: RANKS.index(card.rank), reverse=True)
    card1, card2 = sorted_hand

    if card1.rank == card2.rank:
        return f"{card1.rank}{card2.rank}"  # Pocket pair
    elif card1.suit == card2.suit:
        return f"{card1.rank}{card2.rank}s"  # Suited
    else:
        return f"{card1.rank}{card2.rank}o"  # Offsuit

def main() -> None:
    """Demonstrates the poker hand bucketing implementation."""
    print("Poker Preflop Bucketing Implementation")
    print("=" * 40)

    # Create all possible 2-card combinations
    deck = create_deck()
    all_possible_hands = list(itertools.combinations(deck, 2))
    print(f"Total number of unique 2-card hands: {len(all_possible_hands)}")

    # Group hands into buckets
    preflop_buckets: Dict[str, List[Tuple[Card, Card]]] = {}
    for hand in all_possible_hands:
        bucket_name = get_hand_bucket(list(hand))
        preflop_buckets.setdefault(bucket_name, []).append(hand)

    # Verify results
    total_buckets = len(preflop_buckets)
    print(f"Total number of preflop buckets generated: {total_buckets}")

    if total_buckets == 169:
        print("Successfully generated all 169 preflop buckets.\n")
    else:
        print(f"Error: Expected 169 buckets, but got {total_buckets}.\n")

    # Display example buckets
    example_buckets = ['AA', 'AKs', 'AKo']
    for bucket in example_buckets:
        hands = preflop_buckets[bucket]
        print(f"Bucket '{bucket}': Contains {len(hands)} combinations.")
        print(f"   Hands: {[str(c1)+str(c2) for c1, c2 in hands][:6]}...")

    print("\n--- Verifying Hand Counts ---")
    print(f"Pocket pairs should have 6 combos. 'AA' has: {len(preflop_buckets['AA'])}")
    print(f"Suited hands should have 4 combos. 'AKs' has: {len(preflop_buckets['AKs'])}")
    print(f"Offsuit hands should have 12 combos. 'AKo' has: {len(preflop_buckets['AKo'])}")

if __name__ == "__main__":
    main()

Poker Preflop Bucketing Implementation
Total number of unique 2-card hands: 1326
Total number of preflop buckets generated: 169
Successfully generated all 169 preflop buckets.

Bucket 'AA': Contains 6 combinations.
   Hands: ['AhAd', 'AhAc', 'AhAs', 'AdAc', 'AdAs', 'AcAs']...
Bucket 'AKs': Contains 4 combinations.
   Hands: ['KhAh', 'KdAd', 'KcAc', 'KsAs']...
Bucket 'AKo': Contains 12 combinations.
   Hands: ['KhAd', 'KhAc', 'KhAs', 'KdAh', 'KdAc', 'KdAs']...

--- Verifying Hand Counts ---
Pocket pairs should have 6 combos. 'AA' has: 6
Suited hands should have 4 combos. 'AKs' has: 4
Offsuit hands should have 12 combos. 'AKo' has: 12


In [None]:
import random
import time
from phevaluator import evaluate_cards # Removed card_to_string and string_to_card

# --- Basic Game Definitions ---
SUITS = ['h', 'd', 'c', 's']
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
DECK = [r + s for r in RANKS for s in SUITS]
RANK_MAP = {rank: i for i, rank in enumerate(RANKS)}

class PotentialAwareCalculator:
    """
    Calculates a multi-dimensional feature vector for a poker hand,
    including current equity, flush potential, and straight potential.
    This is the implementation of the improved "potential-aware" methodology.
    """
    def __init__(self):
        self.cache = {}
        print("PotentialAwareCalculator initialized. Cache is currently empty.")

    def _calculate_current_equity(self, hand, board, remaining_deck, num_sims=500):
        """A helper to calculate just the raw winning percentage."""
        wins = 0
        cards_to_draw = 2 + (5 - len(board))
        for _ in range(num_sims):
            try:
                samples = random.sample(remaining_deck, cards_to_draw)
                opp_hand, board_runout = samples[:2], samples[2:]
                final_board = board + board_runout

                player_rank = evaluate_cards(*hand, *final_board)
                opp_rank = evaluate_cards(*opp_hand, *final_board)

                if player_rank < opp_rank: wins += 1
                elif player_rank == opp_rank: wins += 0.5
            except ValueError: continue
        return wins / num_sims if num_sims > 0 else 0

    def _has_straight_in_ranks(self, rank_set):
        """
        Checks if a given SET of numerical ranks contains a 5-card straight.
        This is the helper tool for the new _calculate_straight_outs method.
        """
        if len(rank_set) < 5:
            return False

        # Handle Ace-low straight (A, 2, 3, 4, 5) -> Indices {12, 0, 1, 2, 3}
        if all(x in rank_set for x in [12, 0, 1, 2, 3]):
            return True

        # Check for standard straights by sorting the set
        sorted_ranks = sorted(list(rank_set), reverse=True)
        for i in range(len(sorted_ranks) - 4):
            if sorted_ranks[i] - sorted_ranks[i+4] == 4:
                return True

        return False


    def _calculate_straight_outs(self, hand_and_board):
        """Enhanced with additional safety checks."""
        if len(hand_and_board) < 3:
            return 0  # Not enough cards for meaningful straight draws

        ranks_idx = [RANK_MAP[c[0]] for c in hand_and_board]
        rank_set = set(ranks_idx)
        straight_out_cards = set()

        for missing_rank_idx in range(13):
            if missing_rank_idx not in rank_set:
                test_ranks = rank_set | {missing_rank_idx}

                if self._has_straight_in_ranks(test_ranks):
                    completing_rank_str = RANKS[missing_rank_idx]
                    for suit in SUITS:
                        card = completing_rank_str + suit
                        if card not in hand_and_board:
                            straight_out_cards.add(card)

        return len(straight_out_cards)

    def _calculate_potential(self, hand, board, remaining_deck):
          """
          Calculates potential, now using the enhanced straight out calculation.
          """
          if len(board) == 5:  # River - no more potential
              return 0.0, 0.0


          hand_and_board = hand + board

          # --- Flush Potential Calculation (Unchanged) ---
          suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
          flush_potential = 0.0
          flush_draw_suit = next((s for s, count in suit_counts.items() if count == 4), None)

          if flush_draw_suit:
              flush_outs = 13 - suit_counts[flush_draw_suit]
              if len(board) == 3: # On the flop
                  p_miss_turn = (len(remaining_deck) - flush_outs) / len(remaining_deck)
                  p_miss_river = (len(remaining_deck) - 1 - flush_outs) / (len(remaining_deck) - 1)
                  flush_potential = 1 - (p_miss_turn * p_miss_river)
              elif len(board) == 4: # On the turn
                  flush_potential = flush_outs / len(remaining_deck)

          # --- Straight Potential Calculation (REPLACED with the new logic) ---
          straight_potential = 0.0
          # Only calculate straight potential if we don't already have a better hand.
          # We use the phevaluator rank to check: Flush or better is rank 1609 or lower.
          current_rank = evaluate_cards(*hand_and_board)
          if current_rank > 1609: # If not already a flush or better

              # *** THE KEY CHANGE IS HERE ***
              # We now call your new, precise function.
              num_straight_outs = self._calculate_straight_outs(hand_and_board)
              # *** END OF KEY CHANGE ***

              if num_straight_outs > 0:
                  if len(board) == 3: # On the flop
                      p_miss_turn = (len(remaining_deck) - num_straight_outs) / len(remaining_deck)
                      p_miss_river = (len(remaining_deck) - 1 - num_straight_outs) / (len(remaining_deck) - 1)
                      straight_potential = 1 - (p_miss_turn * p_miss_river)
                  elif len(board) == 4: # On the turn
                      straight_potential = num_straight_outs / len(remaining_deck)

          return flush_potential, straight_potential

    def _calculate_made_hand_potential(self, hand, board, remaining_deck):
        """Enhanced version with better edge case handling."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Your existing range checks are correct
        is_one_pair = 3003 <= player_rank <= 6185
        is_two_pair = 1610 <= player_rank <= 3002

        outs = 0

        if is_one_pair:
            # Check if it's a pocket pair
            if hand[0][0] == hand[1][0]:
                # Pocket pair - need to check if board is paired
                board_ranks = [c[0] for c in board]
                if hand[0][0] not in board_ranks:
                    outs = 2  # Set outs
                else:
                    # We have trips, calculate full house outs
                    remaining_ranks = set(board_ranks) - {hand[0][0]}
                    outs = len(remaining_ranks) * 3  # Any remaining rank can pair the board
            else:
                # Top pair or other pair - your logic is good but could be more robust
                # Consider kicker strength and multiple ways to improve
                outs = 5  # Your calculation is correct for standard case

        elif is_two_pair:
            # Your logic for two pair is correct
            outs = 4

        # Add case for trips wanting to make full house/quads
        elif 1610 > player_rank >= 167:  # This would be trips or better
            # Calculate full house/quads outs more precisely
            pass

        # Your probability calculation is correct

    def _calculate_board_connectivity(self, board):
        """More nuanced connectivity calculation."""
        if len(board) < 3:
            return 0.0

        board_ranks = sorted([RANK_MAP[c[0]] for c in board])

        # Count gaps and consecutive cards
        gaps = 0
        consecutive_pairs = 0

        for i in range(len(board_ranks) - 1):
            gap = board_ranks[i+1] - board_ranks[i]
            if gap == 1:
                consecutive_pairs += 1
            elif gap > 3:  # Large gaps reduce connectivity more
                gaps += gap - 1

        # Handle wheel connectivity (A-2-3, A-2-4, etc.)
        if 12 in board_ranks and 0 in board_ranks:  # A and 2 present
            consecutive_pairs += 0.5  # Partial credit for wheel potential

        # Normalize: high consecutive pairs = high connectivity
        max_consecutive = len(board) - 1
        connectivity = consecutive_pairs / max_consecutive if max_consecutive > 0 else 0

        # Penalize for gaps
        connectivity = max(0, connectivity - gaps * 0.1)

        return min(1.0, connectivity)

    def _is_nut_hand(self, hand, board):
        """More accurate nut hand detection."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Generate all possible opponent hands to see if we can be beaten
        remaining_deck = [card for card in DECK if card not in hand_and_board]

        # Quick check for obvious nuts
        if player_rank <= 10:  # Straight flush
            return 1.0
        if player_rank <= 166:  # Four of a kind - usually nuts
            return 0.9  # Small chance opponent has better quads

        # For other hands, sample opponent possibilities
        beaten_count = 0
        total_samples = 100  # Limit for performance

        for _ in range(total_samples):
            try:
                opp_hand = random.sample(remaining_deck, 2)
                opp_rank = evaluate_cards(*opp_hand, *board)
                if opp_rank < player_rank:  # Opponent wins
                    beaten_count += 1
            except ValueError:
                continue

        # Return inverted ratio (1.0 = never beaten = nuts)
        return max(0.0, 1.0 - (beaten_count / total_samples))

    def _calculate_board_texture(self, board):
        """
        Calculates metrics for board suitedness and connectivity.
        Returns a tuple: (board_suitedness, board_connectivity).
        """
        # --- 1. Board Suitedness ---
        board_suits = {card[1] for card in board}
        if len(board_suits) == 1:
            board_suitedness = 1.0  # Monotone
        elif len(board_suits) == 2:
            board_suitedness = 0.5  # Two-tone
        else:
            board_suitedness = 0.0  # Rainbow

        # --- 2. Board Connectivity ---
        board_ranks_idx = sorted([RANK_MAP[c[0]] for c in board])
        # Max possible span on flop is 9 (e.g., K-Q-2 -> 11-1 = 10; A-7-2 -> 12-0 = 12)
        # Let's normalize from 0 (gapped, e.g., K-7-2) to 1 (connected, e.g., 8-7-6).
        max_span = board_ranks_idx[-1] - board_ranks_idx[0]
        # A span of 2 for a 3-card board (e.g., 8-7-6) is max connectivity.
        # A span of 12 (A-2-X) is min connectivity.
        # Normalize it to a 0-1 range where 1 is highly connected.
        normalized_span = max_span / 12.0
        board_connectivity = 1.0 - normalized_span

        return board_suitedness, board_connectivity

    def calculate_feature_vector(self, hand, board):
        """Updated to include all new features."""
        # ... existing cache logic ...

        known_cards = hand + board
        remaining_deck = [card for card in DECK if card not in known_cards]

        # Original features
        equity = self._calculate_current_equity(hand, board, remaining_deck)
        flush_pot, straight_pot = self._calculate_potential(hand, board, remaining_deck)

        # New features
        made_hand_pot = self._calculate_made_hand_potential(hand, board, remaining_deck)
        board_suit, board_conn = self._calculate_board_texture(board)
        backdoor_pot = self._calculate_backdoor_potential(hand, board)
        nut_strength = self._is_nut_hand(hand, board)

        feature_vector = [
            equity,           # Current winning probability
            flush_pot,        # Flush potential
            straight_pot,     # Straight potential
            made_hand_pot,    # Made hand improvement potential
            board_suit,       # Board suitedness
            board_conn,       # Board connectivity
            backdoor_pot,     # Backdoor potential
            nut_strength      # Relative hand strength
        ]

        self.cache[cache_key] = feature_vector
        return feature_vector

    def _categorize_straight_draw(self, hand_and_board, num_outs):
          """Optional: categorize the type of straight draw."""
          if num_outs == 8:
              return "open_ended"  # Very strong draw
          elif num_outs == 4:
              return "gutshot"     # Moderate draw
          elif num_outs > 8:
              return "wrap"        # Multiple straight possibilities
          else:
              return "weak"        # Unusual or weak drawuckets.") '''




In [None]:
!pip install phevaluator

Collecting phevaluator
  Downloading phevaluator-0.5.3.1-py3-none-any.whl.metadata (3.9 kB)
Downloading phevaluator-0.5.3.1-py3-none-any.whl (3.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: phevaluator
Successfully installed phevaluator-0.5.3.1


In [None]:
def _calculate_potential(self, hand, board, remaining_deck):
    """
    Calculates potential, now using the enhanced straight out calculation.
    """
    if len(board) == 5:  # River - no more potential
        return 0.0, 0.0


    hand_and_board = hand + board

    # --- Flush Potential Calculation (Unchanged) ---
    suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
    flush_potential = 0.0
    flush_draw_suit = next((s for s, count in suit_counts.items() if count == 4), None)

    if flush_draw_suit:
        flush_outs = 13 - suit_counts[flush_draw_suit]
        if len(board) == 3: # On the flop
            p_miss_turn = (len(remaining_deck) - flush_outs) / len(remaining_deck)
            p_miss_river = (len(remaining_deck) - 1 - flush_outs) / (len(remaining_deck) - 1)
            flush_potential = 1 - (p_miss_turn * p_miss_river)
        elif len(board) == 4: # On the turn
            flush_potential = flush_outs / len(remaining_deck)

    # --- Straight Potential Calculation (REPLACED with the new logic) ---
    straight_potential = 0.0
    # Only calculate straight potential if we don't already have a better hand.
    # We use the phevaluator rank to check: Flush or better is rank 1609 or lower.
    current_rank = evaluate_cards(*hand_and_board)
    if current_rank > 1609: # If not already a flush or better

        # *** THE KEY CHANGE IS HERE ***
        # We now call your new, precise function.
        num_straight_outs = self._calculate_straight_outs(hand_and_board)
        # *** END OF KEY CHANGE ***

        if num_straight_outs > 0:
            if len(board) == 3: # On the flop
                p_miss_turn = (len(remaining_deck) - num_straight_outs) / len(remaining_deck)
                p_miss_river = (len(remaining_deck) - 1 - num_straight_outs) / (len(remaining_deck) - 1)
                straight_potential = 1 - (p_miss_turn * p_miss_river)
            elif len(board) == 4: # On the turn
                straight_potential = num_straight_outs / len(remaining_deck)

    return flush_potential, straight_potential

def _categorize_straight_draw(self, hand_and_board, num_outs):
    """Optional: categorize the type of straight draw."""
    if num_outs == 8:
        return "open_ended"  # Very strong draw
    elif num_outs == 4:
        return "gutshot"     # Moderate draw
    elif num_outs > 8:
        return "wrap"        # Multiple straight possibilities
    else:
        return "weak"        # Unusual or weak draw

In [None]:
def _calculate_made_hand_potential(self, hand, board, remaining_deck):
    """Enhanced version with better edge case handling."""
    hand_and_board = hand + board
    player_rank = evaluate_cards(*hand_and_board)

    # Your existing range checks are correct
    is_one_pair = 3003 <= player_rank <= 6185
    is_two_pair = 1610 <= player_rank <= 3002

    outs = 0

    if is_one_pair:
        # Check if it's a pocket pair
        if hand[0][0] == hand[1][0]:
            # Pocket pair - need to check if board is paired
            board_ranks = [c[0] for c in board]
            if hand[0][0] not in board_ranks:
                outs = 2  # Set outs
            else:
                # We have trips, calculate full house outs
                remaining_ranks = set(board_ranks) - {hand[0][0]}
                outs = len(remaining_ranks) * 3  # Any remaining rank can pair the board
        else:
            # Top pair or other pair - your logic is good but could be more robust
            # Consider kicker strength and multiple ways to improve
            outs = 5  # Your calculation is correct for standard case

    elif is_two_pair:
        # Your logic for two pair is correct
        outs = 4

    # Add case for trips wanting to make full house/quads
    elif 1610 > player_rank >= 167:  # This would be trips or better
        # Calculate full house/quads outs more precisely
        pass

    # Your probability calculation is correct

    def _calculate_board_connectivity(self, board):
        """More nuanced connectivity calculation."""
        if len(board) < 3:
            return 0.0

        board_ranks = sorted([RANK_MAP[c[0]] for c in board])

        # Count gaps and consecutive cards
        gaps = 0
        consecutive_pairs = 0

        for i in range(len(board_ranks) - 1):
            gap = board_ranks[i+1] - board_ranks[i]
            if gap == 1:
                consecutive_pairs += 1
            elif gap > 3:  # Large gaps reduce connectivity more
                gaps += gap - 1

        # Handle wheel connectivity (A-2-3, A-2-4, etc.)
        if 12 in board_ranks and 0 in board_ranks:  # A and 2 present
            consecutive_pairs += 0.5  # Partial credit for wheel potential

        # Normalize: high consecutive pairs = high connectivity
        max_consecutive = len(board) - 1
        connectivity = consecutive_pairs / max_consecutive if max_consecutive > 0 else 0

        # Penalize for gaps
        connectivity = max(0, connectivity - gaps * 0.1)

        return min(1.0, connectivity)

    def _is_nut_hand(self, hand, board):
        """More accurate nut hand detection."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Generate all possible opponent hands to see if we can be beaten
        remaining_deck = [card for card in DECK if card not in hand_and_board]

        # Quick check for obvious nuts
        if player_rank <= 10:  # Straight flush
            return 1.0
        if player_rank <= 166:  # Four of a kind - usually nuts
            return 0.9  # Small chance opponent has better quads

        # For other hands, sample opponent possibilities
        beaten_count = 0
        total_samples = 100  # Limit for performance

        for _ in range(total_samples):
            try:
                opp_hand = random.sample(remaining_deck, 2)
                opp_rank = evaluate_cards(*opp_hand, *board)
                if opp_rank < player_rank:  # Opponent wins
                    beaten_count += 1
            except ValueError:
                continue

        # Return inverted ratio (1.0 = never beaten = nuts)
        return max(0.0, 1.0 - (beaten_count / total_samples))

    def _calculate_board_texture(self, board):
        """
        Calculates metrics for board suitedness and connectivity.
        Returns a tuple: (board_suitedness, board_connectivity).
        """
        # --- 1. Board Suitedness ---
        board_suits = {card[1] for card in board}
        if len(board_suits) == 1:
            board_suitedness = 1.0  # Monotone
        elif len(board_suits) == 2:
            board_suitedness = 0.5  # Two-tone
        else:
            board_suitedness = 0.0  # Rainbow

        # --- 2. Board Connectivity ---
        board_ranks_idx = sorted([RANK_MAP[c[0]] for c in board])
        # Max possible span on flop is 9 (e.g., K-Q-2 -> 11-1 = 10; A-7-2 -> 12-0 = 12)
        # Let's normalize from 0 (gapped, e.g., K-7-2) to 1 (connected, e.g., 8-7-6).
        max_span = board_ranks_idx[-1] - board_ranks_idx[0]
        # A span of 2 for a 3-card board (e.g., 8-7-6) is max connectivity.
        # A span of 12 (A-2-X) is min connectivity.
        # Normalize it to a 0-1 range where 1 is highly connected.
        normalized_span = max_span / 12.0
        board_connectivity = 1.0 - normalized_span

        return board_suitedness, board_connectivity

    def calculate_feature_vector(self, hand, board):
        """Updated to include all new features."""
        # ... existing cache logic ...

        known_cards = hand + board
        remaining_deck = [card for card in DECK if card not in known_cards]

        # Original features
        equity = self._calculate_current_equity(hand, board, remaining_deck)
        flush_pot, straight_pot = self._calculate_potential(hand, board, remaining_deck)

        # New features
        made_hand_pot = self._calculate_made_hand_potential(hand, board, remaining_deck)
        board_suit, board_conn = self._calculate_board_texture(board)
        backdoor_pot = self._calculate_backdoor_potential(hand, board)
        nut_strength = self._is_nut_hand(hand, board)

        feature_vector = [
            equity,           # Current winning probability
            flush_pot,        # Flush potential
            straight_pot,     # Straight potential
            made_hand_pot,    # Made hand improvement potential
            board_suit,       # Board suitedness
            board_conn,       # Board connectivity
            backdoor_pot,     # Backdoor potential
            nut_strength      # Relative hand strength
        ]

        self.cache[cache_key] = feature_vector
        return feature_vector




In [None]:
def test_straight_calculation():
    """Validate the new straight calculation logic."""
    calc = PotentialAwareCalculator()

    # Test case 1: Open-ended straight draw
    hand1 = ['9h', '8c']
    board1 = ['7d', '6s', '2h']
    outs1 = calc._calculate_straight_outs(hand1 + board1)
    assert outs1 == 8, f"Expected 8 outs, got {outs1}"

    # Test case 2: Gutshot
    hand2 = ['9h', '8c']
    board2 = ['6d', '5s', '2h']
    outs2 = calc._calculate_straight_outs(hand2 + board2)
    assert outs2 == 4, f"Expected 4 outs, got {outs2}"

    print("All straight calculation tests passed!")

In [None]:
test_straight_calculation()

PotentialAwareCalculator initialized. Cache is currently empty.
All straight calculation tests passed!


In [None]:
import numpy as np
import random
import time
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import pickle
from collections import defaultdict

# Assuming your PotentialAwareCalculator is imported
# from your_module import PotentialAwareCalculator

class FixedAbstractionManager:
    """
    Fixed version of AbstractionManager with proper data generation
    and robust error handling.
    """
    def __init__(self, n_clusters_per_round=20):
        self.calculator = PotentialAwareCalculator()
        self.n_clusters = n_clusters_per_round
        self.kmeans_models = {}
        self.cluster_offsets = {'flop': 0, 'turn': 20, 'river': 40}

        # Card definitions
        self.SUITS = ['h', 'd', 'c', 's']
        self.RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
        self.DECK = [r + s for r in self.RANKS for s in self.SUITS]

        print("Fixed AbstractionManager initialized.")

    def _generate_valid_hand_board(self, num_board_cards):
        """Generate a valid random hand and board combination."""
        max_attempts = 100

        for attempt in range(max_attempts):
            try:
                # Shuffle deck and deal cards
                deck = self.DECK.copy()
                random.shuffle(deck)

                # Deal 2 hole cards + board cards
                total_cards_needed = 2 + num_board_cards
                if total_cards_needed > len(deck):
                    continue

                hand = deck[:2]
                board = deck[2:2 + num_board_cards]

                # Validate that we have the required cards
                if len(hand) != 2 or len(board) != num_board_cards:
                    continue

                # Test that the calculator can handle this combination
                test_vector = self.calculator.calculate_feature_vector(hand, board)

                # Validate the feature vector
                if (isinstance(test_vector, list) and
                    len(test_vector) > 0 and
                    all(isinstance(x, (int, float)) and not np.isnan(x) for x in test_vector)):
                    return hand, board

            except Exception as e:
                continue  # Try again with different cards

        # Fallback: return a known good combination
        return ['As', 'Kh'], ['Qc', '7d', '2s'][:num_board_cards]

    def _generate_data_for_round(self, n_samples, round_name):
        """Generate valid feature vectors for clustering."""
        print(f"--- Generating {n_samples} feature vectors for {round_name} ---")

        feature_vectors = []
        num_board_cards = {'flop': 3, 'turn': 4, 'river': 5}[round_name]
        failed_generations = 0

        start_time = time.time()

        for i in range(n_samples * 2):  # Generate extra to account for failures
            if len(feature_vectors) >= n_samples:
                break

            if (i + 1) % 1000 == 0:
                print(f"  ...attempting {i+1}, collected {len(feature_vectors)}/{n_samples}...")

            try:
                # Generate valid hand and board
                hand, board = self._generate_valid_hand_board(num_board_cards)

                # Calculate feature vector
                vector = self.calculator.calculate_feature_vector(hand, board)

                # Validate vector
                if self._validate_feature_vector(vector):
                    feature_vectors.append(vector)
                else:
                    failed_generations += 1

            except Exception as e:
                failed_generations += 1
                continue

        duration = time.time() - start_time

        if len(feature_vectors) < self.n_clusters:
            raise ValueError(f"Only generated {len(feature_vectors)} valid vectors, "
                           f"need at least {self.n_clusters} for clustering")

        print(f"Generated {len(feature_vectors)} valid vectors "
              f"({failed_generations} failed) in {duration:.2f}s")

        return np.array(feature_vectors)

    def _validate_feature_vector(self, vector):
        """Validate that a feature vector is suitable for clustering."""
        if not isinstance(vector, (list, tuple, np.ndarray)):
            return False

        if len(vector) == 0:
            return False

        # Check for NaN or infinite values
        for val in vector:
            if not isinstance(val, (int, float)):
                return False
            if np.isnan(val) or np.isinf(val):
                return False

        # Check that values are in reasonable ranges (0-1 for probabilities)
        for val in vector:
            if val < 0 or val > 1:
                return False

        return True

    def _remove_duplicate_vectors(self, vectors):
        """Remove duplicate feature vectors to improve clustering."""
        unique_vectors = []
        seen = set()

        for vector in vectors:
            vector_tuple = tuple(np.round(vector, 6))  # Round to avoid floating point issues
            if vector_tuple not in seen:
                seen.add(vector_tuple)
                unique_vectors.append(vector)

        print(f"Removed {len(vectors) - len(unique_vectors)} duplicate vectors")
        return np.array(unique_vectors)

    def train_all_postflop_models(self, n_samples_per_round=10000):
        """Train K-Means models with proper error handling."""
        print("\n=== STARTING K-MEANS MODEL TRAINING FOR ALL ROUNDS ===")

        for round_name in ['flop', 'turn', 'river']:
            try:
                print(f"\n--- Training {round_name} model ---")

                # Generate data
                data = self._generate_data_for_round(n_samples_per_round, round_name)

                # Remove duplicates
                data = self._remove_duplicate_vectors(data)

                # Final validation
                if len(data) < self.n_clusters:
                    print(f"Warning: Only {len(data)} unique vectors for {self.n_clusters} clusters")
                    # Reduce cluster count if necessary
                    actual_clusters = min(self.n_clusters, len(data))
                else:
                    actual_clusters = self.n_clusters

                # Reshape data if needed (should be 2D)
                if data.ndim == 1:
                    data = data.reshape(-1, 1)

                print(f"Training K-Means on {data.shape[0]} vectors with {data.shape[1]} features")
                print(f"Using {actual_clusters} clusters")

                # Train K-Means
                kmeans = KMeans(
                    n_clusters=actual_clusters,
                    random_state=42,
                    n_init=10,
                    max_iter=300
                )

                kmeans.fit(data)

                # Store the model
                self.kmeans_models[round_name] = kmeans

                # Evaluate clustering quality
                if len(data) > actual_clusters:
                    labels = kmeans.predict(data)
                    silhouette = silhouette_score(data, labels)
                    print(f"Silhouette Score: {silhouette:.3f}")

                print(f"{round_name} model trained successfully!")

            except Exception as e:
                print(f"Error training {round_name} model: {e}")
                # Create a dummy model to prevent crashes
                self._create_fallback_model(round_name)

    def _create_fallback_model(self, round_name):
        """Create a simple fallback model if training fails."""
        print(f"Creating fallback model for {round_name}")

        # Create simple dummy data
        dummy_data = np.random.rand(self.n_clusters * 10, 8)  # 8 features

        kmeans = KMeans(n_clusters=self.n_clusters, random_state=42)
        kmeans.fit(dummy_data)

        self.kmeans_models[round_name] = kmeans

    def get_postflop_bucket(self, hand, board):
        """Get bucket ID with proper error handling."""
        try:
            # Validate inputs
            if not isinstance(hand, list) or len(hand) != 2:
                raise ValueError("Hand must be a list of 2 cards")
            if not isinstance(board, list) or len(board) not in [3, 4, 5]:
                raise ValueError("Board must have 3, 4, or 5 cards")

            # Determine round
            num_board_cards = len(board)
            round_map = {3: 'flop', 4: 'turn', 5: 'river'}
            round_name = round_map[num_board_cards]

            if round_name not in self.kmeans_models:
                raise RuntimeError(f"No trained model for {round_name}")

            # Get feature vector
            feature_vector = self.calculator.calculate_feature_vector(hand, board)

            # Validate feature vector
            if not self._validate_feature_vector(feature_vector):
                print(f"Warning: Invalid feature vector {feature_vector}, using fallback")
                return self.cluster_offsets[round_name]  # Return first bucket as fallback

            # Predict cluster
            model = self.kmeans_models[round_name]
            vector_array = np.array([feature_vector])  # Ensure 2D shape

            predicted_cluster = model.predict(vector_array)[0]

            # Return final bucket ID
            bucket_id = predicted_cluster + self.cluster_offsets[round_name]

            # Ensure bucket ID is in valid range
            min_bucket = self.cluster_offsets[round_name]
            max_bucket = min_bucket + self.n_clusters - 1
            bucket_id = max(min_bucket, min(max_bucket, bucket_id))

            return int(bucket_id)

        except Exception as e:
            print(f"Error in get_postflop_bucket: {e}")
            # Return a safe fallback bucket
            round_name = {3: 'flop', 4: 'turn', 5: 'river'}.get(len(board), 'flop')
            return self.cluster_offsets[round_name]

    def save_models(self, filepath):
        """Save trained models to disk."""
        try:
            model_data = {
                'kmeans_models': self.kmeans_models,
                'n_clusters': self.n_clusters,
                'cluster_offsets': self.cluster_offsets
            }

            with open(filepath, 'wb') as f:
                pickle.dump(model_data, f)

            print(f"Models saved to {filepath}")

        except Exception as e:
            print(f"Error saving models: {e}")

    def load_models(self, filepath):
        """Load pre-trained models from disk."""
        try:
            with open(filepath, 'rb') as f:
                model_data = pickle.load(f)

            self.kmeans_models = model_data['kmeans_models']
            self.n_clusters = model_data['n_clusters']
            self.cluster_offsets = model_data['cluster_offsets']

            print(f"Models loaded from {filepath}")

        except Exception as e:
            print(f"Error loading models: {e}")

    def validate_all_models(self, test_samples=100):
        """Test all models with random data to ensure they work."""
        print("\n=== Validating All Models ===")

        for round_name in ['flop', 'turn', 'river']:
            if round_name not in self.kmeans_models:
                print(f"No model found for {round_name}")
                continue

            num_board_cards = {'flop': 3, 'turn': 4, 'river': 5}[round_name]
            success_count = 0

            print(f"Testing {round_name} model...")

            for _ in range(test_samples):
                try:
                    hand, board = self._generate_valid_hand_board(num_board_cards)
                    bucket_id = self.get_postflop_bucket(hand, board)

                    # Check bucket ID is in valid range
                    min_bucket = self.cluster_offsets[round_name]
                    max_bucket = min_bucket + self.n_clusters - 1

                    if min_bucket <= bucket_id <= max_bucket:
                        success_count += 1

                except Exception as e:
                    continue

            success_rate = success_count / test_samples
            print(f"{round_name}: {success_count}/{test_samples} successful ({success_rate:.1%})")


# Example usage with proper error handling
if __name__ == "__main__":
    print("=== Testing Fixed Abstraction Manager ===")

    try:
        # Create manager
        manager = FixedAbstractionManager(n_clusters_per_round=10)  # Use fewer clusters for testing

        # Train models
        manager.train_all_postflop_models(n_samples_per_round=5000)  # Smaller sample for testing

        # Validate models
        manager.validate_all_models(test_samples=50)

        # Test individual predictions
        print("\n=== Testing Individual Predictions ===")
        test_cases = [
            (['As', 'Kh'], ['Qc', '7d', '2s']),  # Flop
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),  # Turn
            (['Ah', '2h'], ['3h', '4h', '5c', 'Kd', '9s'])  # River
        ]

        for hand, board in test_cases:
            try:
                bucket = manager.get_postflop_bucket(hand, board)
                print(f"Hand {hand} + Board {board} -> Bucket {bucket}")
            except Exception as e:
                print(f"Error with {hand} + {board}: {e}")

        # Save models
        manager.save_models("test_models.pkl")
        print("\nAll tests completed successfully!")

    except Exception as e:
        print(f"Critical error: {e}")
        import traceback
        traceback.print_exc()

=== Testing Fixed Abstraction Manager ===
Optimized PotentialAwareCalculator initialized.
Fixed AbstractionManager initialized.

=== STARTING K-MEANS MODEL TRAINING FOR ALL ROUNDS ===

--- Training flop model ---
--- Generating 5000 feature vectors for flop ---
  ...attempting 1000, collected 999/5000...
  ...attempting 2000, collected 1999/5000...
  ...attempting 3000, collected 2999/5000...
  ...attempting 4000, collected 3999/5000...
  ...attempting 5000, collected 4999/5000...
Generated 5000 valid vectors (0 failed) in 10.29s
Removed 12 duplicate vectors
Training K-Means on 4988 vectors with 8 features
Using 10 clusters
Silhouette Score: 0.272
flop model trained successfully!

--- Training turn model ---
--- Generating 5000 feature vectors for turn ---
  ...attempting 1000, collected 999/5000...
  ...attempting 2000, collected 1999/5000...
  ...attempting 3000, collected 2999/5000...
  ...attempting 4000, collected 3999/5000...
  ...attempting 5000, collected 4999/5000...
Generated

In [None]:
import random
import time
from phevaluator import evaluate_cards

# --- Basic Game Definitions ---
SUITS = ['h', 'd', 'c', 's']
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
DECK = [r + s for r in RANKS for s in SUITS]
RANK_MAP = {rank: i for i, rank in enumerate(RANKS)}

class PotentialAwareCalculator:
    """
    Optimized version with reduced computational overhead for clustering.
    """
    def __init__(self):
        self.cache = {}
        print("Optimized PotentialAwareCalculator initialized.")

    def _calculate_current_equity(self, hand, board, remaining_deck, num_sims=100):  # Reduced from 500
        """Reduced simulation count for faster performance."""
        wins = 0
        cards_to_draw = 2 + (5 - len(board))

        # Early return if not enough cards
        if len(remaining_deck) < cards_to_draw:
            return 0.0

        for _ in range(num_sims):
            try:
                samples = random.sample(remaining_deck, cards_to_draw)
                opp_hand, board_runout = samples[:2], samples[2:]
                final_board = board + board_runout

                player_rank = evaluate_cards(*hand, *final_board)
                opp_rank = evaluate_cards(*opp_hand, *final_board)

                if player_rank < opp_rank:
                    wins += 1
                elif player_rank == opp_rank:
                    wins += 0.5
            except (ValueError, IndexError):
                continue
        return wins / num_sims if num_sims > 0 else 0

    def _has_straight_in_ranks(self, rank_set):
        """Optimized straight detection."""
        if len(rank_set) < 5:
            return False

        # Handle Ace-low straight
        if all(x in rank_set for x in [12, 0, 1, 2, 3]):
            return True

        # Check for standard straights
        sorted_ranks = sorted(list(rank_set))
        for i in range(len(sorted_ranks) - 4):
            if sorted_ranks[i+4] - sorted_ranks[i] == 4:
                return True
        return False

    def _calculate_straight_outs(self, hand_and_board):
        """Simplified straight outs calculation."""
        if len(hand_and_board) < 3:
            return 0

        ranks_idx = [RANK_MAP[c[0]] for c in hand_and_board]
        rank_set = set(ranks_idx)

        # Quick heuristic: if we have 4+ to a straight, likely 4-8 outs
        # This is much faster than the exact calculation
        if len(rank_set) >= 4:
            # Check for obvious straight draws
            sorted_ranks = sorted(rank_set)
            max_gap = max(sorted_ranks[i+1] - sorted_ranks[i] for i in range(len(sorted_ranks)-1))
            if max_gap <= 4:  # Reasonable straight potential
                return 8 if max_gap <= 2 else 4

        return 0

    def _calculate_potential(self, hand, board, remaining_deck):
        """Simplified potential calculation."""
        if len(board) == 5:
            return 0.0, 0.0

        hand_and_board = hand + board

        # Flush potential
        suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
        flush_potential = 0.0

        for suit, count in suit_counts.items():
            if count == 4:  # Flush draw
                remaining_suited = 13 - count
                if len(board) == 3:  # Flop
                    flush_potential = 1 - ((len(remaining_deck) - remaining_suited) / len(remaining_deck))**2
                elif len(board) == 4:  # Turn
                    flush_potential = remaining_suited / len(remaining_deck)
                break

        # Straight potential (simplified)
        straight_potential = 0.0
        current_rank = evaluate_cards(*hand_and_board)

        if current_rank > 1609:  # Not already a strong hand
            straight_outs = self._calculate_straight_outs(hand_and_board)
            if straight_outs > 0:
                if len(board) == 3:
                    straight_potential = 1 - ((len(remaining_deck) - straight_outs) / len(remaining_deck))**2
                elif len(board) == 4:
                    straight_potential = straight_outs / len(remaining_deck)

        return flush_potential, straight_potential

    def _calculate_made_hand_potential(self, hand, board, remaining_deck):
        """Simplified made hand potential."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Simple heuristic based on hand strength
        if 3003 <= player_rank <= 6185:  # One pair
            return 0.3  # Some potential to improve
        elif 1610 <= player_rank <= 3002:  # Two pair
            return 0.2  # Less potential
        elif player_rank < 1610:  # Strong hand
            return 0.1  # Little potential needed
        else:
            return 0.4  # High card, needs improvement

    def _calculate_backdoor_potential(self, hand, board):
        """Simple backdoor potential calculation."""
        if len(board) != 3:  # Only relevant on flop
            return 0.0

        hand_and_board = hand + board

        # Backdoor flush potential
        suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
        max_suit_count = max(suit_counts.values())

        backdoor_flush = 0.1 if max_suit_count == 3 else 0.0

        # Backdoor straight potential (simplified)
        ranks = set(RANK_MAP[c[0]] for c in hand_and_board)
        backdoor_straight = 0.1 if len(ranks) >= 3 else 0.0

        return max(backdoor_flush, backdoor_straight)

    def _calculate_board_texture(self, board):
        """Board texture calculation."""
        if len(board) < 3:
            return 0.0, 0.0

        # Suitedness
        board_suits = {card[1] for card in board}
        if len(board_suits) == 1:
            board_suitedness = 1.0
        elif len(board_suits) == 2:
            board_suitedness = 0.5
        else:
            board_suitedness = 0.0

        # Connectivity
        board_ranks_idx = sorted([RANK_MAP[c[0]] for c in board])
        max_span = board_ranks_idx[-1] - board_ranks_idx[0]
        board_connectivity = max(0.0, 1.0 - max_span / 12.0)

        return board_suitedness, board_connectivity

    def _is_nut_hand(self, hand, board):
        """Simplified nut hand detection."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Quick heuristic based on hand rank
        if player_rank <= 10:  # Straight flush
            return 1.0
        elif player_rank <= 166:  # Four of a kind
            return 0.9
        elif player_rank <= 322:  # Full house
            return 0.8
        elif player_rank <= 1609:  # Flush
            return 0.7
        elif player_rank <= 1609:  # Straight
            return 0.6
        else:
            return min(1.0, (7463 - player_rank) / 7463)  # Normalize remaining ranks

    def calculate_feature_vector(self, hand, board):
        """Main method to calculate feature vector."""
        # Simple cache key
        cache_key = tuple(sorted(hand + board))
        if cache_key in self.cache:
            return self.cache[cache_key]

        known_cards = hand + board
        remaining_deck = [card for card in DECK if card not in known_cards]

        if len(remaining_deck) < 2:  # Not enough cards for simulation
            return [0.5] * 8  # Return neutral values

        try:
            # Calculate all features
            equity = self._calculate_current_equity(hand, board, remaining_deck)
            flush_pot, straight_pot = self._calculate_potential(hand, board, remaining_deck)
            made_hand_pot = self._calculate_made_hand_potential(hand, board, remaining_deck)
            board_suit, board_conn = self._calculate_board_texture(board)
            backdoor_pot = self._calculate_backdoor_potential(hand, board)
            nut_strength = self._is_nut_hand(hand, board)

            feature_vector = [
                equity,           # Current winning probability
                flush_pot,        # Flush potential
                straight_pot,     # Straight potential
                made_hand_pot,    # Made hand improvement potential
                board_suit,       # Board suitedness
                board_conn,       # Board connectivity
                backdoor_pot,     # Backdoor potential
                nut_strength      # Relative hand strength
            ]

            # Ensure all values are valid
            feature_vector = [max(0.0, min(1.0, float(x))) for x in feature_vector]

            self.cache[cache_key] = feature_vector
            return feature_vector

        except Exception as e:
            print(f"Error calculating feature vector: {e}")
            return [0.5] * 8  # Return neutral values on error

In [None]:
import numpy as np
import random
from collections import defaultdict, Counter
import pickle
import time
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from enum import Enum

# Assuming your abstraction system is imported
# from abstraction_engine import AbstractionManager, PotentialAwareCalculator

class Action(Enum):
    FOLD = "fold"
    CHECK = "check"
    CALL = "call"
    BET = "bet"
    RAISE = "raise"

@dataclass
class GameState:
    """Represents the current state of a poker game."""
    players: List[int]  # Active players
    current_player: int
    pot_size: float
    street: str  # 'preflop', 'flop', 'turn', 'river'
    board: List[str]
    hands: Dict[int, List[str]]  # Player ID -> hole cards
    bets: Dict[int, float]  # Current street bets
    total_bets: Dict[int, float]  # Total bets across all streets
    action_history: List[Tuple[int, Action, float]]
    is_terminal: bool = False

class AdvancedMCCFRTrainer:
    """
    Advanced Monte Carlo Counterfactual Regret Minimization trainer
    integrated with sophisticated poker abstraction system.
    """

    def __init__(self, abstraction_manager=None, initial_stack=1000.0):
        print("Initializing Advanced MCCFR Trainer...")

        # Initialize abstraction system
        if abstraction_manager is None:
            self.abstraction_manager = AbstractionManager(n_clusters_per_round=20)
            print("Training abstraction models...")
            self.abstraction_manager.train_all_postflop_models(n_samples_per_round=50000)
        else:
            self.abstraction_manager = abstraction_manager

        self.initial_stack = initial_stack

        # Core CFR data structures
        self.regret_sum = defaultdict(lambda: defaultdict(float))
        self.strategy_sum = defaultdict(lambda: defaultdict(float))
        self.policy = defaultdict(lambda: defaultdict(float))

        # Training statistics
        self.iteration_count = 0
        self.strategy_updates = 0

        # Preflop abstraction (simplified - you'd want a more sophisticated system)
        self._initialize_preflop_buckets()

        print("MCCFR Trainer initialized successfully.")

    def _initialize_preflop_buckets(self):
        """Initialize preflop hand strength buckets."""
        # This is a simplified preflop abstraction
        # In practice, you'd want a more sophisticated system
        self.preflop_buckets = {}

        # Example: Group hands by basic strength categories
        premium_hands = ['AA', 'KK', 'QQ', 'JJ', 'AKs', 'AKo']
        strong_hands = ['TT', '99', 'AQs', 'AQo', 'AJs', 'KQs']
        medium_hands = ['88', '77', 'AJo', 'KQo', 'KJs', 'QJs', 'ATs']

        # Assign bucket IDs (60-79 for preflop to avoid collision with postflop)
        bucket_id = 60
        for hand_group in [premium_hands, strong_hands, medium_hands]:
            for hand in hand_group:
                self.preflop_buckets[hand] = bucket_id
            bucket_id += 1

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Convert hole cards to preflop bucket ID."""
        # Convert cards to standard notation (e.g., ['As', 'Kh'] -> 'AKo')
        ranks = [card[0] for card in hand]
        suits = [card[1] for card in hand]

        # Sort ranks for consistent representation
        rank_order = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
        sorted_ranks = sorted(ranks, key=lambda x: rank_order.index(x), reverse=True)

        if ranks[0] == ranks[1]:  # Pocket pair
            hand_str = ranks[0] + ranks[0]
        else:
            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(sorted_ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)  # Default to weakest bucket

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """
        Convert game state to information state string using abstraction.
        This is the critical bridge between complex poker state and CFR.
        """
        hand = state.hands[player_id]
        board = state.board

        # Get bucket ID using our sophisticated abstraction
        if state.street == 'preflop':
            bucket_id = self.get_preflop_bucket(hand)
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # Create action history string
        action_history = self._get_action_sequence(state, player_id)

        # Include position and betting context
        position = self._get_position_context(state, player_id)
        pot_odds = self._get_pot_odds_context(state, player_id)

        info_state = f"B{bucket_id}|H{action_history}|P{position}|O{pot_odds}"
        return info_state

    def _get_action_sequence(self, state: GameState, player_id: int) -> str:
        """Convert action history to compact string representation."""
        # Group actions by street and player
        street_actions = {
            'preflop': [],
            'flop': [],
            'turn': [],
            'river': []
        }

        current_street = 'preflop'
        for pid, action, amount in state.action_history:
            if action == Action.BET or action == Action.RAISE:
                street_actions[current_street].append(f"{action.value[0]}{int(amount)}")
            else:
                street_actions[current_street].append(action.value[0])

        # Create compact representation
        history_parts = []
        for street, actions in street_actions.items():
            if actions:
                history_parts.append(f"{street[0]}{''.join(actions)}")

        return '|'.join(history_parts)

    def _get_position_context(self, state: GameState, player_id: int) -> str:
        """Get simplified position context."""
        # Simplified: just early/late position
        num_players = len(state.players)
        player_pos = state.players.index(player_id)

        if player_pos < num_players // 2:
            return "early"
        else:
            return "late"

    def _get_pot_odds_context(self, state: GameState, player_id: int) -> str:
        """Get pot odds context for decision making."""
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        if to_call <= 0:
            return "0"

        pot_odds = state.pot_size / to_call

        # Discretize pot odds into buckets
        if pot_odds < 2:
            return "low"
        elif pot_odds < 4:
            return "med"
        else:
            return "high"

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """Get legal actions and bet sizes for current state."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack_size = self.initial_stack - state.total_bets.get(player_id, 0)

        # Fold (always legal if facing bet)
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        else:
            if to_call <= stack_size:
                actions.append((Action.CALL, to_call))

        # Bet/Raise with multiple sizing options
        min_bet = max(to_call * 2, state.pot_size * 0.5)  # Min raise or half pot
        bet_sizes = [
            state.pot_size * 0.5,   # Half pot
            state.pot_size * 0.75,  # 3/4 pot
            state.pot_size * 1.0,   # Pot
            state.pot_size * 1.5,   # 1.5x pot
            stack_size              # All-in
        ]

        for bet_size in bet_sizes:
            total_bet = bet_size + state.bets.get(player_id, 0)
            if total_bet >= min_bet and total_bet <= stack_size:
                if to_call == 0:
                    actions.append((Action.BET, bet_size))
                else:
                    actions.append((Action.RAISE, bet_size))

        return actions

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """Get current strategy for an information state using regret matching."""
        strategy = {}
        regret_sum = sum(max(0, self.regret_sum[info_state][action]) for action, _ in legal_actions)

        if regret_sum > 0:
            # Regret matching strategy
            for action, amount in legal_actions:
                strategy[(action, amount)] = max(0, self.regret_sum[info_state][(action, amount)]) / regret_sum
        else:
            # Uniform strategy if no positive regrets
            uniform_prob = 1.0 / len(legal_actions)
            for action, amount in legal_actions:
                strategy[(action, amount)] = uniform_prob

        return strategy

    def train(self, num_iterations: int = 100000):
        """Main training loop for MCCFR algorithm."""
        print(f"\n=== Starting MCCFR Training for {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            if (i + 1) % 10000 == 0:
                elapsed = time.time() - start_time
                print(f"Iteration {i+1}/{num_iterations} | "
                      f"Time: {elapsed:.1f}s | "
                      f"Avg: {elapsed/(i+1)*1000:.2f}ms/iter")

            # Run MCCFR iteration
            self._run_mccfr_iteration()

            # Update strategy every 100 iterations
            if i % 100 == 0:
                self._update_average_strategy()

        total_time = time.time() - start_time
        print(f"Training complete! Total time: {total_time:.1f}s")
        print(f"Final policy contains {len(self.policy)} information states")

    def _run_mccfr_iteration(self):
        """Run one iteration of Monte Carlo CFR."""
        # Create random initial game state
        initial_state = self._create_random_game_state()

        # Run MCCFR recursion for both players
        for player_id in [0, 1]:
            self._mccfr_recursive(initial_state, player_id, 1.0, 1.0)

    def _create_random_game_state(self) -> GameState:
        """Create a random initial poker game state."""
        # Deal random cards
        deck = ['2h', '2d', '2c', '2s', '3h', '3d', '3c', '3s', '4h', '4d', '4c', '4s',
                '5h', '5d', '5c', '5s', '6h', '6d', '6c', '6s', '7h', '7d', '7c', '7s',
                '8h', '8d', '8c', '8s', '9h', '9d', '9c', '9s', 'Th', 'Td', 'Tc', 'Ts',
                'Jh', 'Jd', 'Jc', 'Js', 'Qh', 'Qd', 'Qc', 'Qs', 'Kh', 'Kd', 'Kc', 'Ks',
                'Ah', 'Ad', 'Ac', 'As']

        random.shuffle(deck)

        # Deal hole cards
        hands = {0: deck[:2], 1: deck[2:4]}

        # Start with preflop state
        state = GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,  # Small blind + big blind
            street='preflop',
            board=[],
            hands=hands,
            bets={0: 1.0, 1: 2.0},  # Small blind, big blind
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

        return state

    def _mccfr_recursive(self, state: GameState, traversing_player: int,
                         pi_player: float, pi_opponent: float) -> float:
        """Recursive MCCFR algorithm implementation."""

        # Terminal node - return utility
        if state.is_terminal:
            return self._get_utility(state, traversing_player)

        # Get current player info state and legal actions
        current_player = state.current_player
        info_state = self.get_info_state(state, current_player)
        legal_actions = self.get_legal_actions(state, current_player)

        if not legal_actions:
            # No legal actions - fold
            return self._get_utility(state, traversing_player)

        # Get strategy for current info state
        strategy = self.get_strategy(info_state, legal_actions)

        # Calculate action utilities
        action_utilities = {}
        for action, amount in legal_actions:
            # Create new state after taking this action
            new_state = self._apply_action(state, action, amount)

            if current_player == traversing_player:
                action_utilities[(action, amount)] = self._mccfr_recursive(
                    new_state, traversing_player,
                    pi_player * strategy[(action, amount)], pi_opponent
                )
            else:
                action_utilities[(action, amount)] = self._mccfr_recursive(
                    new_state, traversing_player, pi_player,
                    pi_opponent * strategy[(action, amount)]
                )

        # Calculate node utility
        node_utility = sum(strategy[(action, amount)] * action_utilities[(action, amount)]
                          for action, amount in legal_actions)

        # Update regrets for traversing player
        if current_player == traversing_player:
            for action, amount in legal_actions:
                regret = action_utilities[(action, amount)] - node_utility
                self.regret_sum[info_state][(action, amount)] += pi_opponent * regret

        return node_utility

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """Apply an action to create a new game state."""
        # Deep copy state
        new_state = GameState(
            players=state.players.copy(),
            current_player=state.current_player,
            pot_size=state.pot_size,
            street=state.street,
            board=state.board.copy(),
            hands=state.hands.copy(),
            bets=state.bets.copy(),
            total_bets=state.total_bets.copy(),
            action_history=state.action_history.copy()
        )

        player_id = state.current_player

        # Apply action
        if action == Action.FOLD:
            new_state.is_terminal = True
            new_state.players.remove(player_id)
        elif action == Action.CHECK:
            pass  # No bet change
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player_id, 0)
            new_state.bets[player_id] += call_amount
            new_state.total_bets[player_id] += call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player_id] += amount
            new_state.total_bets[player_id] += amount
            new_state.pot_size += amount

        # Add to action history
        new_state.action_history.append((player_id, action, amount))

        # Advance to next player or next street
        if not new_state.is_terminal:
            new_state.current_player = 1 - player_id  # Switch players

            # Check if betting round is complete
            if self._is_betting_round_complete(new_state):
                new_state = self._advance_street(new_state)

        return new_state

    def _is_betting_round_complete(self, state: GameState) -> bool:
        """Check if the current betting round is complete."""
        # Simple check: all active players have equal bets
        active_bets = [state.bets[p] for p in state.players if p in state.bets]
        return len(set(active_bets)) <= 1

    def _advance_street(self, state: GameState) -> GameState:
        """Advance to the next street (flop -> turn -> river)."""
        street_progression = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('terminal', 5)
        }

        if state.street in street_progression:
            next_street, board_size = street_progression[state.street]

            if next_street == 'terminal':
                state.is_terminal = True
            else:
                state.street = next_street
                # Add community cards (simplified - would need proper deck tracking)
                while len(state.board) < board_size:
                    state.board.append('Xx')  # Placeholder

                # Reset street bets
                state.bets = {p: 0.0 for p in state.players}
                state.current_player = 0  # Reset position

        return state

    def _get_utility(self, state: GameState, player_id: int) -> float:
        """Calculate utility (winnings) for a player in terminal state."""
        if player_id not in state.players:
            # Player folded - lose total bets
            return -state.total_bets.get(player_id, 0)
        else:
            # Player won - get pot minus their bets
            return state.pot_size - state.total_bets.get(player_id, 0)

    def _update_average_strategy(self):
        """Update the average strategy (policy) from strategy sums."""
        for info_state in self.regret_sum:
            strategy = self.get_strategy(info_state,
                                       list(self.regret_sum[info_state].keys()))

            for action in strategy:
                self.strategy_sum[info_state][action] += strategy[action]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Get the final average policy after training."""
        final_policy = {}

        for info_state in self.strategy_sum:
            total_sum = sum(self.strategy_sum[info_state].values())
            if total_sum > 0:
                final_policy[info_state] = {
                    action: prob / total_sum
                    for action, prob in self.strategy_sum[info_state].items()
                }
            else:
                # Uniform distribution if no data
                actions = list(self.strategy_sum[info_state].keys())
                uniform_prob = 1.0 / len(actions) if actions else 1.0
                final_policy[info_state] = {
                    action: uniform_prob for action in actions
                }

        return final_policy

    def save_policy(self, filepath: str):
        """Save the trained policy to disk."""
        policy_data = {
            'final_policy': self.get_final_policy(),
            'regret_sum': dict(self.regret_sum),
            'strategy_sum': dict(self.strategy_sum),
            'iteration_count': self.iteration_count,
            'abstraction_models': self.abstraction_manager.kmeans_models
        }

        with open(filepath, 'wb') as f:
            pickle.dump(policy_data, f)
        print(f"Policy saved to {filepath}")

    def load_policy(self, filepath: str):
        """Load a trained policy from disk."""
        with open(filepath, 'rb') as f:
            policy_data = pickle.load(f)

        self.regret_sum = defaultdict(lambda: defaultdict(float), policy_data['regret_sum'])
        self.strategy_sum = defaultdict(lambda: defaultdict(float), policy_data['strategy_sum'])
        self.iteration_count = policy_data['iteration_count']

        if 'abstraction_models' in policy_data:
            self.abstraction_manager.kmeans_models = policy_data['abstraction_models']

        print(f"Policy loaded from {filepath}")


# Example usage and testing
if __name__ == "__main__":
    print("=== Advanced MCCFR Poker Trainer ===")

    # Initialize trainer
    trainer = AdvancedMCCFRTrainer()

    # Train the policy
    trainer.train(num_iterations=50000)

    # Save the trained policy
    trainer.save_policy("poker_cfr_policy.pkl")

    # Get final policy stats
    final_policy = trainer.get_final_policy()
    print(f"\nTraining Results:")
    print(f"- Information states learned: {len(final_policy)}")
    print(f"- Total iterations: {trainer.iteration_count}")
    print(f"- Strategy updates: {trainer.strategy_updates}")

    # Example: Get strategy for a specific situation
    example_info_state = list(final_policy.keys())[0] if final_policy else None
    if example_info_state:
        print(f"\nExample strategy for info state '{example_info_state}':")
        for action, prob in final_policy[example_info_state].items():
            print(f"  {action}: {prob:.3f}")

=== Advanced MCCFR Poker Trainer ===
Initializing Advanced MCCFR Trainer...


NameError: name 'AbstractionManager' is not defined

In [None]:
import numpy as np
import random
from collections import defaultdict, Counter
import pickle
import time
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from enum import Enum

# Import your optimized calculator
from phevaluator import evaluate_cards

# --- Basic Game Definitions ---
SUITS = ['h', 'd', 'c', 's']
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
DECK = [r + s for r in RANKS for s in SUITS]
RANK_MAP = {rank: i for i, rank in enumerate(RANKS)}

class PotentialAwareCalculator:
    """Fixed version of your sophisticated calculator."""
    def __init__(self):
        self.cache = {}
        print("PotentialAwareCalculator initialized.")

    def _calculate_current_equity(self, hand, board, remaining_deck, num_sims=200):
        """Reduced sims for performance but still meaningful."""
        wins = 0
        cards_to_draw = 2 + (5 - len(board))

        if len(remaining_deck) < cards_to_draw:
            return 0.0

        for _ in range(num_sims):
            try:
                samples = random.sample(remaining_deck, cards_to_draw)
                opp_hand, board_runout = samples[:2], samples[2:]
                final_board = board + board_runout

                player_rank = evaluate_cards(*hand, *final_board)
                opp_rank = evaluate_cards(*opp_hand, *final_board)

                if player_rank < opp_rank:
                    wins += 1
                elif player_rank == opp_rank:
                    wins += 0.5
            except (ValueError, IndexError):
                continue
        return wins / num_sims if num_sims > 0 else 0

    def _has_straight_in_ranks(self, rank_set):
        """Your original straight detection."""
        if len(rank_set) < 5:
            return False

        if all(x in rank_set for x in [12, 0, 1, 2, 3]):
            return True

        sorted_ranks = sorted(list(rank_set), reverse=True)
        for i in range(len(sorted_ranks) - 4):
            if sorted_ranks[i] - sorted_ranks[i+4] == 4:
                return True
        return False

    def _calculate_straight_outs(self, hand_and_board):
        """Your enhanced straight outs calculation."""
        if len(hand_and_board) < 3:
            return 0

        ranks_idx = [RANK_MAP[c[0]] for c in hand_and_board]
        rank_set = set(ranks_idx)
        straight_out_cards = set()

        for missing_rank_idx in range(13):
            if missing_rank_idx not in rank_set:
                test_ranks = rank_set | {missing_rank_idx}
                if self._has_straight_in_ranks(test_ranks):
                    completing_rank_str = RANKS[missing_rank_idx]
                    for suit in SUITS:
                        card = completing_rank_str + suit
                        if card not in hand_and_board:
                            straight_out_cards.add(card)
        return len(straight_out_cards)

    def _calculate_potential(self, hand, board, remaining_deck):
        """Your potential calculation with the enhanced straight logic."""
        if len(board) == 5:
            return 0.0, 0.0

        hand_and_board = hand + board

        # Flush potential
        suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
        flush_potential = 0.0
        flush_draw_suit = next((s for s, count in suit_counts.items() if count == 4), None)

        if flush_draw_suit:
            flush_outs = 13 - suit_counts[flush_draw_suit]
            if len(board) == 3:
                p_miss_turn = (len(remaining_deck) - flush_outs) / len(remaining_deck)
                p_miss_river = (len(remaining_deck) - 1 - flush_outs) / (len(remaining_deck) - 1)
                flush_potential = 1 - (p_miss_turn * p_miss_river)
            elif len(board) == 4:
                flush_potential = flush_outs / len(remaining_deck)

        # Straight potential using your enhanced method
        straight_potential = 0.0
        current_rank = evaluate_cards(*hand_and_board)
        if current_rank > 1609:
            num_straight_outs = self._calculate_straight_outs(hand_and_board)
            if num_straight_outs > 0:
                if len(board) == 3:
                    p_miss_turn = (len(remaining_deck) - num_straight_outs) / len(remaining_deck)
                    p_miss_river = (len(remaining_deck) - 1 - num_straight_outs) / (len(remaining_deck) - 1)
                    straight_potential = 1 - (p_miss_turn * p_miss_river)
                elif len(board) == 4:
                    straight_potential = num_straight_outs / len(remaining_deck)

        return flush_potential, straight_potential

    def _calculate_made_hand_potential(self, hand, board, remaining_deck):
        """Your made hand potential - properly implemented."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        is_one_pair = 3003 <= player_rank <= 6185
        is_two_pair = 1610 <= player_rank <= 3002

        outs = 0

        if is_one_pair:
            if hand[0][0] == hand[1][0]:  # Pocket pair
                board_ranks = [c[0] for c in board]
                if hand[0][0] not in board_ranks:
                    outs = 2  # Set outs
                else:
                    remaining_ranks = set(board_ranks) - {hand[0][0]}
                    outs = len(remaining_ranks) * 3
            else:
                outs = 5  # Standard improvement outs
        elif is_two_pair:
            outs = 4  # Full house outs
        elif 1610 > player_rank >= 167:  # Trips or better
            outs = 7  # Boat/quads outs

        # Convert to probability
        if outs > 0 and len(remaining_deck) > 0:
            if len(board) == 3:  # Flop
                prob = 1 - ((len(remaining_deck) - outs) / len(remaining_deck)) * \
                          ((len(remaining_deck) - 1 - outs) / (len(remaining_deck) - 1))
            else:  # Turn
                prob = outs / len(remaining_deck)
            return min(1.0, prob)

        return 0.0

    def _calculate_backdoor_potential(self, hand, board):
        """Your backdoor potential calculation."""
        if len(board) != 3:
            return 0.0

        hand_and_board = hand + board

        # Backdoor flush
        suit_counts = {s: sum(1 for c in hand_and_board if c[1] == s) for s in SUITS}
        backdoor_flush = 0.0
        for suit, count in suit_counts.items():
            if count == 3:  # Need runner-runner
                backdoor_flush = 0.04  # Approximate probability
                break

        # Backdoor straight
        ranks = set(RANK_MAP[c[0]] for c in hand_and_board)
        backdoor_straight = 0.0
        if len(ranks) >= 3:
            # Simplified: if we have 3+ ranks, some backdoor straight potential
            backdoor_straight = 0.02

        return max(backdoor_flush, backdoor_straight)

    def _calculate_board_texture(self, board):
        """Your board texture calculation."""
        if len(board) < 3:
            return 0.0, 0.0

        # Board suitedness
        board_suits = {card[1] for card in board}
        if len(board_suits) == 1:
            board_suitedness = 1.0  # Monotone
        elif len(board_suits) == 2:
            board_suitedness = 0.5  # Two-tone
        else:
            board_suitedness = 0.0  # Rainbow

        # Board connectivity
        board_ranks_idx = sorted([RANK_MAP[c[0]] for c in board])
        max_span = board_ranks_idx[-1] - board_ranks_idx[0]
        normalized_span = max_span / 12.0
        board_connectivity = 1.0 - normalized_span

        return board_suitedness, board_connectivity

    def _is_nut_hand(self, hand, board):
        """Your nut hand detection - optimized."""
        hand_and_board = hand + board
        player_rank = evaluate_cards(*hand_and_board)

        # Quick checks for obvious nuts
        if player_rank <= 10:  # Straight flush
            return 1.0
        if player_rank <= 166:  # Four of a kind
            return 0.9
        if player_rank <= 322:  # Full house
            return 0.8
        if player_rank <= 1609:  # Flush or straight
            return 0.7

        # For other hands, use rank-based approximation
        return min(1.0, (7463 - player_rank) / 7463)

    def calculate_feature_vector(self, hand, board):
        """Your main feature vector calculation - with all methods implemented."""
        cache_key = tuple(sorted(hand + board))
        if cache_key in self.cache:
            return self.cache[cache_key]

        known_cards = hand + board
        remaining_deck = [card for card in DECK if card not in known_cards]

        if len(remaining_deck) < 2:
            return [0.5] * 8

        try:
            equity = self._calculate_current_equity(hand, board, remaining_deck)
            flush_pot, straight_pot = self._calculate_potential(hand, board, remaining_deck)
            made_hand_pot = self._calculate_made_hand_potential(hand, board, remaining_deck)
            board_suit, board_conn = self._calculate_board_texture(board)
            backdoor_pot = self._calculate_backdoor_potential(hand, board)
            nut_strength = self._is_nut_hand(hand, board)

            feature_vector = [
                equity,           # Current winning probability
                flush_pot,        # Flush potential
                straight_pot,     # Straight potential
                made_hand_pot,    # Made hand improvement potential
                board_suit,       # Board suitedness
                board_conn,       # Board connectivity
                backdoor_pot,     # Backdoor potential
                nut_strength      # Relative hand strength
            ]

            # Validate and clamp values
            feature_vector = [max(0.0, min(1.0, float(x))) for x in feature_vector]

            self.cache[cache_key] = feature_vector
            return feature_vector

        except Exception as e:
            print(f"Error in feature calculation: {e}")
            return [0.5] * 8

class FixedAbstractionManager:
    """Your abstraction manager with the critical bug fixed."""
    def __init__(self, n_clusters_per_round=20):
        self.calculator = PotentialAwareCalculator()
        self.n_clusters = n_clusters_per_round
        self.kmeans_models = {}
        self.cluster_offsets = {'flop': 0, 'turn': 20, 'river': 40}

        # Card definitions
        self.SUITS = SUITS
        self.RANKS = RANKS
        self.DECK = DECK

        print("Fixed AbstractionManager initialized.")

    def _generate_valid_hand_board(self, num_board_cards):
        """Your original method."""
        max_attempts = 100

        for attempt in range(max_attempts):
            try:
                deck = self.DECK.copy()
                random.shuffle(deck)

                total_cards_needed = 2 + num_board_cards
                if total_cards_needed > len(deck):
                    continue

                hand = deck[:2]
                board = deck[2:2 + num_board_cards]

                if len(hand) != 2 or len(board) != num_board_cards:
                    continue

                test_vector = self.calculator.calculate_feature_vector(hand, board)

                if (isinstance(test_vector, list) and
                    len(test_vector) > 0 and
                    all(isinstance(x, (int, float)) and not np.isnan(x) for x in test_vector)):
                    return hand, board

            except Exception as e:
                continue

        return ['As', 'Kh'], ['Qc', '7d', '2s'][:num_board_cards]

    def _generate_data_for_round(self, n_samples, round_name):
        """Your data generation with better progress reporting."""
        print(f"--- Generating {n_samples} feature vectors for {round_name} ---")

        feature_vectors = []
        num_board_cards = {'flop': 3, 'turn': 4, 'river': 5}[round_name]
        failed_generations = 0

        start_time = time.time()

        for i in range(n_samples * 2):
            if len(feature_vectors) >= n_samples:
                break

            if (i + 1) % 5000 == 0:
                print(f"  ...attempting {i+1}, collected {len(feature_vectors)}/{n_samples}...")

            try:
                hand, board = self._generate_valid_hand_board(num_board_cards)
                vector = self.calculator.calculate_feature_vector(hand, board)

                if self._validate_feature_vector(vector):
                    feature_vectors.append(vector)
                else:
                    failed_generations += 1

            except Exception as e:
                failed_generations += 1
                continue

        duration = time.time() - start_time

        if len(feature_vectors) < self.n_clusters:
            raise ValueError(f"Only generated {len(feature_vectors)} valid vectors, "
                           f"need at least {self.n_clusters} for clustering")

        print(f"Generated {len(feature_vectors)} valid vectors "
              f"({failed_generations} failed) in {duration:.2f}s")

        return np.array(feature_vectors)

    def _validate_feature_vector(self, vector):
        """Your validation method."""
        if not isinstance(vector, (list, tuple, np.ndarray)):
            return False

        if len(vector) == 0:
            return False

        for val in vector:
            if not isinstance(val, (int, float)):
                return False
            if np.isnan(val) or np.isinf(val):
                return False

        for val in vector:
            if val < 0 or val > 1:
                return False

        return True

    def _remove_duplicate_vectors(self, vectors):
        """Your deduplication method."""
        unique_vectors = []
        seen = set()

        for vector in vectors:
            vector_tuple = tuple(np.round(vector, 6))
            if vector_tuple not in seen:
                seen.add(vector_tuple)
                unique_vectors.append(vector)

        print(f"Removed {len(vectors) - len(unique_vectors)} duplicate vectors")
        return np.array(unique_vectors)

    def train_all_postflop_models(self, n_samples_per_round=10000):
        """Your training method with sklearn import fix."""
        print("\n=== STARTING K-MEANS MODEL TRAINING FOR ALL ROUNDS ===")

        # Import sklearn here to handle missing dependency gracefully
        try:
            from sklearn.cluster import KMeans
            from sklearn.metrics import silhouette_score
        except ImportError:
            print("ERROR: sklearn not installed. Install with: pip install scikit-learn")
            return

        for round_name in ['flop', 'turn', 'river']:
            try:
                print(f"\n--- Training {round_name} model ---")

                data = self._generate_data_for_round(n_samples_per_round, round_name)
                data = self._remove_duplicate_vectors(data)

                if len(data) < self.n_clusters:
                    print(f"Warning: Only {len(data)} unique vectors for {self.n_clusters} clusters")
                    actual_clusters = min(self.n_clusters, len(data))
                else:
                    actual_clusters = self.n_clusters

                if data.ndim == 1:
                    data = data.reshape(-1, 1)

                print(f"Training K-Means on {data.shape[0]} vectors with {data.shape[1]} features")
                print(f"Using {actual_clusters} clusters")

                kmeans = KMeans(
                    n_clusters=actual_clusters,
                    random_state=42,
                    n_init=10,
                    max_iter=300
                )

                kmeans.fit(data)
                self.kmeans_models[round_name] = kmeans

                if len(data) > actual_clusters:
                    labels = kmeans.predict(data)
                    silhouette = silhouette_score(data, labels)
                    print(f"Silhouette Score: {silhouette:.3f}")

                print(f"{round_name} model trained successfully!")

            except Exception as e:
                print(f"Error training {round_name} model: {e}")
                self._create_fallback_model(round_name)

    def _create_fallback_model(self, round_name):
        """Your fallback model creation."""
        print(f"Creating fallback model for {round_name}")

        try:
            from sklearn.cluster import KMeans
            dummy_data = np.random.rand(self.n_clusters * 10, 8)
            kmeans = KMeans(n_clusters=self.n_clusters, random_state=42)
            kmeans.fit(dummy_data)
            self.kmeans_models[round_name] = kmeans
        except ImportError:
            print("Cannot create fallback model without sklearn")

    def get_postflop_bucket(self, hand, board):
        """Your method with the critical bug FIXED."""
        try:
            if not isinstance(hand, list) or len(hand) != 2:
                raise ValueError("Hand must be a list of 2 cards")
            if not isinstance(board, list) or len(board) not in [3, 4, 5]:
                raise ValueError("Board must have 3, 4, or 5 cards")

            num_board_cards = len(board)
            round_map = {3: 'flop', 4: 'turn', 5: 'river'}
            round_name = round_map[num_board_cards]

            if round_name not in self.kmeans_models:
                raise RuntimeError(f"No trained model for {round_name}")

            feature_vector = self.calculator.calculate_feature_vector(hand, board)

            if not self._validate_feature_vector(feature_vector):
                print(f"Warning: Invalid feature vector {feature_vector}, using fallback")
                return self.cluster_offsets[round_name]

            model = self.kmeans_models[round_name]
            vector_array = np.array([feature_vector])

            predicted_cluster = model.predict(vector_array)[0]

            # THE CRITICAL BUG FIX - was: final_bucket_idclass
            final_bucket_id = predicted_cluster + self.cluster_offsets[round_name]

            # Ensure bucket ID is in valid range
            min_bucket = self.cluster_offsets[round_name]
            max_bucket = min_bucket + self.n_clusters - 1
            final_bucket_id = max(min_bucket, min(max_bucket, final_bucket_id))

            return int(final_bucket_id)

        except Exception as e:
            print(f"Error in get_postflop_bucket: {e}")
            round_name = {3: 'flop', 4: 'turn', 5: 'river'}.get(len(board), 'flop')
            return self.cluster_offsets[round_name]

    def save_models(self, filepath):
        """Your save method."""
        try:
            model_data = {
                'kmeans_models': self.kmeans_models,
                'n_clusters': self.n_clusters,
                'cluster_offsets': self.cluster_offsets
            }

            with open(filepath, 'wb') as f:
                pickle.dump(model_data, f)

            print(f"Models saved to {filepath}")

        except Exception as e:
            print(f"Error saving models: {e}")

    def load_models(self, filepath):
        """Your load method."""
        try:
            with open(filepath, 'rb') as f:
                model_data = pickle.load(f)

            self.kmeans_models = model_data['kmeans_models']
            self.n_clusters = model_data['n_clusters']
            self.cluster_offsets = model_data['cluster_offsets']

            print(f"Models loaded from {filepath}")

        except Exception as e:
            print(f"Error loading models: {e}")

# Your original Action and GameState classes
class Action(Enum):
    FOLD = "fold"
    CHECK = "check"
    CALL = "call"
    BET = "bet"
    RAISE = "raise"

@dataclass
class GameState:
    """Your original GameState class."""
    players: List[int]
    current_player: int
    pot_size: float
    street: str
    board: List[str]
    hands: Dict[int, List[str]]
    bets: Dict[int, float]
    total_bets: Dict[int, float]
    action_history: List[Tuple[int, Action, float]]
    is_terminal: bool = False

# Key fixes for the MCCFR training performance issues

class OptimizedMCCFRTrainer:
    """Fixed version of your MCCFR trainer with performance optimizations."""

    def __init__(self, abstraction_manager=None, initial_stack=1000.0):
        print("Initializing Optimized MCCFR Trainer...")

        if abstraction_manager is None:
            self.abstraction_manager = FixedAbstractionManager(n_clusters_per_round=20)
            print("Training abstraction models...")
            self.abstraction_manager.train_all_postflop_models(n_samples_per_round=5000)  # Reduced
        else:
            self.abstraction_manager = abstraction_manager

        self.initial_stack = initial_stack

        # CFR data structures
        self.regret_sum = defaultdict(lambda: defaultdict(float))
        self.strategy_sum = defaultdict(lambda: defaultdict(float))
        self.policy = defaultdict(lambda: defaultdict(float))

        self.iteration_count = 0
        self.strategy_updates = 0

        # CRITICAL FIX 1: Add recursion depth tracking
        self.max_recursion_depth = 15
        self.current_depth = 0

        # CRITICAL FIX 2: Add action sequence tracking to prevent infinite loops
        self.action_sequence_cache = {}

        self._initialize_preflop_buckets()

        print("Optimized MCCFR Trainer initialized successfully.")

    def _initialize_preflop_buckets(self):
        """Simplified preflop bucketing for testing."""
        self.preflop_buckets = {}

        # Simplified bucketing - just a few categories
        premium_hands = ['AA', 'KK', 'QQ', 'AK']
        strong_hands = ['JJ', 'TT', 'AQ', 'AJ']

        bucket_id = 60
        for hand in premium_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in strong_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Simplified preflop bucketing."""
        ranks = sorted([card[0] for card in hand], key=lambda x: RANKS.index(x), reverse=True)
        suits = [card[1] for card in hand]

        if ranks[0] == ranks[1]:
            hand_str = ranks[0] + ranks[0]
        else:
            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """CRITICAL FIX 3: Simplified information state to prevent explosion."""
        hand = state.hands[player_id]
        board = state.board

        # Get bucket ID
        if state.street == 'preflop':
            bucket_id = self.get_preflop_bucket(hand)
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # SIMPLIFIED: Only include essential info to prevent state explosion
        recent_actions = state.action_history[-3:] if len(state.action_history) > 3 else state.action_history
        action_str = ''.join([f"{a.value[0]}" for _, a, _ in recent_actions])

        pot_ratio = min(9, int(state.pot_size / 10))  # Discretize pot size

        info_state = f"B{bucket_id}S{state.street[0]}A{action_str}P{pot_ratio}"
        return info_state

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """CRITICAL FIX 4: Simplified action space to prevent explosion."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack_size = self.initial_stack - state.total_bets.get(player_id, 0)

        # Always allow fold if facing a bet
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        else:
            if to_call <= stack_size:
                actions.append((Action.CALL, to_call))

        # SIMPLIFIED: Only 2 bet sizes instead of 5
        pot_bet = state.pot_size
        min_bet = max(to_call * 2, pot_bet * 0.5) if to_call > 0 else pot_bet * 0.5

        if pot_bet <= stack_size and pot_bet >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, pot_bet))
            else:
                actions.append((Action.RAISE, pot_bet))

        # All-in as second option
        if stack_size > pot_bet and stack_size >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, stack_size))
            else:
                actions.append((Action.RAISE, stack_size))

        return actions if actions else [(Action.FOLD, 0)]

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """Your original regret matching - no changes needed."""
        strategy = {}
        regret_sum = sum(max(0, self.regret_sum[info_state][action]) for action, _ in legal_actions)

        if regret_sum > 0:
            for action, amount in legal_actions:
                strategy[(action, amount)] = max(0, self.regret_sum[info_state][(action, amount)]) / regret_sum
        else:
            uniform_prob = 1.0 / len(legal_actions)
            for action, amount in legal_actions:
                strategy[(action, amount)] = uniform_prob

        return strategy

    def train(self, num_iterations: int = 1000):  # Reduced default
        """CRITICAL FIX 5: Better training loop with timeout protection."""
        print(f"\n=== Starting Optimized MCCFR Training for {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            # Progress reporting
            if (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"Iteration {i+1}/{num_iterations} | "
                      f"Time: {elapsed:.1f}s | "
                      f"States: {len(self.regret_sum)}")

            try:
                # CRITICAL: Reset depth counter for each iteration
                self.current_depth = 0
                self._run_mccfr_iteration()

                # Timeout protection
                if time.time() - start_time > 300:  # 5 minute timeout
                    print(f"Training timeout after {i+1} iterations")
                    break

            except RecursionError:
                print(f"Recursion limit hit at iteration {i+1}")
                continue
            except Exception as e:
                print(f"Error in iteration {i+1}: {e}")
                continue

            if i % 50 == 0:  # More frequent strategy updates
                self._update_average_strategy()

        total_time = time.time() - start_time
        print(f"Training complete! Total time: {total_time:.1f}s")
        print(f"Final policy contains {len(self.policy)} information states")

    def _run_mccfr_iteration(self):
        """CRITICAL FIX 6: Simplified iteration with better initial state."""
        initial_state = self._create_simple_game_state()

        # Run MCCFR for both players
        for player_id in [0, 1]:
            self.current_depth = 0  # Reset depth
            try:
                self._mccfr_recursive(initial_state, player_id, 1.0, 1.0)
            except Exception as e:
                # Gracefully handle errors and continue
                continue

    def _create_simple_game_state(self) -> GameState:
        """CRITICAL FIX 7: Simplified initial game state."""
        deck = DECK.copy()
        random.shuffle(deck)

        # Simple preflop state
        hands = {0: deck[:2], 1: deck[2:4]}

        state = GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,
            street='preflop',
            board=[],
            hands=hands,
            bets={0: 1.0, 1: 2.0},  # Small blind, big blind
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

        return state

    def _mccfr_recursive(self, state: GameState, traversing_player: int,
                         pi_player: float, pi_opponent: float) -> float:
        """CRITICAL FIX 8: Protected recursive function with depth limits."""

        # CRITICAL: Depth protection
        self.current_depth += 1
        if self.current_depth > self.max_recursion_depth:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Terminal state check
        if state.is_terminal:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Action sequence protection against infinite loops
        action_sequence = tuple((p, a.value, amt) for p, a, amt in state.action_history[-5:])
        if len(state.action_history) > 10:  # Reduced from 20
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        current_player = state.current_player
        info_state = self.get_info_state(state, current_player)
        legal_actions = self.get_legal_actions(state, current_player)

        if not legal_actions:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        strategy = self.get_strategy(info_state, legal_actions)

        # Calculate utilities for each action
        action_utilities = {}
        for action, amount in legal_actions:
            try:
                new_state = self._apply_action(state, action, amount)

                if current_player == traversing_player:
                    action_utilities[(action, amount)] = self._mccfr_recursive(
                        new_state, traversing_player,
                        pi_player * strategy[(action, amount)], pi_opponent
                    )
                else:
                    action_utilities[(action, amount)] = self._mccfr_recursive(
                        new_state, traversing_player, pi_player,
                        pi_opponent * strategy[(action, amount)]
                    )
            except Exception:
                # Fallback utility if recursion fails
                action_utilities[(action, amount)] = 0.0

        # Node utility calculation
        node_utility = sum(strategy[(action, amount)] * action_utilities[(action, amount)]
                          for action, amount in legal_actions)

        # Update regrets
        if current_player == traversing_player:
            for action, amount in legal_actions:
                regret = action_utilities[(action, amount)] - node_utility
                self.regret_sum[info_state][(action, amount)] += pi_opponent * regret

        self.current_depth -= 1
        return node_utility

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """CRITICAL FIX 9: Simplified action application."""
        from copy import copy

        # Shallow copy for performance
        new_state = copy(state)
        new_state.bets = state.bets.copy()
        new_state.total_bets = state.total_bets.copy()
        new_state.players = state.players.copy()
        new_state.action_history = state.action_history.copy()
        new_state.hands = state.hands.copy()
        new_state.board = state.board.copy()

        player_id = state.current_player

        # Apply the action
        if action == Action.FOLD:
            new_state.is_terminal = True
            if player_id in new_state.players:
                new_state.players.remove(player_id)
        elif action == Action.CHECK:
            pass  # No bet change
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player_id, 0)
            new_state.bets[player_id] += call_amount
            new_state.total_bets[player_id] += call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player_id] += amount
            new_state.total_bets[player_id] += amount
            new_state.pot_size += amount

        # Add to history
        new_state.action_history.append((player_id, action, amount))

        # CRITICAL FIX 10: Simplified game advancement
        if not new_state.is_terminal:
            new_state.current_player = 1 - player_id

            # Simple completion check
            if len(new_state.action_history) >= 4:  # Both players acted twice
                new_state = self._try_advance_street(new_state)
            elif action in [Action.CHECK, Action.CALL] and len(new_state.action_history) >= 2:
                # Check if both players checked or one called
                last_two = new_state.action_history[-2:]
                if all(act in [Action.CHECK, Action.CALL] for _, act, _ in last_two):
                    new_state = self._try_advance_street(new_state)

        return new_state

    def _try_advance_street(self, state: GameState) -> GameState:
        """CRITICAL FIX 11: Simplified street advancement."""
        street_map = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('terminal', 5)
        }

        if state.street in street_map:
            next_street, target_board_size = street_map[state.street]

            if next_street == 'terminal':
                state.is_terminal = True
            else:
                # Deal cards to reach target board size
                cards_needed = target_board_size - len(state.board)
                if cards_needed > 0:
                    # Get available cards
                    used_cards = set()
                    for hand in state.hands.values():
                        used_cards.update(hand)
                    used_cards.update(state.board)

                    available_cards = [c for c in DECK if c not in used_cards]
                    random.shuffle(available_cards)

                    # Deal new cards
                    new_cards = available_cards[:cards_needed]
                    state.board.extend(new_cards)

                state.street = next_street
                state.bets = {p: 0.0 for p in state.players}
                state.current_player = 0

        return state

    def _get_utility(self, state: GameState, player_id: int) -> float:
        """CRITICAL FIX 12: Simplified utility calculation."""
        if player_id not in state.players:
            return -state.total_bets.get(player_id, 0)

        if len(state.players) == 1:
            return state.pot_size - state.total_bets.get(player_id, 0)

        # Simplified showdown
        if len(state.board) >= 3:  # Any postflop situation
            try:
                # Use your hand evaluator if available
                player_hand = state.hands[player_id] + state.board
                if len(player_hand) >= 5:
                    player_rank = evaluate_cards(*player_hand)

                    best_opponent_rank = float('inf')
                    for opp_id in state.players:
                        if opp_id != player_id:
                            opp_hand = state.hands[opp_id] + state.board
                            if len(opp_hand) >= 5:
                                opp_rank = evaluate_cards(*opp_hand)
                                best_opponent_rank = min(best_opponent_rank, opp_rank)

                    if best_opponent_rank == float('inf'):
                        # No valid opponent hands
                        return state.pot_size - state.total_bets.get(player_id, 0)

                    if player_rank < best_opponent_rank:
                        return state.pot_size - state.total_bets.get(player_id, 0)
                    elif player_rank == best_opponent_rank:
                        return (state.pot_size / len(state.players)) - state.total_bets.get(player_id, 0)
                    else:
                        return -state.total_bets.get(player_id, 0)
            except:
                pass

        # Fallback: random outcome weighted by pot investment
        random_outcome = random.choice([1, -1])
        return random_outcome * (state.pot_size / 2) - state.total_bets.get(player_id, 0)

    def _update_average_strategy(self):
        """Your original method - no changes needed."""
        for info_state in self.regret_sum:
            actions = list(self.regret_sum[info_state].keys())
            strategy = self.get_strategy(info_state, actions)

            for action in strategy:
                self.strategy_sum[info_state][action] += strategy[action]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Your original method - no changes needed."""
        final_policy = {}

        for info_state in self.strategy_sum:
            total_sum = sum(self.strategy_sum[info_state].values())
            if total_sum > 0:
                final_policy[info_state] = {
                    str(action): prob / total_sum
                    for action, prob in self.strategy_sum[info_state].items()
                }
            else:
                actions = list(self.strategy_sum[info_state].keys())
                uniform_prob = 1.0 / len(actions) if actions else 1.0
                final_policy[info_state] = {
                    str(action): uniform_prob for action in actions
                }

        return final_policy

    def save_policy(self, filepath: str):
        """Your original method - simplified."""
        policy_data = {
            'final_policy': self.get_final_policy(),
            'iteration_count': self.iteration_count,
            'regret_states': len(self.regret_sum)
        }

        with open(filepath, 'wb') as f:
            pickle.dump(policy_data, f)
        print(f"Policy saved to {filepath}")


# USAGE EXAMPLE - Replace your test section with this:
if __name__ == "__main__":
    print("=== Testing OPTIMIZED MCCFR System ===")

    try:
        # Use your existing abstraction manager
        print("Testing abstraction system...")
        manager = FixedAbstractionManager(n_clusters_per_round=10)
        manager.train_all_postflop_models(n_samples_per_round=1000)

        # Test predictions
        test_hands = [
            (['As', 'Kh'], ['Qc', '7d', '2s']),
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),
        ]

        for hand, board in test_hands:
            bucket = manager.get_postflop_bucket(hand, board)
            print(f"Hand {hand} + Board {board} -> Bucket {bucket}")

        print("Abstraction system working! Now testing OPTIMIZED MCCFR...")

        # Use the optimized trainer
        trainer = OptimizedMCCFRTrainer(manager)
        trainer.train(num_iterations=500)  # Start small

        # Show results
        final_policy = trainer.get_final_policy()
        print(f"\nTraining Results:")
        print(f"- Information states learned: {len(final_policy)}")
        print(f"- Total iterations: {trainer.iteration_count}")
        print(f"- Strategy updates: {trainer.strategy_updates}")

        if final_policy:
            example_info_state = list(final_policy.keys())[0]
            print(f"\nExample strategy for info state '{example_info_state}':")
            for action, prob in final_policy[example_info_state].items():
                print(f"  {action}: {prob:.3f}")

        trainer.save_policy("optimized_cfr_policy.pkl")

        print("SUCCESS: Optimized system is working!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

#ACtual work


In [None]:
@dataclass
class GameState:
    """
    Represents the complete state of a poker game.
    """
    players: List[int]
    current_player: int
    pot_size: float
    street: str  # 'preflop', 'flop', 'turn', 'river'
    board: List[str]  # Community cards
    hands: Dict[int, List[str]]  # Player hole cards
    bets: Dict[int, float]  # Current round bets
    total_bets: Dict[int, float]  # Total bets across all rounds
    action_history: List[Tuple[int, Action, float]] = field(default_factory=list)
    is_terminal: bool = False

    def copy(self) -> 'GameState':
        """Create a deep copy of the game state."""
        return GameState(
            players=self.players.copy(),
            current_player=self.current_player,
            pot_size=self.pot_size,
            street=self.street,
            board=self.board.copy(),
            hands={k: v.copy() for k, v in self.hands.items()},
            bets=self.bets.copy(),
            total_bets=self.total_bets.copy(),
            action_history=self.action_history.copy(),
            is_terminal=self.is_terminal
        )

In [None]:
#++++++++++++++++++++++++++++++++++

# Key fixes for the MCCFR training performance issues

class OptimizedMCCFRTrainer:
    """Fixed version of your MCCFR trainer with performance optimizations."""

    def __init__(self, abstraction_manager=None, initial_stack=1000.0):
        print("Initializing Optimized MCCFR Trainer...")

        if abstraction_manager is None:
            self.abstraction_manager = FixedAbstractionManager(n_clusters_per_round=20)
            print("Training abstraction models...")
            self.abstraction_manager.train_all_postflop_models(n_samples_per_round=5000)  # Reduced
        else:
            self.abstraction_manager = abstraction_manager

        self.initial_stack = initial_stack

        # CFR data structures
        self.regret_sum = defaultdict(lambda: defaultdict(float))
        self.strategy_sum = defaultdict(lambda: defaultdict(float))
        self.policy = defaultdict(lambda: defaultdict(float))

        self.iteration_count = 0
        self.strategy_updates = 0

        # CRITICAL FIX 1: Add recursion depth tracking
        self.max_recursion_depth = 15
        self.current_depth = 0

        # CRITICAL FIX 2: Add action sequence tracking to prevent infinite loops
        self.action_sequence_cache = {}

        self._initialize_preflop_buckets()

        print("Optimized MCCFR Trainer initialized successfully.")

    def _initialize_preflop_buckets(self):
        """Simplified preflop bucketing for testing."""
        self.preflop_buckets = {}

        # Simplified bucketing - just a few categories
        premium_hands = ['AA', 'KK', 'QQ', 'AK']
        strong_hands = ['JJ', 'TT', 'AQ', 'AJ']

        bucket_id = 60
        for hand in premium_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in strong_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Simplified preflop bucketing."""
        ranks = sorted([card[0] for card in hand], key=lambda x: RANKS.index(x), reverse=True)
        suits = [card[1] for card in hand]

        if ranks[0] == ranks[1]:
            hand_str = ranks[0] + ranks[0]
        else:
            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """CRITICAL FIX 3: Simplified information state to prevent explosion."""
        hand = state.hands[player_id]
        board = state.board

        # Get bucket ID
        if state.street == 'preflop':
            bucket_id = self.get_preflop_bucket(hand)
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # SIMPLIFIED: Only include essential info to prevent state explosion
        recent_actions = state.action_history[-3:] if len(state.action_history) > 3 else state.action_history
        action_str = ''.join([f"{a.value[0]}" for _, a, _ in recent_actions])

        pot_ratio = min(9, int(state.pot_size / 10))  # Discretize pot size

        info_state = f"B{bucket_id}S{state.street[0]}A{action_str}P{pot_ratio}"
        return info_state

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """CRITICAL FIX 4: Simplified action space to prevent explosion."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack_size = self.initial_stack - state.total_bets.get(player_id, 0)

        # Always allow fold if facing a bet
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        else:
            if to_call <= stack_size:
                actions.append((Action.CALL, to_call))

        # SIMPLIFIED: Only 2 bet sizes instead of 5
        pot_bet = state.pot_size
        min_bet = max(to_call * 2, pot_bet * 0.5) if to_call > 0 else pot_bet * 0.5

        if pot_bet <= stack_size and pot_bet >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, pot_bet))
            else:
                actions.append((Action.RAISE, pot_bet))

        # All-in as second option
        if stack_size > pot_bet and stack_size >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, stack_size))
            else:
                actions.append((Action.RAISE, stack_size))

        return actions if actions else [(Action.FOLD, 0)]

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """FIXED: Consistent action tuple handling."""
        strategy = {}

        # Calculate regret sum for normalization
        regret_sum = 0.0
        for action_tuple in legal_actions:
            regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
            regret_sum += max(0, regret_value)

        # Generate strategy based on regret matching
        if regret_sum > 0:
            for action_tuple in legal_actions:
                regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
                strategy[action_tuple] = max(0, regret_value) / regret_sum
        else:
            # Uniform strategy if no positive regrets
            uniform_prob = 1.0 / len(legal_actions)
            for action_tuple in legal_actions:
                strategy[action_tuple] = uniform_prob

        return strategy

    def train(self, num_iterations: int = 1000):  # Reduced default
        """CRITICAL FIX 5: Better training loop with timeout protection."""
        print(f"\n=== Starting Optimized MCCFR Training for {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            # Progress reporting
            if (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"Iteration {i+1}/{num_iterations} | "
                      f"Time: {elapsed:.1f}s | "
                      f"States: {len(self.regret_sum)}")

            try:
                # CRITICAL: Reset depth counter for each iteration
                self.current_depth = 0
                self._run_mccfr_iteration()

                # Timeout protection
                if time.time() - start_time > 300:  # 5 minute timeout
                    print(f"Training timeout after {i+1} iterations")
                    break

            except RecursionError:
                print(f"Recursion limit hit at iteration {i+1}")
                continue
            except Exception as e:
                print(f"Error in iteration {i+1}: {e}")
                continue

            if i % 50 == 0:  # More frequent strategy updates
                self._update_average_strategy()

        total_time = time.time() - start_time
        print(f"Training complete! Total time: {total_time:.1f}s")
        print(f"Final policy contains {len(self.policy)} information states")

    def _run_mccfr_iteration(self):
        """CRITICAL FIX 6: Simplified iteration with better initial state."""
        initial_state = self._create_simple_game_state()

        # Run MCCFR for both players
        for player_id in [0, 1]:
            self.current_depth = 0  # Reset depth
            try:
                self._mccfr_recursive(initial_state, player_id, 1.0, 1.0)
            except Exception as e:
                # Gracefully handle errors and continue
                continue

    def _create_simple_game_state(self) -> GameState:
        """CRITICAL FIX 7: Simplified initial game state."""
        deck = DECK.copy()
        random.shuffle(deck)

        # Simple preflop state
        hands = {0: deck[:2], 1: deck[2:4]}

        state = GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,
            street='preflop',
            board=[],
            hands=hands,
            bets={0: 1.0, 1: 2.0},  # Small blind, big blind
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

        return state

    def _mccfr_recursive(self, state: GameState, traversing_player: int,
                         pi_player: float, pi_opponent: float) -> float:
        """CRITICAL FIX 8: Protected recursive function with depth limits."""

        # CRITICAL: Depth protection
        self.current_depth += 1
        if self.current_depth > self.max_recursion_depth:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Terminal state check
        if state.is_terminal:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Action sequence protection against infinite loops
        action_sequence = tuple((p, a.value, amt) for p, a, amt in state.action_history[-5:])
        if len(state.action_history) > 10:  # Reduced from 20
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        current_player = state.current_player
        info_state = self.get_info_state(state, current_player)
        legal_actions = self.get_legal_actions(state, current_player)

        if not legal_actions:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        strategy = self.get_strategy(info_state, legal_actions)

        # Calculate utilities for each action
        action_utilities = {}
        for action_tuple in legal_actions:
            try:
                new_state = self._apply_action(state, action_tuple[0], action_tuple[1])

                if current_player == traversing_player:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player,
                        pi_player * strategy[action_tuple], pi_opponent
                    )
                else:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player, pi_player,
                        pi_opponent * strategy[action_tuple]
                    )
            except Exception:
                # Fallback utility if recursion fails
                action_utilities[action_tuple] = 0.0

        # Node utility calculation
        node_utility = sum(strategy[action_tuple] * action_utilities[action_tuple]
                          for action_tuple in legal_actions)

        # Update regrets
        if current_player == traversing_player:
            for action_tuple in legal_actions:
                regret = action_utilities[action_tuple] - node_utility
                self.regret_sum[info_state][action_tuple] += pi_opponent * regret

        self.current_depth -= 1
        return node_utility

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """CRITICAL FIX 9: Simplified action application."""
        from copy import copy

        # Shallow copy for performance
        new_state = copy(state)
        new_state.bets = state.bets.copy()
        new_state.total_bets = state.total_bets.copy()
        new_state.players = state.players.copy()
        new_state.action_history = state.action_history.copy()
        new_state.hands = state.hands.copy()
        new_state.board = state.board.copy()

        player_id = state.current_player

        # Apply the action
        if action == Action.FOLD:
            new_state.is_terminal = True
            if player_id in new_state.players:
                new_state.players.remove(player_id)
        elif action == Action.CHECK:
            pass  # No bet change
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player_id, 0)
            new_state.bets[player_id] += call_amount
            new_state.total_bets[player_id] += call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player_id] += amount
            new_state.total_bets[player_id] += amount
            new_state.pot_size += amount

        # Add to history
        new_state.action_history.append((player_id, action, amount))

        # CRITICAL FIX 10: Simplified game advancement
        if not new_state.is_terminal:
            new_state.current_player = 1 - player_id

            # Simple completion check
            if len(new_state.action_history) >= 4:  # Both players acted twice
                new_state = self._try_advance_street(new_state)
            elif action in [Action.CHECK, Action.CALL] and len(new_state.action_history) >= 2:
                # Check if both players checked or one called
                last_two = new_state.action_history[-2:]
                if all(act in [Action.CHECK, Action.CALL] for _, act, _ in last_two):
                    new_state = self._try_advance_street(new_state)

        return new_state

    def _try_advance_street(self, state: GameState) -> GameState:
        """CRITICAL FIX 11: Simplified street advancement."""
        street_map = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('terminal', 5)
        }

        if state.street in street_map:
            next_street, target_board_size = street_map[state.street]

            if next_street == 'terminal':
                state.is_terminal = True
            else:
                # Deal cards to reach target board size
                cards_needed = target_board_size - len(state.board)
                if cards_needed > 0:
                    # Get available cards
                    used_cards = set()
                    for hand in state.hands.values():
                        used_cards.update(hand)
                    used_cards.update(state.board)

                    available_cards = [c for c in DECK if c not in used_cards]
                    random.shuffle(available_cards)

                    # Deal new cards
                    new_cards = available_cards[:cards_needed]
                    state.board.extend(new_cards)

                state.street = next_street
                state.bets = {p: 0.0 for p in state.players}
                state.current_player = 0

        return state

    def _get_utility(self, state: GameState, player_id: int) -> float:
        """CRITICAL FIX 12: Simplified utility calculation."""
        if player_id not in state.players:
            return -state.total_bets.get(player_id, 0)

        if len(state.players) == 1:
            return state.pot_size - state.total_bets.get(player_id, 0)

        # Simplified showdown
        if len(state.board) >= 3:  # Any postflop situation
            try:
                # Use your hand evaluator if available
                player_hand = state.hands[player_id] + state.board
                if len(player_hand) >= 5:
                    player_rank = evaluate_cards(*player_hand)

                    best_opponent_rank = float('inf')
                    for opp_id in state.players:
                        if opp_id != player_id:
                            opp_hand = state.hands[opp_id] + state.board
                            if len(opp_hand) >= 5:
                                opp_rank = evaluate_cards(*opp_hand)
                                best_opponent_rank = min(best_opponent_rank, opp_rank)

                    if best_opponent_rank == float('inf'):
                        # No valid opponent hands
                        return state.pot_size - state.total_bets.get(player_id, 0)

                    if player_rank < best_opponent_rank:
                        return state.pot_size - state.total_bets.get(player_id, 0)
                    elif player_rank == best_opponent_rank:
                        return (state.pot_size / len(state.players)) - state.total_bets.get(player_id, 0)
                    else:
                        return -state.total_bets.get(player_id, 0)
            except:
                pass

        # Fallback: random outcome weighted by pot investment
        random_outcome = random.choice([1, -1])
        return random_outcome * (state.pot_size / 2) - state.total_bets.get(player_id, 0)

    def _update_average_strategy(self):
        """FIXED: Proper handling of action tuples."""
        for info_state in self.regret_sum:
            # The keys in regret_sum are (Action, amount) tuples
            action_tuples = list(self.regret_sum[info_state].keys())

            # Convert to the format expected by get_strategy
            legal_actions = [(action, amount) for action, amount in action_tuples]

            if legal_actions:
                strategy = self.get_strategy(info_state, legal_actions)

                for action_tuple in strategy:
                    self.strategy_sum[info_state][action_tuple] += strategy[action_tuple]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Your original method - no changes needed."""
        final_policy = {}

        for info_state in self.strategy_sum:
            total_sum = sum(self.strategy_sum[info_state].values())
            if total_sum > 0:
                final_policy[info_state] = {
                    str(action): prob / total_sum
                    for action, prob in self.strategy_sum[info_state].items()
                }
            else:
                actions = list(self.strategy_sum[info_state].keys())
                uniform_prob = 1.0 / len(actions) if actions else 1.0
                final_policy[info_state] = {
                    str(action): uniform_prob for action in actions
                }

        return final_policy

    def save_policy(self, filepath: str):
        """Your original method - simplified."""
        policy_data = {
            'final_policy': self.get_final_policy(),
            'iteration_count': self.iteration_count,
            'regret_states': len(self.regret_sum)
        }

        with open(filepath, 'wb') as f:
            pickle.dump(policy_data, f)
        print(f"Policy saved to {filepath}")


# USAGE EXAMPLE - Replace your test section with this:
if __name__ == "__main__":
    print("=== Testing OPTIMIZED MCCFR System ===")

    try:
        # Use your existing abstraction manager
        print("Testing abstraction system...")
        manager = FixedAbstractionManager(n_clusters_per_round=10)
        manager.train_all_postflop_models(n_samples_per_round=1000)

        # Test predictions
        test_hands = [
            (['As', 'Kh'], ['Qc', '7d', '2s']),
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),
        ]

        for hand, board in test_hands:
            bucket = manager.get_postflop_bucket(hand, board)
            print(f"Hand {hand} + Board {board} -> Bucket {bucket}")

        print("Abstraction system working! Now testing OPTIMIZED MCCFR...")

        # Use the optimized trainer
        trainer = OptimizedMCCFRTrainer(manager)
        trainer.train(num_iterations=500)  # Start small

        # Show results
        final_policy = trainer.get_final_policy()
        print(f"\nTraining Results:")
        print(f"- Information states learned: {len(final_policy)}")
        print(f"- Total iterations: {trainer.iteration_count}")
        print(f"- Strategy updates: {trainer.strategy_updates}")

        if final_policy:
            example_info_state = list(final_policy.keys())[0]
            print(f"\nExample strategy for info state '{example_info_state}':")
            for action, prob in final_policy[example_info_state].items():
                print(f"  {action}: {prob:.3f}")

        trainer.save_policy("optimized_cfr_policy.pkl")

        print("SUCCESS: Optimized system is working!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

=== Testing OPTIMIZED MCCFR System ===
Testing abstraction system...
Optimized PotentialAwareCalculator initialized.
Fixed AbstractionManager initialized.

=== STARTING K-MEANS MODEL TRAINING FOR ALL ROUNDS ===

--- Training flop model ---
--- Generating 1000 feature vectors for flop ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid vectors (0 failed) in 0.61s
Removed 5 duplicate vectors
Training K-Means on 995 vectors with 8 features
Using 10 clusters
Silhouette Score: 0.260
flop model trained successfully!

--- Training turn model ---
--- Generating 1000 feature vectors for turn ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid vectors (0 failed) in 0.61s
Removed 10 duplicate vectors
Training K-Means on 990 vectors with 8 features
Using 10 clusters
Silhouette Score: 0.255
turn model trained successfully!

--- Training river model ---
--- Generating 1000 feature vectors for river ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid

In [None]:
from __future__ import annotations
import itertools
import random
import time
import pickle
from typing import List, Dict, Tuple, Set
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass, field
from copy import deepcopy

# Card and game constants
SUITS = ['h', 'd', 'c', 's']  # Hearts, Diamonds, Clubs, Spades
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']

# Create a standard deck
DECK = [f"{rank}{suit}" for rank in RANKS for suit in SUITS]

class Action(Enum):
    """Poker actions available to players."""
    FOLD = "fold"
    CHECK = "check"
    CALL = "call"
    BET = "bet"
    RAISE = "raise"

@dataclass
class GameState:
    """
    Represents the complete state of a poker game.
    """
    players: List[int]
    current_player: int
    pot_size: float
    street: str  # 'preflop', 'flop', 'turn', 'river'
    board: List[str]  # Community cards
    hands: Dict[int, List[str]]  # Player hole cards
    bets: Dict[int, float]  # Current round bets
    total_bets: Dict[int, float]  # Total bets across all rounds
    action_history: List[Tuple[int, Action, float]] = field(default_factory=list)
    is_terminal: bool = False

    def copy(self) -> 'GameState':
        """Create a deep copy of the game state."""
        return GameState(
            players=self.players.copy(),
            current_player=self.current_player,
            pot_size=self.pot_size,
            street=self.street,
            board=self.board.copy(),
            hands={k: v.copy() for k, v in self.hands.items()},
            bets=self.bets.copy(),
            total_bets=self.total_bets.copy(),
            action_history=self.action_history.copy(),
            is_terminal=self.is_terminal
        )

class Card:
    """
    Represents a single playing card with rank and suit.
    """
    __slots__ = ('rank', 'suit')

    def __init__(self, rank: str, suit: str) -> None:
        if rank not in RANKS:
            raise ValueError(f"Invalid rank: {rank}. Must be one of {RANKS}")
        if suit not in SUITS:
            raise ValueError(f"Invalid suit: {suit}. Must be one of {SUITS}")
        self.rank = rank
        self.suit = suit

    def __repr__(self) -> str:
        return f"{self.rank}{self.suit}"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Card):
            return NotImplemented
        return self.rank == other.rank and self.suit == other.suit

    def __hash__(self) -> int:
        return hash((self.rank, self.suit))

    def get_numeric_rank(self) -> int:
        """Returns the numeric rank of the card (2-14, where Ace is 14)."""
        return RANKS.index(self.rank) + 2

def create_deck() -> List[Card]:
    """Creates a standard 52-card deck."""
    return [Card(rank, suit) for rank in RANKS for suit in SUITS]

def get_hand_bucket(hand: List[Card]) -> str:
    """
    Determines the preflop bucket for a given two-card hand.
    """
    if len(hand) != 2:
        raise ValueError("Hand must consist of exactly two cards.")

    # Sort cards by rank (higher rank first)
    sorted_hand = sorted(hand, key=lambda card: RANKS.index(card.rank), reverse=True)
    card1, card2 = sorted_hand

    if card1.rank == card2.rank:
        return f"{card1.rank}{card2.rank}"  # Pocket pair
    elif card1.suit == card2.suit:
        return f"{card1.rank}{card2.rank}s"  # Suited
    else:
        return f"{card1.rank}{card2.rank}o"  # Offsuit

# Simple hand evaluator (placeholder - you can integrate your actual evaluator)
def evaluate_hand_strength(hand: List[str], board: List[str]) -> float:
    """
    Simple hand strength evaluator. Replace with your actual hand evaluator.
    Returns a value between 0 and 1 where higher is better.
    """
    all_cards = hand + board
    if len(all_cards) < 5:
        return 0.5  # Unknown strength

    # Very simplified evaluation based on high cards
    ranks_values = {'2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8,
                   '9': 9, 'T': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14}

    card_ranks = [ranks_values[card[0]] for card in all_cards]
    max_rank = max(card_ranks)
    avg_rank = sum(card_ranks) / len(card_ranks)

    # Simple strength calculation
    strength = (max_rank + avg_rank) / 28.0  # Normalize roughly to 0-1
    return min(1.0, max(0.0, strength))

# Mock abstraction manager for testing
class MockAbstractionManager:
    """Simple mock abstraction manager for testing."""

    def __init__(self, n_clusters_per_round=10):
        self.n_clusters = n_clusters_per_round

    def get_postflop_bucket(self, hand: List[str], board: List[str]) -> int:
        """Return a simple bucket based on hand strength."""
        strength = evaluate_hand_strength(hand, board)
        return int(strength * (self.n_clusters - 1))

    def train_all_postflop_models(self, n_samples_per_round=1000):
        """Mock training method."""
        print(f"Mock abstraction manager trained with {n_samples_per_round} samples")

class OptimizedMCCFRTrainer:
    """Fixed version of MCCFR trainer with performance optimizations."""

    def __init__(self, abstraction_manager=None, initial_stack=1000.0):
        print("Initializing Optimized MCCFR Trainer...")

        if abstraction_manager is None:
            self.abstraction_manager = MockAbstractionManager(n_clusters_per_round=20)
            self.abstraction_manager.train_all_postflop_models(n_samples_per_round=1000)
        else:
            self.abstraction_manager = abstraction_manager

        self.initial_stack = initial_stack

        # CFR data structures
        self.regret_sum = defaultdict(lambda: defaultdict(float))
        self.strategy_sum = defaultdict(lambda: defaultdict(float))
        self.policy = defaultdict(lambda: defaultdict(float))

        self.iteration_count = 0
        self.strategy_updates = 0

        # Recursion depth tracking
        self.max_recursion_depth = 15
        self.current_depth = 0

        # Action sequence tracking to prevent infinite loops
        self.action_sequence_cache = {}

        self._initialize_preflop_buckets()

        print("Optimized MCCFR Trainer initialized successfully.")

    def _initialize_preflop_buckets(self):
        """Simplified preflop bucketing for testing."""
        self.preflop_buckets = {}

        # Simplified bucketing - just a few categories
        premium_hands = ['AA', 'KK', 'QQ', 'AKs', 'AKo']
        strong_hands = ['JJ', 'TT', 'AQs', 'AQo', 'AJs', 'AJo']
        medium_hands = ['99', '88', 'KQs', 'KQo', 'KJs', 'KJo']

        bucket_id = 60
        for hand in premium_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in strong_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in medium_hands:
            self.preflop_buckets[hand] = bucket_id

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Simplified preflop bucketing."""
        if len(hand) != 2:
            return 79  # Default bucket

        ranks = sorted([card[0] for card in hand], key=lambda x: RANKS.index(x), reverse=True)
        suits = [card[1] for card in hand]

        if ranks[0] == ranks[1]:
            hand_str = ranks[0] + ranks[0]
        else:
            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """Simplified information state to prevent explosion."""
        hand = state.hands[player_id]
        board = state.board

        # Get bucket ID
        if state.street == 'preflop':
            bucket_id = self.get_preflop_bucket(hand)
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # Only include essential info to prevent state explosion
        recent_actions = state.action_history[-3:] if len(state.action_history) > 3 else state.action_history
        action_str = ''.join([f"{a.value[0]}" for _, a, _ in recent_actions])

        pot_ratio = min(9, int(state.pot_size / 10))  # Discretize pot size

        info_state = f"B{bucket_id}S{state.street[0]}A{action_str}P{pot_ratio}"
        return info_state

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """Simplified action space to prevent explosion."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack_size = self.initial_stack - state.total_bets.get(player_id, 0)

        # Always allow fold if facing a bet
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        else:
            if to_call <= stack_size:
                actions.append((Action.CALL, to_call))

        # Simplified: Only 2 bet sizes instead of 5
        pot_bet = state.pot_size
        min_bet = max(to_call * 2, pot_bet * 0.5) if to_call > 0 else pot_bet * 0.5

        if pot_bet <= stack_size and pot_bet >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, pot_bet))
            else:
                actions.append((Action.RAISE, pot_bet))

        # All-in as second option
        if stack_size > pot_bet and stack_size >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, stack_size))
            else:
                actions.append((Action.RAISE, stack_size))

        return actions if actions else [(Action.FOLD, 0)]

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """Consistent action tuple handling."""
        strategy = {}

        # Calculate regret sum for normalization
        regret_sum = 0.0
        for action_tuple in legal_actions:
            regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
            regret_sum += max(0, regret_value)

        # Generate strategy based on regret matching
        if regret_sum > 0:
            for action_tuple in legal_actions:
                regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
                strategy[action_tuple] = max(0, regret_value) / regret_sum
        else:
            # Uniform strategy if no positive regrets
            uniform_prob = 1.0 / len(legal_actions)
            for action_tuple in legal_actions:
                strategy[action_tuple] = uniform_prob

        return strategy

    def train(self, num_iterations: int = 1000):
        """Better training loop with timeout protection."""
        print(f"\n=== Starting Optimized MCCFR Training for {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            # Progress reporting
            if (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"Iteration {i+1}/{num_iterations} | "
                      f"Time: {elapsed:.1f}s | "
                      f"States: {len(self.regret_sum)}")

            try:
                # Reset depth counter for each iteration
                self.current_depth = 0
                self._run_mccfr_iteration()

                # Timeout protection
                if time.time() - start_time > 300:  # 5 minute timeout
                    print(f"Training timeout after {i+1} iterations")
                    break

            except RecursionError:
                print(f"Recursion limit hit at iteration {i+1}")
                continue
            except Exception as e:
                print(f"Error in iteration {i+1}: {e}")
                continue

            if i % 50 == 0:  # More frequent strategy updates
                self._update_average_strategy()

        total_time = time.time() - start_time
        print(f"Training complete! Total time: {total_time:.1f}s")
        print(f"Final policy contains {len(self.policy)} information states")

    def _run_mccfr_iteration(self):
        """Simplified iteration with better initial state."""
        initial_state = self._create_simple_game_state()

        # Run MCCFR for both players
        for player_id in [0, 1]:
            self.current_depth = 0  # Reset depth
            try:
                self._mccfr_recursive(initial_state, player_id, 1.0, 1.0)
            except Exception as e:
                # Gracefully handle errors and continue
                continue

    def _create_simple_game_state(self) -> GameState:
        """Simplified initial game state."""
        deck = DECK.copy()
        random.shuffle(deck)

        # Simple preflop state
        hands = {0: deck[:2], 1: deck[2:4]}

        state = GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,
            street='preflop',
            board=[],
            hands=hands,
            bets={0: 1.0, 1: 2.0},  # Small blind, big blind
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

        return state

    def _mccfr_recursive(self, state: GameState, traversing_player: int,
                         pi_player: float, pi_opponent: float) -> float:
        """Protected recursive function with depth limits."""

        # Depth protection
        self.current_depth += 1
        if self.current_depth > self.max_recursion_depth:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Terminal state check
        if state.is_terminal:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Action sequence protection against infinite loops
        if len(state.action_history) > 10:  # Reduced from 20
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        current_player = state.current_player
        info_state = self.get_info_state(state, current_player)
        legal_actions = self.get_legal_actions(state, current_player)

        if not legal_actions:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        strategy = self.get_strategy(info_state, legal_actions)

        # Calculate utilities for each action
        action_utilities = {}
        for action_tuple in legal_actions:
            try:
                new_state = self._apply_action(state, action_tuple[0], action_tuple[1])

                if current_player == traversing_player:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player,
                        pi_player * strategy[action_tuple], pi_opponent
                    )
                else:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player, pi_player,
                        pi_opponent * strategy[action_tuple]
                    )
            except Exception:
                # Fallback utility if recursion fails
                action_utilities[action_tuple] = 0.0

        # Node utility calculation
        node_utility = sum(strategy[action_tuple] * action_utilities[action_tuple]
                          for action_tuple in legal_actions)

        # Update regrets
        if current_player == traversing_player:
            for action_tuple in legal_actions:
                regret = action_utilities[action_tuple] - node_utility
                self.regret_sum[info_state][action_tuple] += pi_opponent * regret

        self.current_depth -= 1
        return node_utility

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """Simplified action application."""
        # Create a copy of the state
        new_state = state.copy()

        player_id = state.current_player

        # Apply the action
        if action == Action.FOLD:
            new_state.is_terminal = True
            if player_id in new_state.players:
                new_state.players.remove(player_id)
        elif action == Action.CHECK:
            pass  # No bet change
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player_id, 0)
            new_state.bets[player_id] += call_amount
            new_state.total_bets[player_id] += call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player_id] += amount
            new_state.total_bets[player_id] += amount
            new_state.pot_size += amount

        # Add to history
        new_state.action_history.append((player_id, action, amount))

        # Simplified game advancement
        if not new_state.is_terminal:
            new_state.current_player = 1 - player_id

            # Simple completion check
            if len(new_state.action_history) >= 4:  # Both players acted twice
                new_state = self._try_advance_street(new_state)
            elif action in [Action.CHECK, Action.CALL] and len(new_state.action_history) >= 2:
                # Check if both players checked or one called
                last_two = new_state.action_history[-2:]
                if all(act in [Action.CHECK, Action.CALL] for _, act, _ in last_two):
                    new_state = self._try_advance_street(new_state)

        return new_state

    def _try_advance_street(self, state: GameState) -> GameState:
        """Simplified street advancement."""
        street_map = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('terminal', 5)
        }

        if state.street in street_map:
            next_street, target_board_size = street_map[state.street]

            if next_street == 'terminal':
                state.is_terminal = True
            else:
                # Deal cards to reach target board size
                cards_needed = target_board_size - len(state.board)
                if cards_needed > 0:
                    # Get available cards
                    used_cards = set()
                    for hand in state.hands.values():
                        used_cards.update(hand)
                    used_cards.update(state.board)

                    available_cards = [c for c in DECK if c not in used_cards]
                    random.shuffle(available_cards)

                    # Deal new cards
                    new_cards = available_cards[:cards_needed]
                    state.board.extend(new_cards)

                state.street = next_street
                state.bets = {p: 0.0 for p in state.players}
                state.current_player = 0

        return state

    def _get_utility(self, state: GameState, player_id: int) -> float:
        """Simplified utility calculation."""
        if player_id not in state.players:
            return -state.total_bets.get(player_id, 0)

        if len(state.players) == 1:
            return state.pot_size - state.total_bets.get(player_id, 0)

        # Simplified showdown using hand strength
        if len(state.board) >= 3:  # Any postflop situation
            try:
                player_strength = evaluate_hand_strength(state.hands[player_id], state.board)

                best_opponent_strength = 0
                for opp_id in state.players:
                    if opp_id != player_id:
                        opp_strength = evaluate_hand_strength(state.hands[opp_id], state.board)
                        best_opponent_strength = max(best_opponent_strength, opp_strength)

                if player_strength > best_opponent_strength:
                    return state.pot_size - state.total_bets.get(player_id, 0)
                elif player_strength == best_opponent_strength:
                    return (state.pot_size / len(state.players)) - state.total_bets.get(player_id, 0)
                else:
                    return -state.total_bets.get(player_id, 0)
            except:
                pass

        # Fallback: random outcome weighted by pot investment
        random_outcome = random.choice([1, -1])
        return random_outcome * (state.pot_size / 2) - state.total_bets.get(player_id, 0)

    def _update_average_strategy(self):
        """Proper handling of action tuples."""
        for info_state in self.regret_sum:
            # The keys in regret_sum are (Action, amount) tuples
            action_tuples = list(self.regret_sum[info_state].keys())

            # Convert to the format expected by get_strategy
            legal_actions = [(action, amount) for action, amount in action_tuples]

            if legal_actions:
                strategy = self.get_strategy(info_state, legal_actions)

                for action_tuple in strategy:
                    self.strategy_sum[info_state][action_tuple] += strategy[action_tuple]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Get the final averaged policy."""
        final_policy = {}

        for info_state in self.strategy_sum:
            total_sum = sum(self.strategy_sum[info_state].values())
            if total_sum > 0:
                final_policy[info_state] = {
                    str(action): prob / total_sum
                    for action, prob in self.strategy_sum[info_state].items()
                }
            else:
                actions = list(self.strategy_sum[info_state].keys())
                uniform_prob = 1.0 / len(actions) if actions else 1.0
                final_policy[info_state] = {
                    str(action): uniform_prob for action in actions
                }

        return final_policy

    def save_policy(self, filepath: str):
        """Save the trained policy to disk."""
        policy_data = {
            'final_policy': self.get_final_policy(),
            'iteration_count': self.iteration_count,
            'regret_states': len(self.regret_sum)
        }

        with open(filepath, 'wb') as f:
            pickle.dump(policy_data, f)
        print(f"Policy saved to {filepath}")

# Test the complete system
if __name__ == "__main__":
    print("=== Testing Complete MCCFR System ===")

    try:
        # Test with mock abstraction manager
        print("Creating mock abstraction system...")
        manager = MockAbstractionManager(n_clusters_per_round=10)

        # Test predictions
        test_hands = [
            (['As', 'Kh'], ['Qc', '7d', '2s']),
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),
        ]

        for hand, board in test_hands:
            bucket = manager.get_postflop_bucket(hand, board)
            print(f"Hand {hand} + Board {board} -> Bucket {bucket}")

        print("Abstraction system working! Now testing MCCFR...")

        # Use the optimized trainer
        trainer = OptimizedMCCFRTrainer(manager)
        trainer.train(num_iterations=200)  # Start small

        # Show results
        final_policy = trainer.get_final_policy()
        print(f"\nTraining Results:")
        print(f"- Information states learned: {len(final_policy)}")
        print(f"- Total iterations: {trainer.iteration_count}")
        print(f"- Strategy updates: {trainer.strategy_updates}")

        if final_policy:
            example_info_state = list(final_policy.keys())[0]
            print(f"\nExample strategy for info state '{example_info_state}':")
            for action, prob in final_policy[example_info_state].items():
                print(f"  {action}: {prob:.3f}")

        trainer.save_policy("optimized_cfr_policy.pkl")

        print("SUCCESS: Complete system is working!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

=== Testing Complete MCCFR System ===
Creating mock abstraction system...
Hand ['As', 'Kh'] + Board ['Qc', '7d', '2s'] -> Bucket 7
Hand ['9h', '8c'] + Board ['7d', '6s', '2h', 'Tc'] -> Bucket 5
Abstraction system working! Now testing MCCFR...
Initializing Optimized MCCFR Trainer...
Optimized MCCFR Trainer initialized successfully.

=== Starting Optimized MCCFR Training for 200 iterations ===
Iteration 100/200 | Time: 1.9s | States: 673
Iteration 200/200 | Time: 3.9s | States: 705
Training complete! Total time: 3.9s
Final policy contains 0 information states

Training Results:
- Information states learned: 666
- Total iterations: 200
- Strategy updates: 4

Example strategy for info state 'B79SpAP0':
  (<Action.FOLD: 'fold'>, 0): 0.900
  (<Action.CALL: 'call'>, 1.0): 0.086
  (<Action.RAISE: 'raise'>, 3.0): 0.014
  (<Action.RAISE: 'raise'>, 999.0): 0.000
Policy saved to optimized_cfr_policy.pkl
SUCCESS: Complete system is working!


In [None]:
import numpy as np
import random
from collections import defaultdict
import time
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from enum import Enum
from copy import copy

class UltraFastMCCFRTrainer:
    """Ultra-optimized version for maximum training speed."""

    def __init__(self, abstraction_manager, initial_stack=1000.0):
        print("Initializing Ultra-Fast MCCFR Trainer...")

        self.abstraction_manager = abstraction_manager
        self.initial_stack = initial_stack

        # CFR data structures - using regular dicts for speed
        self.regret_sum = {}
        self.strategy_sum = {}

        self.iteration_count = 0
        self.strategy_updates = 0

        # PERFORMANCE OPTIMIZATIONS
        self.max_actions_per_state = 15  # Limit action history length
        self.max_recursion_depth = 12    # Reduced further
        self.current_depth = 0

        # Pre-computed common values
        self.streets = ['preflop', 'flop', 'turn', 'river']
        self.street_to_board_size = {'preflop': 0, 'flop': 3, 'turn': 4, 'river': 5}

        # Simplified preflop buckets
        self._init_fast_preflop_buckets()

        print("Ultra-Fast MCCFR Trainer initialized.")

    def _init_fast_preflop_buckets(self):
        """Lightning-fast preflop bucketing."""
        self.preflop_buckets = {}

        # Only the most essential hand categories
        premium = ['AA', 'KK', 'QQ', 'AKs', 'AKo']
        strong = ['JJ', 'TT', 'AQs', 'AQo', 'KQs']
        decent = ['99', '88', 'AJs', 'AJo', 'KJs', 'KQo']

        bucket = 60
        for group in [premium, strong, decent]:
            for hand in group:
                self.preflop_buckets[hand] = bucket
            bucket += 1

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Optimized preflop bucketing."""
        ranks = [card[0] for card in hand]
        suits = [card[1] for card in hand]

        # Quick pocket pair check
        if ranks[0] == ranks[1]:
            hand_str = ranks[0] + ranks[0]
        else:
            # Sort ranks by strength
            if RANKS.index(ranks[0]) > RANKS.index(ranks[1]):
                sorted_ranks = ranks
            else:
                sorted_ranks = [ranks[1], ranks[0]]

            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(sorted_ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """Ultra-simplified info state for maximum speed."""
        hand = state.hands[player_id]
        board = state.board

        # Get bucket
        if state.street == 'preflop':
            bucket = self.get_preflop_bucket(hand)
        else:
            bucket = self.abstraction_manager.get_postflop_bucket(hand, board)

        # ULTRA SIMPLIFIED: Only most recent action and pot size category
        last_action = 'x'  # default
        if state.action_history:
            last_action = state.action_history[-1][1].value[0]

        # Discretize pot to 5 categories
        pot_cat = min(4, int(state.pot_size / 20))

        # Minimal info state
        return f"B{bucket}S{state.street[0]}L{last_action}P{pot_cat}"

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """Streamlined action generation with safety guarantees."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack = self.initial_stack - state.total_bets.get(player_id, 0)

        # SAFETY: Ensure stack is positive
        if stack <= 0:
            return [(Action.FOLD, 0)]

        # Fold if facing bet
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        elif to_call <= stack:
            actions.append((Action.CALL, to_call))

        # Only ONE bet size for maximum speed
        pot_bet = max(state.pot_size * 0.8, 1.0)  # Slightly less than pot
        min_bet = max(to_call * 2, 1.0) if to_call > 0 else 1.0

        if pot_bet <= stack and pot_bet >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, pot_bet))
            else:
                actions.append((Action.RAISE, pot_bet))

        # SAFETY: Always return at least one action
        if not actions:
            if to_call > 0:
                actions = [(Action.FOLD, 0)]
            else:
                actions = [(Action.CHECK, 0)]

        return actions

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """Optimized regret matching with safety checks."""
        if info_state not in self.regret_sum:
            self.regret_sum[info_state] = {}

        strategy = {}

        # SAFETY: Handle empty legal_actions
        if not legal_actions:
            return strategy

        regret_sum = 0.0

        # Calculate positive regret sum
        for action_tuple in legal_actions:
            regret = self.regret_sum[info_state].get(action_tuple, 0.0)
            regret_sum += max(0, regret)

        # Generate strategy
        if regret_sum > 0:
            for action_tuple in legal_actions:
                regret = self.regret_sum[info_state].get(action_tuple, 0.0)
                strategy[action_tuple] = max(0, regret) / regret_sum
        else:
            # SAFETY: Check for division by zero
            if len(legal_actions) > 0:
                prob = 1.0 / len(legal_actions)
                for action_tuple in legal_actions:
                    strategy[action_tuple] = prob

        return strategy

    def train(self, num_iterations: int = 1000):
        """Ultra-optimized training loop."""
        print(f"\n=== Ultra-Fast MCCFR Training: {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            # Progress every 50 iterations for speed
            if (i + 1) % 50 == 0:
                elapsed = time.time() - start_time
                rate = (i + 1) / elapsed
                print(f"Iter {i+1}/{num_iterations} | {elapsed:.1f}s | {rate:.1f} iter/s | States: {len(self.regret_sum)}")

            # Reset depth
            self.current_depth = 0

            try:
                self._run_iteration()

                # Timeout protection (reduced to 2 minutes)
                if elapsed := time.time() - start_time > 120:
                    print(f"Timeout after {i+1} iterations")
                    break

            except Exception as e:
                if i < 10:  # Only print first few errors
                    print(f"Error iter {i+1}: {e}")
                continue

            # Update strategy more frequently
            if i % 25 == 0:
                self._update_strategy()

        total_time = time.time() - start_time
        rate = self.iteration_count / total_time if total_time > 0 else 0
        print(f"Training done! {total_time:.1f}s | {rate:.1f} iter/s | {len(self.regret_sum)} states")

    def _run_iteration(self):
        """Simplified iteration."""
        state = self._create_fast_state()

        # Alternate which player we're traversing
        traversing_player = self.iteration_count % 2

        try:
            self._mccfr(state, traversing_player, 1.0, 1.0)
        except Exception:
            pass  # Silently continue

    def _create_fast_state(self) -> GameState:
        """Lightning-fast state creation."""
        # Pre-shuffled deck approach
        deck_copy = DECK.copy()
        random.shuffle(deck_copy)

        return GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,
            street='preflop',
            board=[],
            hands={0: deck_copy[:2], 1: deck_copy[2:4]},
            bets={0: 1.0, 1: 2.0},
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

    def _mccfr(self, state: GameState, traversing_player: int, pi_p: float, pi_o: float) -> float:
        """Ultra-streamlined MCCFR."""

        # Depth protection
        self.current_depth += 1
        if self.current_depth > self.max_recursion_depth:
            self.current_depth -= 1
            return self._fast_utility(state, traversing_player)

        # Terminal checks
        if state.is_terminal or len(state.action_history) > self.max_actions_per_state:
            self.current_depth -= 1
            return self._fast_utility(state, traversing_player)

        player = state.current_player
        info_state = self.get_info_state(state, player)
        legal_actions = self.get_legal_actions(state, player)

        if not legal_actions:
            self.current_depth -= 1
            return self._fast_utility(state, traversing_player)

        strategy = self.get_strategy(info_state, legal_actions)

        # Action utilities
        utilities = {}
        for action_tuple in legal_actions:
            new_state = self._fast_apply_action(state, action_tuple[0], action_tuple[1])

            if player == traversing_player:
                utilities[action_tuple] = self._mccfr(
                    new_state, traversing_player,
                    pi_p * strategy[action_tuple], pi_o
                )
            else:
                utilities[action_tuple] = self._mccfr(
                    new_state, traversing_player,
                    pi_p, pi_o * strategy[action_tuple]
                )

        # Node utility
        node_util = sum(strategy[at] * utilities[at] for at in legal_actions)

        # Update regrets for traversing player
        if player == traversing_player:
            if info_state not in self.regret_sum:
                self.regret_sum[info_state] = {}

            for action_tuple in legal_actions:
                regret = utilities[action_tuple] - node_util
                current_regret = self.regret_sum[info_state].get(action_tuple, 0.0)
                self.regret_sum[info_state][action_tuple] = current_regret + pi_o * regret

        self.current_depth -= 1
        return node_util

    def _fast_apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """Optimized action application."""
        # Shallow copy for speed
        new_state = copy(state)
        new_state.bets = state.bets.copy()
        new_state.total_bets = state.total_bets.copy()
        new_state.players = state.players.copy()
        new_state.action_history = state.action_history.copy()
        new_state.hands = state.hands  # Reference copy - don't change hands
        new_state.board = state.board.copy()

        player = state.current_player

        # Apply action
        if action == Action.FOLD:
            new_state.is_terminal = True
            new_state.players.remove(player)
        elif action == Action.CHECK:
            pass
        elif action == Action.CALL:
            call_amt = max(state.bets.values()) - state.bets.get(player, 0)
            new_state.bets[player] += call_amt
            new_state.total_bets[player] += call_amt
            new_state.pot_size += call_amt
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player] += amount
            new_state.total_bets[player] += amount
            new_state.pot_size += amount

        # Add to history
        new_state.action_history.append((player, action, amount))

        if not new_state.is_terminal:
            new_state.current_player = 1 - player

            # Simple street advancement
            if self._should_advance_street(new_state):
                self._advance_street_fast(new_state)

        return new_state

    def _should_advance_street(self, state: GameState) -> bool:
        """Fast street advancement check."""
        if len(state.action_history) < 2:
            return False

        # Check if both players acted and betting is equal
        recent = state.action_history[-2:]
        actions = [act for _, act, _ in recent]

        # Advance on check-check or call patterns
        return (Action.CHECK in actions or
                Action.CALL in actions or
                len(state.action_history) >= 4)

    def _advance_street_fast(self, state: GameState):
        """Lightning-fast street advancement."""
        if state.street == 'preflop':
            state.street = 'flop'
            self._deal_board(state, 3)
        elif state.street == 'flop':
            state.street = 'turn'
            self._deal_board(state, 4)
        elif state.street == 'turn':
            state.street = 'river'
            self._deal_board(state, 5)
        else:
            state.is_terminal = True
            return

        # Reset for new street
        state.bets = {p: 0.0 for p in state.players}
        state.current_player = 0

    def _deal_board(self, state: GameState, target_size: int):
        """Fast board card dealing."""
        cards_needed = target_size - len(state.board)
        if cards_needed <= 0:
            return

        # Get used cards
        used = set(state.board)
        for hand in state.hands.values():
            used.update(hand)

        # Deal from available cards
        available = [c for c in DECK if c not in used]
        random.shuffle(available)
        state.board.extend(available[:cards_needed])

    def _fast_utility(self, state: GameState, player_id: int) -> float:
        """Super-fast utility calculation."""
        if player_id not in state.players:
            return -state.total_bets.get(player_id, 0)

        if len(state.players) == 1:
            return state.pot_size - state.total_bets.get(player_id, 0)

        # Quick random showdown
        if len(state.board) >= 3:
            # 50/50 chance with small bias toward better position
            win_prob = 0.5 + (0.1 if player_id == 0 else -0.1)
            if random.random() < win_prob:
                return state.pot_size - state.total_bets.get(player_id, 0)
            else:
                return -state.total_bets.get(player_id, 0)

        # Preflop: slight random outcome
        return random.choice([1, -1]) * (state.pot_size * 0.3)

    def _update_strategy(self):
        """Fast strategy update."""
        for info_state in self.regret_sum:
            actions = list(self.regret_sum[info_state].keys())
            strategy = self.get_strategy(info_state, actions)

            if info_state not in self.strategy_sum:
                self.strategy_sum[info_state] = {}

            for action_tuple in strategy:
                current = self.strategy_sum[info_state].get(action_tuple, 0.0)
                self.strategy_sum[info_state][action_tuple] = current + strategy[action_tuple]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Fast policy extraction."""
        policy = {}

        for info_state in self.strategy_sum:
            total = sum(self.strategy_sum[info_state].values())
            if total > 0:
                policy[info_state] = {
                    str(action): prob / total
                    for action, prob in self.strategy_sum[info_state].items()
                }

        return policy

    def save_policy(self, filepath: str):
        """Fast policy save."""
        import pickle

        data = {
            'policy': self.get_final_policy(),
            'iterations': self.iteration_count,
            'states': len(self.regret_sum),
            'updates': self.strategy_updates
        }

        with open(filepath, 'wb') as f:
            pickle.dump(data, f)

        print(f"Policy saved: {len(data['policy'])} states")


# USAGE EXAMPLE
if __name__ == "__main__":
    print("=== ULTRA-FAST MCCFR SYSTEM TEST ===")

    try:
        # Use your existing abstraction
        print("Setting up abstraction...")
        manager = FixedAbstractionManager(n_clusters_per_round=10)
        manager.train_all_postflop_models(n_samples_per_round=1000)

        print("Testing ultra-fast trainer...")
        trainer = UltraFastMCCFRTrainer(manager)

        # This should be MUCH faster
        trainer.train(num_iterations=1000)

        # Results
        policy = trainer.get_final_policy()
        print(f"\nResults:")
        print(f"- States: {len(policy)}")
        print(f"- Iterations: {trainer.iteration_count}")
        print(f"- Updates: {trainer.strategy_updates}")

        trainer.save_policy("ultra_fast_policy.pkl")
        print("SUCCESS: Ultra-fast training complete!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

=== ULTRA-FAST MCCFR SYSTEM TEST ===
Setting up abstraction...
Optimized PotentialAwareCalculator initialized.
Fixed AbstractionManager initialized.

=== STARTING K-MEANS MODEL TRAINING FOR ALL ROUNDS ===

--- Training flop model ---
--- Generating 1000 feature vectors for flop ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid vectors (0 failed) in 1.81s
Removed 4 duplicate vectors
Training K-Means on 996 vectors with 8 features
Using 10 clusters
Silhouette Score: 0.287
flop model trained successfully!

--- Training turn model ---
--- Generating 1000 feature vectors for turn ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid vectors (0 failed) in 3.01s
Removed 2 duplicate vectors
Training K-Means on 998 vectors with 8 features
Using 10 clusters
Silhouette Score: 0.270
turn model trained successfully!

--- Training river model ---
--- Generating 1000 feature vectors for river ---
  ...attempting 1000, collected 999/1000...
Generated 1000 valid vector

#Evaluations


In [None]:
# (Add this class to your evaluation_engine.py file)

# You will need your AbstractionManager here
# from final_abstraction_engine import FixedAbstractionManager

class RealMCCFRBotStrategy:
    """
    This is your real, trained AI. It loads the policy from disk and uses
    the abstraction manager to convert game states into info states.
    """
    def __init__(self, policy_file: str, abstraction_manager):
        print(f"Loading REAL trained policy from: {policy_file}")
        with open(policy_file, 'rb') as f:
            policy_data = pickle.load(f)

        # Adapt this key based on how you saved your policy
        self.policy = policy_data.get('final_policy', {})
        self.abstraction_manager = abstraction_manager
        print(f"Policy loaded. Agent knows {len(self.policy)} info states.")

    def get_action(self, state: GameState, player_id: int) -> Tuple[Action, float]:
        """Makes a decision using the loaded policy."""
        # 1. Convert the rich GameState into the simplified info_state string.
        info_state = self._get_info_state(state, player_id)

        strategy = self.policy.get(info_state)

        if strategy:
            # 2. Choose an action based on the learned probabilities.
            # The keys in the saved policy might be strings, need to parse them.

            # This parsing logic is complex and must match EXACTLY how you save it.
            # Let's assume a simpler format for now.
            actions_list = []
            probs_list = []
            for action_str, prob in strategy.items():
                # This is a placeholder for robust parsing logic
                try:
                    # e.g., action_str = "(<Action.RAISE: 'raise'>, 100.0)"
                    action_enum_str, amount_str = action_str.strip('()').split(', ')
                    action = Action(action_enum_str.split('.')[-1].strip('>').lower())
                    amount = float(amount_str)
                    actions_list.append((action, amount))
                    probs_list.append(prob)
                except:
                    continue # Skip malformed actions

            if not actions_list: return self._fallback_action(state, player_id)

            # Normalize probabilities to be safe
            probs_list = np.array(probs_list) / np.sum(probs_list)

            chosen_idx = np.random.choice(len(actions_list), p=probs_list)
            return actions_list[chosen_idx]
        else:
            # 3. If the state is new, use a safe fallback.
            return self._fallback_action(state, player_id)

    def _fallback_action(self, state: GameState, player_id: int) -> Tuple[Action, float]:
        """A safe action to take when an info state is not in the policy."""
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        if to_call == 0:
            return (Action.CHECK, 0)
        else:
            return (Action.FOLD, 0)

    def _get_info_state(self, state: GameState, player_id: int) -> str:
        """
        This must be the IDENTICAL info state generation logic from your trainer.
        """
        hand = state.hands[player_id]
        board = state.board

        if state.street == 'preflop':
            # This needs to call the preflop bucketing from your trainer
            bucket_id = 79 # fallback
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # Recreate the simplified action string
        recent_actions = state.action_history[-3:]
        action_str = ''.join([f"{a.value[0]}" for _, a, _ in recent_actions])
        pot_ratio = min(9, int(state.pot_size / 10))

        return f"B{bucket_id}S{state.street[0]}A{action_str}P{pot_ratio}"




In [None]:
import numpy as np
import random
import time
import pickle
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from enum import Enum
import matplotlib.pyplot as plt



# --- Basic Game Definitions ---
SUITS = ['h', 'd', 'c', 's']
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
DECK = [r + s for r in RANKS for s in SUITS]
RANK_MAP = {rank: i for i, rank in enumerate(RANKS)}

class Action(Enum):
    FOLD = "fold"
    CHECK = "check"
    CALL = "call"
    BET = "bet"
    RAISE = "raise"

@dataclass
class GameState:
    """Fixed GameState class - removed 'players' parameter causing the error."""
    current_player: int
    pot_size: float
    street: str
    board: List[str]
    hands: Dict[int, List[str]]
    bets: Dict[int, float]
    total_bets: Dict[int, float]
    action_history: List[Tuple[int, Action, float]]
    is_terminal: bool = False

    def __post_init__(self):
        # Initialize players list based on hands
        if not hasattr(self, 'players'):
            self.players = list(self.hands.keys())

class OpponentType(Enum):
    RANDOM = "random"
    TIGHT = "tight"
    LOOSE = "loose"
    AGGRESSIVE = "aggressive"
    STUMBOT = "stumbot"

@dataclass
class GameResult:
    """Results from a single poker game."""
    winner: int
    final_stacks: Dict[int, float]
    hands_played: int
    total_pot: float
    game_length: int  # Number of actions

@dataclass
class EvaluationResults:
    """Complete evaluation results."""
    win_rate: float
    avg_profit: float
    total_profit: float
    hands_played: int
    opponent_type: str
    confidence_interval: Tuple[float, float]
    game_results: List[GameResult]

class PokerGameEngine:
    """Complete poker game engine for evaluation."""

    def __init__(self, initial_stack=20000, small_blind=50, big_blind=100):
        self.initial_stack = initial_stack
        self.small_blind = small_blind
        self.big_blind = big_blind

    def play_game(self, player1_strategy, player2_strategy, max_hands=100) -> GameResult:
        """Play a complete poker game between two strategies."""

        # Initialize stacks
        stacks = {0: self.initial_stack, 1: self.initial_stack}
        hands_played = 0
        total_actions = 0
        total_pot_won = 0

        while hands_played < max_hands and min(stacks.values()) > self.big_blind:
            # Play one hand
            try:
                hand_result = self._play_hand(player1_strategy, player2_strategy, stacks, hands_played)

                # Update stacks
                stacks[0] += hand_result['player_0_change']
                stacks[1] += hand_result['player_1_change']
                total_pot_won += hand_result['pot_size']
                total_actions += hand_result['actions']
                hands_played += 1

                # Switch button/blinds
                if hands_played % 2 == 0:
                    stacks[0] -= self.small_blind
                    stacks[1] -= self.big_blind
                else:
                    stacks[0] -= self.big_blind
                    stacks[1] -= self.small_blind

            except Exception as e:
                print(f"Error in hand {hands_played}: {e}")
                hands_played += 1
                continue

        # Determine winner
        winner = 0 if stacks[0] > stacks[1] else 1

        return GameResult(
            winner=winner,
            final_stacks=stacks,
            hands_played=hands_played,
            total_pot=total_pot_won,
            game_length=total_actions
        )

    def _play_hand(self, p1_strategy, p2_strategy, stacks, hand_num) -> Dict:
        """Play a single poker hand."""

        # Deal cards
        deck = DECK.copy()
        random.shuffle(deck)
        hands = {0: deck[:2], 1: deck[2:4]}
        board = []

        # Initialize hand state
        pot = self.small_blind + self.big_blind
        bets = {0: self.small_blind, 1: self.big_blind}
        total_bets = bets.copy()

        # Determine positions (alternate button)
        button = hand_num % 2

        # FIXED: Create GameState without 'players' parameter
        state = GameState(
            current_player=button,  # Button acts first preflop
            pot_size=pot,
            street='preflop',
            board=board,
            hands=hands,
            bets=bets,
            total_bets=total_bets,
            action_history=[]
        )

        # Play through all streets
        action_count = 0
        max_actions = 20  # Prevent infinite loops

        while not state.is_terminal and action_count < max_actions:

            # Get current player's action
            current_player = state.current_player
            try:
                if current_player == 0:
                    action, amount = p1_strategy.get_action(state, current_player)
                else:
                    action, amount = p2_strategy.get_action(state, current_player)

                # Apply action
                state = self._apply_action(state, action, amount)
                action_count += 1

                # Check for street advancement or game end
                if self._should_advance_street(state):
                    state = self._advance_to_next_street(state, deck)

            except Exception as e:
                print(f"Error in action processing: {e}")
                # Force terminal state to avoid infinite loop
                state.is_terminal = True
                break

        # Calculate final result
        try:
            if state.is_terminal:
                payoffs = self._calculate_showdown(state)
            else:
                # Timeout - split pot
                payoffs = {0: 0, 1: 0}
        except Exception as e:
            print(f"Error in showdown: {e}")
            payoffs = {0: 0, 1: 0}

        return {
            'player_0_change': payoffs[0],
            'player_1_change': payoffs[1],
            'pot_size': state.pot_size,
            'actions': action_count
        }

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """Apply an action to the game state."""
        from copy import deepcopy
        new_state = deepcopy(state)

        player = state.current_player

        if action == Action.FOLD:
            new_state.is_terminal = True
            if player in new_state.players:
                new_state.players.remove(player)
        elif action == Action.CHECK:
            pass
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player, 0)
            call_amount = max(0, call_amount)  # Ensure non-negative
            new_state.bets[player] = new_state.bets.get(player, 0) + call_amount
            new_state.total_bets[player] = new_state.total_bets.get(player, 0) + call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            amount = max(amount, 1)  # Minimum bet size
            new_state.bets[player] = new_state.bets.get(player, 0) + amount
            new_state.total_bets[player] = new_state.total_bets.get(player, 0) + amount
            new_state.pot_size += amount

        # Add to action history
        new_state.action_history.append((player, action, amount))

        # Switch players
        if not new_state.is_terminal and len(new_state.players) > 1:
            new_state.current_player = 1 - player

        return new_state

    def _should_advance_street(self, state: GameState) -> bool:
        """Check if betting round is complete."""
        if len(state.action_history) < 2:
            return False

        if len(state.players) < 2:
            return False

        # Check if all active players have equal bets
        active_bets = [state.bets.get(p, 0) for p in state.players]
        if len(set(active_bets)) > 1:
            return False

        # Check for completion patterns
        recent = state.action_history[-2:]
        actions = [act for _, act, _ in recent]

        return Action.CHECK in actions or Action.CALL in actions

    def _advance_to_next_street(self, state: GameState, deck: List[str]) -> GameState:
        """Advance to the next betting round."""

        street_progression = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('showdown', 5)
        }

        if state.street in street_progression:
            next_street, board_size = street_progression[state.street]

            if next_street == 'showdown':
                state.is_terminal = True
            else:
                # Deal community cards
                used_cards = set()
                for hand in state.hands.values():
                    used_cards.update(hand)
                used_cards.update(state.board)

                available = [c for c in deck if c not in used_cards]
                cards_needed = board_size - len(state.board)
                if len(available) >= cards_needed:
                    state.board.extend(available[:cards_needed])
                else:
                    # Not enough cards, force terminal
                    state.is_terminal = True
                    return state

                state.street = next_street
                state.bets = {p: 0 for p in state.players}
                if state.players:
                    state.current_player = min(state.players)  # Reset to first active player

        return state

    def _calculate_showdown(self, state: GameState) -> Dict[int, float]:
        """Calculate final payoffs."""

        if len(state.players) == 1:
            # Someone folded
            winner = state.players[0]
            loser = 1 - winner
            return {
                winner: state.pot_size - state.total_bets.get(winner, 0),
                loser: -state.total_bets.get(loser, 0)
            }

        # Showdown with hand evaluation
        if len(state.board) == 5:
            try:
                p0_hand = state.hands[0] + state.board
                p1_hand = state.hands[1] + state.board

                p0_rank = evaluate_cards(*p0_hand)
                p1_rank = evaluate_cards(*p1_hand)

                if p0_rank < p1_rank:  # Lower rank wins
                    winner, loser = 0, 1
                elif p1_rank < p0_rank:
                    winner, loser = 1, 0
                else:
                    # Tie - split pot
                    pot_share = state.pot_size / 2
                    return {
                        0: pot_share - state.total_bets.get(0, 0),
                        1: pot_share - state.total_bets.get(1, 0)
                    }

                return {
                    winner: state.pot_size - state.total_bets.get(winner, 0),
                    loser: -state.total_bets.get(loser, 0)
                }

            except Exception as e:
                print(f"Hand evaluation error: {e}")
                # Fallback to random
                winner = random.choice([0, 1])
                loser = 1 - winner
                return {
                    winner: state.pot_size - state.total_bets.get(winner, 0),
                    loser: -state.total_bets.get(loser, 0)
                }

        # Early street showdown or error
        return {
            0: -state.total_bets.get(0, 0),
            1: -state.total_bets.get(1, 0)
        }

# Simplified strategy classes for testing
class RandomStrategy:
    """Random baseline opponent."""

    def get_action(self, state: GameState, player_id: int) -> Tuple[Action, float]:
        legal_actions = self._get_legal_actions(state, player_id)
        if not legal_actions:
            return (Action.FOLD, 0)
        return random.choice(legal_actions)

    def _get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        to_call = max(0, to_call)  # Ensure non-negative

        if to_call > 0:
            actions.extend([(Action.FOLD, 0), (Action.CALL, to_call)])
        else:
            actions.append((Action.CHECK, 0))

        # Simple bet sizing
        bet_size = max(state.pot_size * 0.75, 10)
        if to_call == 0:
            actions.append((Action.BET, bet_size))
        else:
            actions.append((Action.RAISE, bet_size))

        return actions

class TightStrategy:
    """Tight/conservative opponent."""

    def get_action(self, state: GameState, player_id: int) -> Tuple[Action, float]:
        hand = state.hands[player_id]

        # Simple hand strength evaluation
        hand_strength = self._evaluate_hand_strength(hand, state.board)

        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        to_call = max(0, to_call)

        # Tight strategy: only play strong hands
        if hand_strength < 0.6:
            if to_call > 0:
                return (Action.FOLD, 0)
            else:
                return (Action.CHECK, 0)
        elif hand_strength > 0.8:
            # Strong hand: bet or raise
            bet_size = state.pot_size * 0.6
            if to_call == 0:
                return (Action.BET, bet_size)
            else:
                return (Action.RAISE, bet_size)
        else:
            # Medium hand: call or check
            if to_call > 0:
                return (Action.CALL, to_call)
            else:
                return (Action.CHECK, 0)

    def _evaluate_hand_strength(self, hand: List[str], board: List[str]) -> float:
        """Simple hand strength evaluation."""
        if not board:
            # Preflop evaluation
            ranks = [RANKS.index(card[0]) for card in hand]
            if ranks[0] == ranks[1]:  # Pocket pair
                return 0.7 + max(ranks) * 0.02
            elif max(ranks) >= 10:  # High card
                return 0.5 + max(ranks) * 0.01
            else:
                return 0.3
        else:
            # Postflop: use hand evaluator if available
            try:
                full_hand = hand + board
                rank = evaluate_cards(*full_hand)
                # Convert rank to strength (lower rank = better hand)
                return max(0, 1.0 - rank / 7463)
            except:
                return 0.4  # Fallback

class MockMCCFRBotStrategy:
    """Mock strategy for testing when no trained model is available."""

    def __init__(self, policy_file: str = None, abstraction_manager=None):
        self.policy = {}
        print("Mock MCCFR Bot Strategy initialized (no real training)")

    def get_action(self, state: GameState, player_id: int) -> Tuple[Action, float]:
        """Simple mock strategy that plays reasonably."""
        hand = state.hands[player_id]

        # Basic hand evaluation
        hand_strength = self._evaluate_hand_strength(hand, state.board)
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        to_call = max(0, to_call)

        # Simple decision logic
        if hand_strength > 0.7:
            # Strong hand: bet/raise
            bet_size = state.pot_size * 0.8
            if to_call == 0:
                return (Action.BET, bet_size)
            else:
                return (Action.RAISE, bet_size)
        elif hand_strength > 0.4:
            # Medium hand: call/check
            if to_call > 0:
                return (Action.CALL, to_call)
            else:
                return (Action.CHECK, 0)
        else:
            # Weak hand: fold or check
            if to_call > 0:
                return (Action.FOLD, 0)
            else:
                return (Action.CHECK, 0)

    def _evaluate_hand_strength(self, hand: List[str], board: List[str]) -> float:
        """Simple hand strength evaluation."""
        if not board:
            # Preflop
            ranks = [RANKS.index(card[0]) for card in hand]
            return 0.3 + max(ranks) * 0.05
        else:
            # Postflop
            try:
                full_hand = hand + board
                rank = evaluate_cards(*full_hand)
                return max(0, 1.0 - rank / 7463)
            except:
                return random.uniform(0.2, 0.6)

class PokerBotEvaluator:
    """Main evaluation system for poker bots."""

    def __init__(self):
        self.game_engine = PokerGameEngine()
        self.results_history = []

    def evaluate_bot(self, bot_strategy, opponent_type: OpponentType,
                    num_games: int = 1000, verbose: bool = True) -> EvaluationResults:
        """Evaluate bot against specified opponent type."""

        if verbose:
            print(f"\n=== Evaluating Bot vs {opponent_type.value.upper()} ===")
            print(f"Playing {num_games} games...")

        # Create opponent strategy
        if opponent_type == OpponentType.RANDOM:
            opponent = RandomStrategy()
        elif opponent_type == OpponentType.TIGHT:
            opponent = TightStrategy()
        # --- ADD THIS NEW OPTION ---
        elif opponent_type == OpponentType.STUMBOT: # Note: STUMBOT is a typo in your class, should be SLUMBOT
            opponent = SlumbotStrategy()
        else:
            opponent = RandomStrategy() # Default fallback

        # Run games
        game_results = []
        profits = []

        start_time = time.time()

        for i in range(num_games):
            if verbose and (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"  Game {i+1}/{num_games} | {elapsed:.1f}s")

            try:
                # Alternate who plays first
                if i % 2 == 0:
                    result = self.game_engine.play_game(bot_strategy, opponent, max_hands=50)
                    bot_profit = result.final_stacks[0] - 20000  # Initial stack was 20000
                else:
                    result = self.game_engine.play_game(opponent, bot_strategy, max_hands=50)
                    bot_profit = result.final_stacks[1] - 20000

                game_results.append(result)
                profits.append(bot_profit)

            except Exception as e:
                print(f"Error in game {i+1}: {e}")
                # Add neutral result to continue
                profits.append(0)
                continue

        # Calculate statistics
        if not profits:
            print("No valid games completed!")
            return EvaluationResults(0, 0, 0, 0, opponent_type.value, (0, 0), [])

        win_rate = sum(1 for p in profits if p > 0) / len(profits)
        avg_profit = np.mean(profits)
        total_profit = sum(profits)

        # Calculate confidence interval (95%)
        std_error = np.std(profits) / np.sqrt(len(profits))
        ci_lower = avg_profit - 1.96 * std_error
        ci_upper = avg_profit + 1.96 * std_error

        evaluation_results = EvaluationResults(
            win_rate=win_rate,
            avg_profit=avg_profit,
            total_profit=total_profit,
            hands_played=sum(r.hands_played for r in game_results),
            opponent_type=opponent_type.value,
            confidence_interval=(ci_lower, ci_upper),
            game_results=game_results
        )

        if verbose:
            self._print_results(evaluation_results)

        self.results_history.append(evaluation_results)
        return evaluation_results

    def _print_results(self, results: EvaluationResults):
        """Print formatted evaluation results."""
        print(f"\n--- Results vs {results.opponent_type.upper()} ---")
        print(f"Win Rate: {results.win_rate:.1%}")
        print(f"Average Profit: {results.avg_profit:+.2f} chips")
        print(f"Total Profit: {results.total_profit:+.2f} chips")
        print(f"Hands Played: {results.hands_played}")
        print(f"95% CI: [{results.confidence_interval[0]:+.2f}, {results.confidence_interval[1]:+.2f}]")

        if results.avg_profit > 0:
            print("✅ PROFITABLE against this opponent")
        else:
            print("❌ LOSING against this opponent")

    def run_full_evaluation(self, bot_strategy, num_games_per_opponent: int = 1000):
        """Run complete evaluation against multiple opponent types."""
        print("\n🎯 === FULL POKER BOT EVALUATION ===")

        opponents = [OpponentType.RANDOM, OpponentType.TIGHT]

        for opponent_type in opponents:
            self.evaluate_bot(bot_strategy, opponent_type, num_games_per_opponent)

        self._print_summary()

    def _print_summary(self):
        """Print summary of all evaluations."""
        print(f"\n📊 === EVALUATION SUMMARY ===")

        if not self.results_history:
            print("No evaluation results available.")
            return

        total_profit = sum(r.total_profit for r in self.results_history)
        total_games = sum(len(r.game_results) for r in self.results_history)

        print(f"Overall Performance:")
        print(f"  Total Profit: {total_profit:+.2f} chips")
        print(f"  Games Played: {total_games}")
        if total_games > 0:
            print(f"  Average per Game: {total_profit/total_games:+.2f} chips")

        for result in self.results_history:
            print(f"  vs {result.opponent_type}: {result.win_rate:.1%} win rate, {result.avg_profit:+.2f} avg profit")

# USAGE EXAMPLE
if __name__ == "__main__":
    POLICY_FILE = "optimized_cfr_policy.pkl"
    print("🃏 === POKER BOT EVALUATION SYSTEM ===")

    try:
        # Create a mock bot strategy since models aren't trained yet
        print("Creating mock bot strategy for testing...")
        #bot_strategy = MockMCCFRBotStrategy()

        # Create evaluator
        evaluator = PokerBotEvaluator()

        # Run evaluation with fewer games for initial testing
        print("Running evaluation with mock strategy...")
        mock_bot = MockMCCFRBotStrategy()
        evaluator.evaluate_bot(mock_bot, OpponentType.TIGHT, num_games=100)

        print("✅ Evaluation system is ready for real models.")

        print("✅ Evaluation system test complete!")
        print("\nTo use with real trained models:")
        print("1. Train your abstraction manager and MCCFR policy")
        print("2. Replace MockMCCFRBotStrategy with your real MCCFRBotStrategy")
        print("3. Update file paths to your trained models")

    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()

🃏 === POKER BOT EVALUATION SYSTEM ===
Creating mock bot strategy for testing...
Running evaluation with mock strategy...
Mock MCCFR Bot Strategy initialized (no real training)

=== Evaluating Bot vs TIGHT ===
Playing 100 games...
  Game 100/100 | 2.8s

--- Results vs TIGHT ---
Win Rate: 22.0%
Average Profit: -1509.08 chips
Total Profit: -150908.30 chips
Hands Played: 5000
95% CI: [-1852.73, -1165.44]
❌ LOSING against this opponent
✅ Evaluation system is ready for real models.
✅ Evaluation system test complete!

To use with real trained models:
1. Train your abstraction manager and MCCFR policy
2. Replace MockMCCFRBotStrategy with your real MCCFRBotStrategy
3. Update file paths to your trained models


In [None]:
from __future__ import annotations
import itertools
import random
import time
import pickle
from typing import List, Dict, Tuple, Set
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass, field
from copy import deepcopy

# Card and game constants
SUITS = ['h', 'd', 'c', 's']  # Hearts, Diamonds, Clubs, Spades
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']

# Create a standard deck
DECK = [f"{rank}{suit}" for rank in RANKS for suit in SUITS]

class Action(Enum):
    """Poker actions available to players."""
    FOLD = "fold"
    CHECK = "check"
    CALL = "call"
    BET = "bet"
    RAISE = "raise"

@dataclass
class GameState:
    """
    Represents the complete state of a poker game.
    """
    players: List[int]
    current_player: int
    pot_size: float
    street: str  # 'preflop', 'flop', 'turn', 'river'
    board: List[str]  # Community cards
    hands: Dict[int, List[str]]  # Player hole cards
    bets: Dict[int, float]  # Current round bets
    total_bets: Dict[int, float]  # Total bets across all rounds
    action_history: List[Tuple[int, Action, float]] = field(default_factory=list)
    is_terminal: bool = False

    def copy(self) -> 'GameState':
        """Create a deep copy of the game state."""
        return GameState(
            players=self.players.copy(),
            current_player=self.current_player,
            pot_size=self.pot_size,
            street=self.street,
            board=self.board.copy(),
            hands={k: v.copy() for k, v in self.hands.items()},
            bets=self.bets.copy(),
            total_bets=self.total_bets.copy(),
            action_history=self.action_history.copy(),
            is_terminal=self.is_terminal
        )

class Card:
    """
    Represents a single playing card with rank and suit.
    """
    __slots__ = ('rank', 'suit')

    def __init__(self, rank: str, suit: str) -> None:
        if rank not in RANKS:
            raise ValueError(f"Invalid rank: {rank}. Must be one of {RANKS}")
        if suit not in SUITS:
            raise ValueError(f"Invalid suit: {suit}. Must be one of {SUITS}")
        self.rank = rank
        self.suit = suit

    def __repr__(self) -> str:
        return f"{self.rank}{self.suit}"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Card):
            return NotImplemented
        return self.rank == other.rank and self.suit == other.suit

    def __hash__(self) -> int:
        return hash((self.rank, self.suit))

    def get_numeric_rank(self) -> int:
        """Returns the numeric rank of the card (2-14, where Ace is 14)."""
        return RANKS.index(self.rank) + 2

def create_deck() -> List[Card]:
    """Creates a standard 52-card deck."""
    return [Card(rank, suit) for rank in RANKS for suit in SUITS]

def get_hand_bucket(hand: List[Card]) -> str:
    """
    Determines the preflop bucket for a given two-card hand.
    """
    if len(hand) != 2:
        raise ValueError("Hand must consist of exactly two cards.")

    # Sort cards by rank (higher rank first)
    sorted_hand = sorted(hand, key=lambda card: RANKS.index(card.rank), reverse=True)
    card1, card2 = sorted_hand

    if card1.rank == card2.rank:
        return f"{card1.rank}{card2.rank}"  # Pocket pair
    elif card1.suit == card2.suit:
        return f"{card1.rank}{card2.rank}s"  # Suited
    else:
        return f"{card1.rank}{card2.rank}o"  # Offsuit

# Simple hand evaluator (placeholder - you can integrate your actual evaluator)
def evaluate_hand_strength(hand: List[str], board: List[str]) -> float:
    """
    Simple hand strength evaluator. Replace with your actual hand evaluator.
    Returns a value between 0 and 1 where higher is better.
    """
    all_cards = hand + board
    if len(all_cards) < 5:
        return 0.5  # Unknown strength

    # Very simplified evaluation based on high cards
    ranks_values = {'2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8,
                   '9': 9, 'T': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14}

    card_ranks = [ranks_values[card[0]] for card in all_cards]
    max_rank = max(card_ranks)
    avg_rank = sum(card_ranks) / len(card_ranks)

    # Simple strength calculation
    strength = (max_rank + avg_rank) / 28.0  # Normalize roughly to 0-1
    return min(1.0, max(0.0, strength))

# Mock abstraction manager for testing
class MockAbstractionManager:
    """Simple mock abstraction manager for testing."""

    def __init__(self, n_clusters_per_round=10):
        self.n_clusters = n_clusters_per_round

    def get_postflop_bucket(self, hand: List[str], board: List[str]) -> int:
        """Return a simple bucket based on hand strength."""
        strength = evaluate_hand_strength(hand, board)
        return int(strength * (self.n_clusters - 1))

    def train_all_postflop_models(self, n_samples_per_round=1000):
        """Mock training method."""
        print(f"Mock abstraction manager trained with {n_samples_per_round} samples")

class OptimizedMCCFRTrainer:
    """Fixed version of MCCFR trainer with performance optimizations."""

    def __init__(self, abstraction_manager=None, initial_stack=1000.0):
        print("Initializing Optimized MCCFR Trainer...")

        if abstraction_manager is None:
            self.abstraction_manager = MockAbstractionManager(n_clusters_per_round=20)
            self.abstraction_manager.train_all_postflop_models(n_samples_per_round=1000)
        else:
            self.abstraction_manager = abstraction_manager

        self.initial_stack = initial_stack

        # CFR data structures
        self.regret_sum = defaultdict(lambda: defaultdict(float))
        self.strategy_sum = defaultdict(lambda: defaultdict(float))
        self.policy = defaultdict(lambda: defaultdict(float))

        self.iteration_count = 0
        self.strategy_updates = 0

        # Recursion depth tracking
        self.max_recursion_depth = 15
        self.current_depth = 0

        # Action sequence tracking to prevent infinite loops
        self.action_sequence_cache = {}

        self._initialize_preflop_buckets()

        print("Optimized MCCFR Trainer initialized successfully.")

    def _initialize_preflop_buckets(self):
        """Simplified preflop bucketing for testing."""
        self.preflop_buckets = {}

        # Simplified bucketing - just a few categories
        premium_hands = ['AA', 'KK', 'QQ', 'AKs', 'AKo']
        strong_hands = ['JJ', 'TT', 'AQs', 'AQo', 'AJs', 'AJo']
        medium_hands = ['99', '88', 'KQs', 'KQo', 'KJs', 'KJo']

        bucket_id = 60
        for hand in premium_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in strong_hands:
            self.preflop_buckets[hand] = bucket_id
        bucket_id += 1

        for hand in medium_hands:
            self.preflop_buckets[hand] = bucket_id

    def get_preflop_bucket(self, hand: List[str]) -> int:
        """Simplified preflop bucketing."""
        if len(hand) != 2:
            return 79  # Default bucket

        ranks = sorted([card[0] for card in hand], key=lambda x: RANKS.index(x), reverse=True)
        suits = [card[1] for card in hand]

        if ranks[0] == ranks[1]:
            hand_str = ranks[0] + ranks[0]
        else:
            suited = 's' if suits[0] == suits[1] else 'o'
            hand_str = ''.join(ranks) + suited

        return self.preflop_buckets.get(hand_str, 79)

    def get_info_state(self, state: GameState, player_id: int) -> str:
        """Simplified information state to prevent explosion."""
        hand = state.hands[player_id]
        board = state.board

        # Get bucket ID
        if state.street == 'preflop':
            bucket_id = self.get_preflop_bucket(hand)
        else:
            bucket_id = self.abstraction_manager.get_postflop_bucket(hand, board)

        # Only include essential info to prevent state explosion
        recent_actions = state.action_history[-3:] if len(state.action_history) > 3 else state.action_history
        action_str = ''.join([f"{a.value[0]}" for _, a, _ in recent_actions])

        pot_ratio = min(9, int(state.pot_size / 10))  # Discretize pot size

        info_state = f"B{bucket_id}S{state.street[0]}A{action_str}P{pot_ratio}"
        return info_state

    def get_legal_actions(self, state: GameState, player_id: int) -> List[Tuple[Action, float]]:
        """Simplified action space to prevent explosion."""
        actions = []
        to_call = max(state.bets.values()) - state.bets.get(player_id, 0)
        stack_size = self.initial_stack - state.total_bets.get(player_id, 0)

        # Always allow fold if facing a bet
        if to_call > 0:
            actions.append((Action.FOLD, 0))

        # Check/Call
        if to_call == 0:
            actions.append((Action.CHECK, 0))
        else:
            if to_call <= stack_size:
                actions.append((Action.CALL, to_call))

        # Simplified: Only 2 bet sizes instead of 5
        pot_bet = state.pot_size
        min_bet = max(to_call * 2, pot_bet * 0.5) if to_call > 0 else pot_bet * 0.5

        if pot_bet <= stack_size and pot_bet >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, pot_bet))
            else:
                actions.append((Action.RAISE, pot_bet))

        # All-in as second option
        if stack_size > pot_bet and stack_size >= min_bet:
            if to_call == 0:
                actions.append((Action.BET, stack_size))
            else:
                actions.append((Action.RAISE, stack_size))

        return actions if actions else [(Action.FOLD, 0)]

    def get_strategy(self, info_state: str, legal_actions: List[Tuple[Action, float]]) -> Dict[Tuple[Action, float], float]:
        """Consistent action tuple handling."""
        strategy = {}

        # Calculate regret sum for normalization
        regret_sum = 0.0
        for action_tuple in legal_actions:
            regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
            regret_sum += max(0, regret_value)

        # Generate strategy based on regret matching
        if regret_sum > 0:
            for action_tuple in legal_actions:
                regret_value = self.regret_sum[info_state].get(action_tuple, 0.0)
                strategy[action_tuple] = max(0, regret_value) / regret_sum
        else:
            # Uniform strategy if no positive regrets
            uniform_prob = 1.0 / len(legal_actions)
            for action_tuple in legal_actions:
                strategy[action_tuple] = uniform_prob

        return strategy

    def train(self, num_iterations: int = 1000):
        """Better training loop with timeout protection."""
        print(f"\n=== Starting Optimized MCCFR Training for {num_iterations} iterations ===")
        start_time = time.time()

        for i in range(num_iterations):
            self.iteration_count += 1

            # Progress reporting
            if (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"Iteration {i+1}/{num_iterations} | "
                      f"Time: {elapsed:.1f}s | "
                      f"States: {len(self.regret_sum)}")

            try:
                # Reset depth counter for each iteration
                self.current_depth = 0
                self._run_mccfr_iteration()

                # Timeout protection
                if time.time() - start_time > 300:  # 5 minute timeout
                    print(f"Training timeout after {i+1} iterations")
                    break

            except RecursionError:
                print(f"Recursion limit hit at iteration {i+1}")
                continue
            except Exception as e:
                print(f"Error in iteration {i+1}: {e}")
                continue

            if i % 50 == 0:  # More frequent strategy updates
                self._update_average_strategy()

        total_time = time.time() - start_time
        print(f"Training complete! Total time: {total_time:.1f}s")
        print(f"Final policy contains {len(self.policy)} information states")

    def _run_mccfr_iteration(self):
        """Simplified iteration with better initial state."""
        initial_state = self._create_simple_game_state()

        # Run MCCFR for both players
        for player_id in [0, 1]:
            self.current_depth = 0  # Reset depth
            try:
                self._mccfr_recursive(initial_state, player_id, 1.0, 1.0)
            except Exception as e:
                # Gracefully handle errors and continue
                continue

    def _create_simple_game_state(self) -> GameState:
        """Simplified initial game state."""
        deck = DECK.copy()
        random.shuffle(deck)

        # Simple preflop state
        hands = {0: deck[:2], 1: deck[2:4]}

        state = GameState(
            players=[0, 1],
            current_player=0,
            pot_size=3.0,
            street='preflop',
            board=[],
            hands=hands,
            bets={0: 1.0, 1: 2.0},  # Small blind, big blind
            total_bets={0: 1.0, 1: 2.0},
            action_history=[]
        )

        return state

    def _mccfr_recursive(self, state: GameState, traversing_player: int,
                         pi_player: float, pi_opponent: float) -> float:
        """Protected recursive function with depth limits."""

        # Depth protection
        self.current_depth += 1
        if self.current_depth > self.max_recursion_depth:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Terminal state check
        if state.is_terminal:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        # Action sequence protection against infinite loops
        if len(state.action_history) > 10:  # Reduced from 20
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        current_player = state.current_player
        info_state = self.get_info_state(state, current_player)
        legal_actions = self.get_legal_actions(state, current_player)

        if not legal_actions:
            self.current_depth -= 1
            return self._get_utility(state, traversing_player)

        strategy = self.get_strategy(info_state, legal_actions)

        # Calculate utilities for each action
        action_utilities = {}
        for action_tuple in legal_actions:
            try:
                new_state = self._apply_action(state, action_tuple[0], action_tuple[1])

                if current_player == traversing_player:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player,
                        pi_player * strategy[action_tuple], pi_opponent
                    )
                else:
                    action_utilities[action_tuple] = self._mccfr_recursive(
                        new_state, traversing_player, pi_player,
                        pi_opponent * strategy[action_tuple]
                    )
            except Exception:
                # Fallback utility if recursion fails
                action_utilities[action_tuple] = 0.0

        # Node utility calculation
        node_utility = sum(strategy[action_tuple] * action_utilities[action_tuple]
                          for action_tuple in legal_actions)

        # Update regrets
        if current_player == traversing_player:
            for action_tuple in legal_actions:
                regret = action_utilities[action_tuple] - node_utility
                self.regret_sum[info_state][action_tuple] += pi_opponent * regret

        self.current_depth -= 1
        return node_utility

    def _apply_action(self, state: GameState, action: Action, amount: float) -> GameState:
        """Simplified action application."""
        # Create a copy of the state
        new_state = state.copy()

        player_id = state.current_player

        # Apply the action
        if action == Action.FOLD:
            new_state.is_terminal = True
            if player_id in new_state.players:
                new_state.players.remove(player_id)
        elif action == Action.CHECK:
            pass  # No bet change
        elif action == Action.CALL:
            call_amount = max(state.bets.values()) - state.bets.get(player_id, 0)
            new_state.bets[player_id] += call_amount
            new_state.total_bets[player_id] += call_amount
            new_state.pot_size += call_amount
        elif action in [Action.BET, Action.RAISE]:
            new_state.bets[player_id] += amount
            new_state.total_bets[player_id] += amount
            new_state.pot_size += amount

        # Add to history
        new_state.action_history.append((player_id, action, amount))

        # Simplified game advancement
        if not new_state.is_terminal:
            new_state.current_player = 1 - player_id

            # Simple completion check
            if len(new_state.action_history) >= 4:  # Both players acted twice
                new_state = self._try_advance_street(new_state)
            elif action in [Action.CHECK, Action.CALL] and len(new_state.action_history) >= 2:
                # Check if both players checked or one called
                last_two = new_state.action_history[-2:]
                if all(act in [Action.CHECK, Action.CALL] for _, act, _ in last_two):
                    new_state = self._try_advance_street(new_state)

        return new_state

    def _try_advance_street(self, state: GameState) -> GameState:
        """Simplified street advancement."""
        street_map = {
            'preflop': ('flop', 3),
            'flop': ('turn', 4),
            'turn': ('river', 5),
            'river': ('terminal', 5)
        }

        if state.street in street_map:
            next_street, target_board_size = street_map[state.street]

            if next_street == 'terminal':
                state.is_terminal = True
            else:
                # Deal cards to reach target board size
                cards_needed = target_board_size - len(state.board)
                if cards_needed > 0:
                    # Get available cards
                    used_cards = set()
                    for hand in state.hands.values():
                        used_cards.update(hand)
                    used_cards.update(state.board)

                    available_cards = [c for c in DECK if c not in used_cards]
                    random.shuffle(available_cards)

                    # Deal new cards
                    new_cards = available_cards[:cards_needed]
                    state.board.extend(new_cards)

                state.street = next_street
                state.bets = {p: 0.0 for p in state.players}
                state.current_player = 0

        return state

    def _get_utility(self, state: GameState, player_id: int) -> float:
        """Simplified utility calculation."""
        if player_id not in state.players:
            return -state.total_bets.get(player_id, 0)

        if len(state.players) == 1:
            return state.pot_size - state.total_bets.get(player_id, 0)

        # Simplified showdown using hand strength
        if len(state.board) >= 3:  # Any postflop situation
            try:
                player_strength = evaluate_hand_strength(state.hands[player_id], state.board)

                best_opponent_strength = 0
                for opp_id in state.players:
                    if opp_id != player_id:
                        opp_strength = evaluate_hand_strength(state.hands[opp_id], state.board)
                        best_opponent_strength = max(best_opponent_strength, opp_strength)

                if player_strength > best_opponent_strength:
                    return state.pot_size - state.total_bets.get(player_id, 0)
                elif player_strength == best_opponent_strength:
                    return (state.pot_size / len(state.players)) - state.total_bets.get(player_id, 0)
                else:
                    return -state.total_bets.get(player_id, 0)
            except:
                pass

        # Fallback: random outcome weighted by pot investment
        random_outcome = random.choice([1, -1])
        return random_outcome * (state.pot_size / 2) - state.total_bets.get(player_id, 0)

    def _update_average_strategy(self):
        """Proper handling of action tuples."""
        for info_state in self.regret_sum:
            # The keys in regret_sum are (Action, amount) tuples
            action_tuples = list(self.regret_sum[info_state].keys())

            # Convert to the format expected by get_strategy
            legal_actions = [(action, amount) for action, amount in action_tuples]

            if legal_actions:
                strategy = self.get_strategy(info_state, legal_actions)

                for action_tuple in strategy:
                    self.strategy_sum[info_state][action_tuple] += strategy[action_tuple]

        self.strategy_updates += 1

    def get_final_policy(self) -> Dict[str, Dict]:
        """Get the final averaged policy."""
        final_policy = {}

        for info_state in self.strategy_sum:
            total_sum = sum(self.strategy_sum[info_state].values())
            if total_sum > 0:
                final_policy[info_state] = {
                    str(action): prob / total_sum
                    for action, prob in self.strategy_sum[info_state].items()
                }
            else:
                actions = list(self.strategy_sum[info_state].keys())
                uniform_prob = 1.0 / len(actions) if actions else 1.0
                final_policy[info_state] = {
                    str(action): uniform_prob for action in actions
                }

        return final_policy

    def save_policy(self, filepath: str):
        """Save the trained policy to disk."""
        policy_data = {
            'final_policy': self.get_final_policy(),
            'iteration_count': self.iteration_count,
            'regret_states': len(self.regret_sum)
        }

        with open(filepath, 'wb') as f:
            pickle.dump(policy_data, f)
        print(f"Policy saved to {filepath}")

# Test the complete system
if __name__ == "__main__":
    print("=== Testing Complete MCCFR System ===")

    try:
        # Test with mock abstraction manager
        print("Creating mock abstraction system...")
        manager = MockAbstractionManager(n_clusters_per_round=10)

        # Test predictions
        test_hands = [
            (['As', 'Kh'], ['Qc', '7d', '2s']),
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),
        ]

        for hand, board in test_hands:
            bucket = manager.get_postflop_bucket(hand, board)
            print(f"Hand {hand} + Board {board} -> Bucket {bucket}")

        print("Abstraction system working! Now testing MCCFR...")

        # Use the optimized trainer
        trainer = OptimizedMCCFRTrainer(manager)
        trainer.train(num_iterations=200)  # Start small

        # Show results
        final_policy = trainer.get_final_policy()
        print(f"\nTraining Results:")
        print(f"- Information states learned: {len(final_policy)}")
        print(f"- Total iterations: {trainer.iteration_count}")
        print(f"- Strategy updates: {trainer.strategy_updates}")

        if final_policy:
            example_info_state = list(final_policy.keys())[0]
            print(f"\nExample strategy for info state '{example_info_state}':")
            for action, prob in final_policy[example_info_state].items():
                print(f"  {action}: {prob:.3f}")

        trainer.save_policy("optimized_cfr_policy.pkl")

        print("SUCCESS: Complete system is working!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

=== Testing Complete MCCFR System ===
Creating mock abstraction system...
Hand ['As', 'Kh'] + Board ['Qc', '7d', '2s'] -> Bucket 7
Hand ['9h', '8c'] + Board ['7d', '6s', '2h', 'Tc'] -> Bucket 5
Abstraction system working! Now testing MCCFR...
Initializing Optimized MCCFR Trainer...
Optimized MCCFR Trainer initialized successfully.

=== Starting Optimized MCCFR Training for 200 iterations ===
Iteration 100/200 | Time: 2.8s | States: 697
Iteration 200/200 | Time: 4.8s | States: 733
Training complete! Total time: 4.8s
Final policy contains 0 information states

Training Results:
- Information states learned: 698
- Total iterations: 200
- Strategy updates: 4

Example strategy for info state 'B79SpAP0':
  (<Action.FOLD: 'fold'>, 0): 0.986
  (<Action.CALL: 'call'>, 1.0): 0.000
  (<Action.RAISE: 'raise'>, 3.0): 0.014
  (<Action.RAISE: 'raise'>, 999.0): 0.000
Policy saved to optimized_cfr_policy.pkl
SUCCESS: Complete system is working!


In [None]:
ga

Choose an option:
1. Debug policy file
2. Run full bot evaluation
3. Quick test only
Enter choice (1-3): 2
🔍 Debugging policy file...
Policy data type: <class 'dict'>
Policy data keys: ['final_policy', 'iteration_count', 'regret_states']
Final policy has 1120 info states

Sample info state: B79SpAP0
Sample strategy: {"(<Action.FOLD: 'fold'>, 0)": 0.05875968242716244, "(<Action.CALL: 'call'>, 1.0)": 0.01439961759067168, "(<Action.RAISE: 'raise'>, 3.0)": 0.40851586448811095, "(<Action.RAISE: 'raise'>, 999.0)": 0.5183248354940548}

Policy looks good! Starting evaluation...
🚀 === TESTING YOUR REAL TRAINED MCCFR BOT ===
Loading abstraction manager...
Optimized PotentialAwareCalculator initialized.
Fixed AbstractionManager initialized.
Loading your trained MCCFR bot...
Loading REAL trained policy from: optimized_cfr_policy.pkl
Policy loaded. Agent knows 1120 info states.

🎯 === EVALUATION RESULTS ===

🔍 Quick test vs Random (50 games)...

=== Evaluating Bot vs RANDOM ===
Playing 50 games...


#END


In [None]:
import random
import time
from collections import defaultdict
import traceback

# Mock the phevaluator functions for testing if not available
try:
    from phevaluator import evaluate_cards, card_to_string, string_to_card
    PHEVALUATOR_AVAILABLE = True
except ImportError:
    print("phevaluator not available, using mock functions")
    PHEVALUATOR_AVAILABLE = False

    def evaluate_cards(*cards):
        """Mock function that returns a random poker hand rank"""
        return random.randint(1, 7462)

# Basic Game Definitions
SUITS = ['h', 'd', 'c', 's']
RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
DECK = [r + s for r in RANKS for s in SUITS]
RANK_MAP = {rank: i for i, rank in enumerate(RANKS)}

class DebugPotentialAwareCalculator:
    """
    Debug version of PotentialAwareCalculator to identify issues.
    """
    def __init__(self):
        self.cache = {}
        print("DebugPotentialAwareCalculator initialized.")

    def _calculate_current_equity(self, hand, board, remaining_deck, num_sims=100):
        """Simplified equity calculation for debugging."""
        try:
            if not hand or not isinstance(hand, list) or len(hand) != 2:
                print(f"Invalid hand: {hand}")
                return 0.0

            if not isinstance(board, list):
                print(f"Invalid board: {board}")
                return 0.0

            if len(remaining_deck) < 4:
                print(f"Not enough cards in remaining deck: {len(remaining_deck)}")
                return 0.0

            wins = 0
            valid_sims = 0

            for _ in range(num_sims):
                try:
                    # Need at least 2 cards for opponent + remaining board cards
                    cards_needed = 2 + max(0, 5 - len(board))

                    if len(remaining_deck) < cards_needed:
                        continue

                    samples = random.sample(remaining_deck, cards_needed)
                    opp_hand = samples[:2]
                    board_runout = samples[2:]
                    final_board = board + board_runout

                    # Ensure we have exactly 5 board cards
                    if len(final_board) > 5:
                        final_board = final_board[:5]
                    elif len(final_board) < 5:
                        # Add random cards to complete board
                        remaining_for_board = [c for c in remaining_deck if c not in samples]
                        while len(final_board) < 5 and remaining_for_board:
                            final_board.append(remaining_for_board.pop())

                    if len(final_board) != 5:
                        continue

                    # Evaluate hands
                    all_player_cards = hand + final_board
                    all_opp_cards = opp_hand + final_board

                    player_rank = evaluate_cards(*all_player_cards)
                    opp_rank = evaluate_cards(*all_opp_cards)

                    if player_rank < opp_rank:  # Lower rank = better hand
                        wins += 1
                    elif player_rank == opp_rank:
                        wins += 0.5

                    valid_sims += 1

                except Exception as e:
                    continue

            if valid_sims == 0:
                print("No valid simulations completed")
                return 0.0

            equity = wins / valid_sims
            print(f"Equity calculation: {wins}/{valid_sims} = {equity:.3f}")
            return equity

        except Exception as e:
            print(f"Error in _calculate_current_equity: {e}")
            return 0.0

    def _calculate_flush_potential(self, hand, board, remaining_deck):
        """Calculate flush potential."""
        try:
            hand_and_board = hand + board
            suit_counts = {'h': 0, 'd': 0, 'c': 0, 's': 0}

            for card in hand_and_board:
                if len(card) >= 2:
                    suit = card[1]
                    if suit in suit_counts:
                        suit_counts[suit] += 1

            # Check for 4-card flush draw
            for suit, count in suit_counts.items():
                if count == 4:
                    flush_outs = 13 - count  # 9 remaining suited cards
                    deck_size = len(remaining_deck)

                    if len(board) == 3:  # Flop
                        p_miss_turn = (deck_size - flush_outs) / deck_size
                        p_miss_river = (deck_size - 1 - flush_outs) / (deck_size - 1)
                        return 1 - (p_miss_turn * p_miss_river)
                    elif len(board) == 4:  # Turn
                        return flush_outs / deck_size

            return 0.0

        except Exception as e:
            print(f"Error in _calculate_flush_potential: {e}")
            return 0.0

    def _calculate_straight_potential(self, hand, board):
        """Simplified straight potential calculation."""
        try:
            # For now, just return a random value between 0 and 0.3
            return random.random() * 0.3
        except Exception as e:
            print(f"Error in _calculate_straight_potential: {e}")
            return 0.0

    def calculate_feature_vector(self, hand, board):
        """
        Main method with extensive debugging.
        """
        print(f"\n=== Calculating feature vector ===")
        print(f"Hand: {hand}")
        print(f"Board: {board}")

        try:
            # Input validation
            if not isinstance(hand, list) or len(hand) != 2:
                print(f"ERROR: Invalid hand format: {hand}")
                return []

            if not isinstance(board, list) or len(board) not in [3, 4, 5]:
                print(f"ERROR: Invalid board format: {board}")
                return []

            # Check for valid card format
            for card in hand + board:
                if not isinstance(card, str) or len(card) != 2:
                    print(f"ERROR: Invalid card format: {card}")
                    return []
                if card[0] not in RANKS or card[1] not in SUITS:
                    print(f"ERROR: Invalid card: {card}")
                    return []

            # Check for duplicate cards
            all_cards = hand + board
            if len(set(all_cards)) != len(all_cards):
                print(f"ERROR: Duplicate cards found: {all_cards}")
                return []

            print("Input validation passed")

            # Calculate remaining deck
            known_cards = hand + board
            remaining_deck = [card for card in DECK if card not in known_cards]
            print(f"Remaining deck size: {len(remaining_deck)}")

            # Calculate features
            print("Calculating equity...")
            equity = self._calculate_current_equity(hand, board, remaining_deck)

            print("Calculating flush potential...")
            flush_potential = self._calculate_flush_potential(hand, board, remaining_deck)

            print("Calculating straight potential...")
            straight_potential = self._calculate_straight_potential(hand, board)

            # Create feature vector
            feature_vector = [equity, flush_potential, straight_potential]

            print(f"Raw feature vector: {feature_vector}")

            # Validate feature vector
            for i, val in enumerate(feature_vector):
                if not isinstance(val, (int, float)):
                    print(f"ERROR: Feature {i} is not numeric: {val} (type: {type(val)})")
                    return []
                if val < 0 or val > 1:
                    print(f"WARNING: Feature {i} out of range [0,1]: {val}")
                    # Clamp to valid range
                    feature_vector[i] = max(0.0, min(1.0, val))

            print(f"Final feature vector: {feature_vector}")
            return feature_vector

        except Exception as e:
            print(f"CRITICAL ERROR in calculate_feature_vector: {e}")
            traceback.print_exc()
            return []

class DebugAbstractionManager:
    """Simplified manager for debugging."""

    def __init__(self):
        self.calculator = DebugPotentialAwareCalculator()
        print("DebugAbstractionManager initialized")

    def test_calculator(self):
        """Test the calculator with various inputs."""
        test_cases = [
            # Valid cases
            (['As', 'Kh'], ['Qc', '7d', '2s']),  # Flop
            (['9h', '8c'], ['7d', '6s', '2h', 'Tc']),  # Turn
            (['Ah', '2h'], ['3h', '4h', '5c', 'Kd', '9s']),  # River

            # Edge cases
            (['AA', 'KK'], ['QQ', '77', '22']),  # Invalid card format
            (['As', 'As'], ['Qc', '7d', '2s']),  # Duplicate cards
            (['As'], ['Qc', '7d', '2s']),  # Too few hole cards
            (['As', 'Kh'], ['Qc', '7d']),  # Too few board cards
        ]

        print("\n" + "="*50)
        print("TESTING CALCULATOR")
        print("="*50)

        success_count = 0
        for i, (hand, board) in enumerate(test_cases):
            print(f"\nTest {i+1}: Hand {hand}, Board {board}")
            try:
                result = self.calculator.calculate_feature_vector(hand, board)
                if result and len(result) > 0:
                    print(f"SUCCESS: {result}")
                    success_count += 1
                else:
                    print(f"FAILED: Empty or invalid result")
            except Exception as e:
                print(f"EXCEPTION: {e}")

        print(f"\nSUMMARY: {success_count}/{len(test_cases)} tests passed")
        return success_count > 0

# Test the debug version
if __name__ == "__main__":
    print("=== DEBUGGING CALCULATOR ISSUES ===")

    manager = DebugAbstractionManager()

    # Test individual calculator calls
    success = manager.test_calculator()

    if success:
        print("\n✅ Calculator is working for some cases")
        print("The issue might be in the data generation loop")
    else:
        print("\n❌ Calculator is completely broken")
        print("Need to fix the calculate_feature_vector method")

    # Test a simple generation loop
    print("\n" + "="*50)
    print("TESTING DATA GENERATION LOOP")
    print("="*50)

    SUITS = ['h', 'd', 'c', 's']
    RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
    DECK = [r + s for r in RANKS for s in SUITS]

    valid_vectors = []
    for i in range(100):
        try:
            deck = DECK.copy()
            random.shuffle(deck)

            hand = deck[:2]
            board = deck[2:5]  # Flop

            vector = manager.calculator.calculate_feature_vector(hand, board)

            if vector and len(vector) > 0:
                valid_vectors.append(vector)
                if len(valid_vectors) <= 5:
                    print(f"Valid vector {len(valid_vectors)}: {vector}")
        except Exception as e:
            continue

    print(f"\nGenerated {len(valid_vectors)}/100 valid vectors")

    if len(valid_vectors) > 0:
        print("✅ Data generation is working!")
    else:
        print("❌ Data generation is completely failing")

phevaluator not available, using mock functions
=== DEBUGGING CALCULATOR ISSUES ===
DebugPotentialAwareCalculator initialized.
DebugAbstractionManager initialized

TESTING CALCULATOR

Test 1: Hand ['As', 'Kh'], Board ['Qc', '7d', '2s']

=== Calculating feature vector ===
Hand: ['As', 'Kh']
Board: ['Qc', '7d', '2s']
Input validation passed
Remaining deck size: 47
Calculating equity...
Equity calculation: 46/100 = 0.460
Calculating flush potential...
Calculating straight potential...
Raw feature vector: [0.46, 0.0, 0.2758369088220143]
Final feature vector: [0.46, 0.0, 0.2758369088220143]
SUCCESS: [0.46, 0.0, 0.2758369088220143]

Test 2: Hand ['9h', '8c'], Board ['7d', '6s', '2h', 'Tc']

=== Calculating feature vector ===
Hand: ['9h', '8c']
Board: ['7d', '6s', '2h', 'Tc']
Input validation passed
Remaining deck size: 46
Calculating equity...
Equity calculation: 40/100 = 0.400
Calculating flush potential...
Calculating straight potential...
Raw feature vector: [0.4, 0.0, 0.2353888402563104]