In [1]:
import re
import itertools
from dataclasses import dataclass
from functools import lru_cache
from collections import Counter, defaultdict

import pandas as pd


# =========================
# Card utilities
# =========================

RANKS = "23456789TJQKA"
RANK_TO_INT = {r: i+2 for i, r in enumerate(RANKS)}  # 2..14
INT_TO_RANK = {v: k for k, v in RANK_TO_INT.items()}
SUITS = "cdhs"

@dataclass(frozen=True)
class Card:
    rank: int   # 2..14
    suit: str   # c,d,h,s

def parse_card(token: str) -> Card:
    token = token.strip()
    if len(token) != 2:
        raise ValueError(f"Bad card token: {token}")
    r, s = token[0], token[1].lower()
    if r not in RANK_TO_INT or s not in SUITS:
        raise ValueError(f"Bad card token: {token}")
    return Card(RANK_TO_INT[r], s)

def parse_flop(flop_str: str):
    parts = flop_str.strip().split()
    if len(parts) != 3:
        raise ValueError(f"Flop must have 3 cards like 'Ac 9s 9d', got: {flop_str}")
    flop = [parse_card(p) for p in parts]
    # validate no duplicates
    if len({(c.rank, c.suit) for c in flop}) != 3:
        raise ValueError(f"Duplicate cards on flop: {flop_str}")
    return flop


# =========================
# Range parsing and combo expansion
# =========================

def _normalize_hand_token(tok: str) -> str:
    return tok.strip()

def _split_range_tokens(range_str: str):
    # comma-separated
    toks = [t.strip() for t in range_str.split(",") if t.strip()]
    return toks

def _parse_weighted_groups(range_str: str):
    """
    Supports the format you described:
      - tokens separated by commas
      - if a token has ":0.xxx", that weight applies to *all* tokens since the last weight marker (inclusive)
      - tokens after that start a new group until next weight marker
      - any trailing tokens with no weight marker default to weight 1.0
    Example:
      "AK,AQ,AJ:0.2,KQo:0.5,22"
      => AK,AQ,AJ all weight 0.2 ; KQo weight 0.5 ; 22 weight 1.0
    """
    toks = _split_range_tokens(range_str)
    groups = []
    pending = []
    pending_weight = None

    for tok in toks:
        if ":" in tok:
            base, w = tok.split(":")
            base = base.strip()
            w = float(w.strip())
            # add base to pending then flush pending at weight w
            pending.append(base)
            for t in pending:
                groups.append((t, w))
            pending = []
            pending_weight = None
        else:
            pending.append(tok)

    # remaining tokens => weight 1.0
    for t in pending:
        groups.append((t, 1.0))

    return groups


def _hand_token_to_combos(hand: str):
    """
    Expand a hand token into explicit 2-card combos (rank,suit) x 2.

    Tokens:
      - Pair: "AA", "TT"
      - Suited: "AJs"
      - Offsuit: "QJo"
      - Both suited+offsuit: "AK", "KQ", "AJ" (no s/o)
    """
    hand = hand.strip()
    # Pair
    if len(hand) == 2 and hand[0] == hand[1]:
        r = RANK_TO_INT[hand[0]]
        # choose 2 suits among 4 => 6 combos
        combos = []
        for s1, s2 in itertools.combinations(SUITS, 2):
            combos.append((Card(r, s1), Card(r, s2)))
        return combos

    # Non-pair
    # forms: "AK", "AJs", "QJo"
    suited_flag = None
    if len(hand) == 3 and hand[-1] in ("s", "o"):
        suited_flag = hand[-1]
        r1 = RANK_TO_INT[hand[0]]
        r2 = RANK_TO_INT[hand[1]]
    elif len(hand) == 2:
        r1 = RANK_TO_INT[hand[0]]
        r2 = RANK_TO_INT[hand[1]]
    else:
        raise ValueError(f"Unrecognized hand token: {hand}")

    combos = []
    if suited_flag == "s":
        for s in SUITS:
            combos.append((Card(r1, s), Card(r2, s)))
    elif suited_flag == "o":
        for s1 in SUITS:
            for s2 in SUITS:
                if s1 != s2:
                    combos.append((Card(r1, s1), Card(r2, s2)))
    else:
        # both suited and offsuit
        for s1 in SUITS:
            for s2 in SUITS:
                if s1 == s2:
                    combos.append((Card(r1, s1), Card(r2, s2)))
                else:
                    combos.append((Card(r1, s1), Card(r2, s2)))

    # Remove impossible duplicates (AK and KA are equivalent tokens, but we keep as ordered hole cards)
    return combos


@lru_cache(maxsize=5000)
def expand_range_to_weighted_combos(range_str: str):
    """
    Returns list of ((card1, card2), weight_per_combo)
    where weight_per_combo is the fraction of combos included for that hand token/group.
    """
    hand_weights = _parse_weighted_groups(range_str)

    weighted = []
    for hand, w in hand_weights:
        combos = _hand_token_to_combos(hand)
        # each combo gets weight w
        for c1, c2 in combos:
            weighted.append(((c1, c2), w))
    return weighted


def filter_blocked_combos(weighted_combos, board_cards):
    board_set = {(c.rank, c.suit) for c in board_cards}
    out = []
    for (c1, c2), w in weighted_combos:
        if (c1.rank, c1.suit) in board_set: 
            continue
        if (c2.rank, c2.suit) in board_set:
            continue
        # also avoid identical hole cards (shouldn't happen)
        if (c1.rank, c1.suit) == (c2.rank, c2.suit):
            continue
        out.append(((c1, c2), w))
    return out


# =========================
# Hand feature helpers (made / draw logic)
# =========================

def _is_flush(cards):
    suit_counts = Counter([c.suit for c in cards])
    return max(suit_counts.values()) >= 5

def _flush_suits(cards):
    suit_counts = Counter([c.suit for c in cards])
    return [s for s, ct in suit_counts.items() if ct >= 5]

def _unique_ranks(cards):
    return sorted(set([c.rank for c in cards]), reverse=True)

def _is_straight_from_ranks(ranks):
    # ranks: iterable of ints
    rs = sorted(set(ranks))
    # wheel support
    if 14 in rs:
        rs = rs + [1]
    # scan for 5 consecutive
    streak = 1
    for i in range(1, len(rs)):
        if rs[i] == rs[i-1] + 1:
            streak += 1
            if streak >= 5:
                return True
        elif rs[i] != rs[i-1]:
            streak = 1
    return False

def _is_straight(cards):
    return _is_straight_from_ranks([c.rank for c in cards])

def _is_straight_flush(cards):
    # check straight within any suit with >=5
    suit_groups = defaultdict(list)
    for c in cards:
        suit_groups[c.suit].append(c)
    for s, group in suit_groups.items():
        if len(group) >= 5 and _is_straight(group):
            return True
    return False


def board_distinct_ranks(flop):
    # distinct ranks sorted high->low
    return sorted(set([c.rank for c in flop]), reverse=True)

def made_hand_bucket(hole, flop):
    """
    Returns one made-hand bucket label for (hole cards, flop).
    Buckets are mutually exclusive and intended to sum to 1.0 when aggregated.
    """
    c1, c2 = hole
    board = flop
    cards7 = [c1, c2] + board

    rank_counts = Counter([c.rank for c in cards7])
    counts_sorted = sorted(rank_counts.values(), reverse=True)

    # Big categories first
    if _is_straight_flush(cards7):
        return "straight_flush"
    if 4 in counts_sorted:
        return "four_of_a_kind"
    if (3 in counts_sorted) and (2 in counts_sorted or counts_sorted.count(3) >= 2):
        return "full_house"
    if _is_flush(cards7):
        return "flush"
    if _is_straight(cards7):
        return "straight"

    # Trips / Two pair / One pair
    if 3 in counts_sorted:
        # Determine if it's a "set" (pocket pair matches one board card) vs "trips" (other)
        pocket_pair = (c1.rank == c2.rank)
        if pocket_pair:
            # if pocket pair rank appears on board exactly once => set
            board_count_same = sum(1 for b in board if b.rank == c1.rank)
            if board_count_same == 1:
                return "set"
        # otherwise trips (e.g., A9 on 9x y, or board trips with any hole)
        return "trips"

    if counts_sorted.count(2) >= 2:
        return "two_pair"

    if 2 in counts_sorted:
        # exactly one pair (since no trips / two pair)
        board_ranks = board_distinct_ranks(board)  # may be len 1-3
        hi = max(board_ranks)
        lo = min(board_ranks)

        pocket_pair = (c1.rank == c2.rank)
        if pocket_pair:
            # Overpair / underpair / "low_pair_between" (pocket pair between board ranks)
            if c1.rank > hi:
                return "overpair"
            if c1.rank < lo:
                return "underpair"
            return "low_pair"  # pocket pair between board ranks

        # non-pocket single pair:
        # paired rank is the one with count==2
        paired_rank = None
        for r, ct in rank_counts.items():
            if ct == 2:
                paired_rank = r
                break

        # classify vs board ranks
        # define top/second/third based on distinct board ranks
        top = board_ranks[0] if len(board_ranks) >= 1 else None
        second = board_ranks[1] if len(board_ranks) >= 2 else None
        third = board_ranks[2] if len(board_ranks) >= 3 else None

        if paired_rank == top:
            return "top_pair"
        if second is not None and paired_rank == second:
            return "second_pair"
        if third is not None and paired_rank == third:
            return "third_pair"
        return "low_pair"

    # High-card buckets
    hole_ranks = sorted([c1.rank, c2.rank], reverse=True)
    if hole_ranks[0] == 14:
        return "ace_high"
    return "no_made_hand"


def straight_draw_type(hole, flop):
    """
    Determine straight draw type from the 5 cards (hole+flop).
    Returns: "oesd", "gutshot", or None
    Uses "number of completion ranks" logic:
      - if >=2 distinct ranks complete a straight => OESD
      - if 1 => gutshot
    """
    cards5 = [hole[0], hole[1]] + flop
    ranks = set([c.rank for c in cards5])
    # wheel support by adding 1 when Ace present
    if 14 in ranks:
        ranks.add(1)

    # all possible 5-rank straights
    straights = [
        {14, 13, 12, 11, 10},  # TJQKA (Ace-high)
        {13, 12, 11, 10, 9},
        {12, 11, 10, 9, 8},
        {11, 10, 9, 8, 7},
        {10, 9, 8, 7, 6},
        {9, 8, 7, 6, 5},
        {8, 7, 6, 5, 4},
        {7, 6, 5, 4, 3},
        {6, 5, 4, 3, 2},
        {5, 4, 3, 2, 1},       # A2345 (wheel)
    ]

    # if already straight on flop, we still consider draw as None (you already "made it")
    for st in straights:
        if len(st.intersection(ranks)) == 5:
            return None

    completion = set()
    for st in straights:
        inter = st.intersection(ranks)
        if len(inter) == 4:
            missing = list(st - inter)[0]
            completion.add(missing)

    if len(completion) >= 2:
        return "oesd"
    if len(completion) == 1:
        return "gutshot"
    return None


def flush_draw_type(hole, flop):
    """
    Returns:
      - "flush_draw_nuts"
      - "flush_draw"
      - "bdfd_2cards"
      - "bdfd_1card"
      - None
    Based on suits among (hole+flop) (5 cards total).
    """
    cards5 = [hole[0], hole[1]] + flop
    suit_counts = Counter([c.suit for c in cards5])
    max_suit, max_ct = max(suit_counts.items(), key=lambda x: x[1])

    # flush draw: 4 to a flush
    if max_ct == 4:
        # nuts if player holds Ace of that suit
        ace_of_suit = (hole[0].rank == 14 and hole[0].suit == max_suit) or (hole[1].rank == 14 and hole[1].suit == max_suit)
        if ace_of_suit:
            return "flush_draw_nuts"
        return "flush_draw"

    # backdoor flush draw: exactly 3 to a suit (and not already 4)
    if max_ct == 3:
        # distinguish using 2 hole cards vs 1
        hole_suits = [hole[0].suit, hole[1].suit]
        board_suits = [c.suit for c in flop]

        # bdfd using 2 cards: both hole cards same suit AND board has exactly 1 of that suit
        if hole_suits[0] == hole_suits[1] and board_suits.count(hole_suits[0]) == 1:
            return "bdfd_2cards"

        # bdfd using 1 card: hole has 1 of the suit and board has 2 of it
        # (includes case where hole cards are different suits)
        # Find a suit where board has 2 and hole has 1
        for s in SUITS:
            if board_suits.count(s) == 2 and hole_suits.count(s) == 1:
                return "bdfd_1card"

    return None


def draw_bucket(hole, flop):
    """
    Mutually exclusive draw buckets that sum to 1.0 across all combos.
    Priority:
      combo_draw > FD nuts > FD > OESD > gutshot > BDFD(2) > BDFD(1) > no_draw
    """
    fd = flush_draw_type(hole, flop)
    sd = straight_draw_type(hole, flop)

    # combo draw = flush draw (4 to a flush) AND straight draw (oesd/gutshot)
    if fd in ("flush_draw_nuts", "flush_draw") and sd in ("oesd", "gutshot"):
        return "combo_draw"

    if fd == "flush_draw_nuts":
        return "flush_draw_nuts"
    if fd == "flush_draw":
        return "flush_draw"
    if sd == "oesd":
        return "oesd"
    if sd == "gutshot":
        return "gutshot"
    if fd == "bdfd_2cards":
        return "bdfd_2cards"
    if fd == "bdfd_1card":
        return "bdfd_1card"
    return "no_draw"


# =========================
# Aggregation over a range
# =========================

MADE_BUCKETS = [
    "straight_flush",
    "four_of_a_kind",
    "full_house",
    "flush",
    "straight",
    "set",
    "trips",
    "two_pair",
    "overpair",
    "top_pair",
    "underpair",
    "second_pair",
    "third_pair",
    "low_pair",
    "ace_high",
    "no_made_hand",
]

DRAW_BUCKETS = [
    "combo_draw",
    "flush_draw_nuts",
    "flush_draw",
    "oesd",
    "gutshot",
    "bdfd_2cards",
    "bdfd_1card",
    "no_draw",
]


def summarize_range(range_str: str, flop_str: str):
    flop = parse_flop(flop_str)
    weighted = expand_range_to_weighted_combos(range_str)
    weighted = filter_blocked_combos(weighted, flop)

    made_counts = {k: 0.0 for k in MADE_BUCKETS}
    draw_counts = {k: 0.0 for k in DRAW_BUCKETS}

    total_w = 0.0
    for (c1, c2), w in weighted:
        total_w += w
        mh = made_hand_bucket((c1, c2), flop)
        dr = draw_bucket((c1, c2), flop)
        made_counts[mh] += w
        draw_counts[dr] += w

    # normalize to proportions (sum to 1.0)
    if total_w > 0:
        made_pct = {k: made_counts[k] / total_w for k in MADE_BUCKETS}
        draw_pct = {k: draw_counts[k] / total_w for k in DRAW_BUCKETS}
    else:
        # degenerate: no combos left after blockers (should be rare)
        made_pct = {k: 0.0 for k in MADE_BUCKETS}
        draw_pct = {k: 0.0 for k in DRAW_BUCKETS}

    return made_pct, draw_pct


# =========================
# Apply to your dataset
# =========================

def add_range_features(df: pd.DataFrame,
                       hero_col="Hero Range",
                       vil_col="Villain Range",
                       flop_col="Flop"):
    # Create columns
    hero_made_cols = [f"HeroMade_{k}" for k in MADE_BUCKETS]
    vil_made_cols  = [f"VillainMade_{k}" for k in MADE_BUCKETS]
    hero_draw_cols = [f"HeroDraw_{k}" for k in DRAW_BUCKETS]
    vil_draw_cols  = [f"VillainDraw_{k}" for k in DRAW_BUCKETS]

    for c in hero_made_cols + vil_made_cols + hero_draw_cols + vil_draw_cols:
        if c not in df.columns:
            df[c] = 0.0

    # Row-wise compute
    for i, row in df.iterrows():
        flop = row[flop_col]
        hero_range = row[hero_col]
        villain_range = row[vil_col]

        h_made, h_draw = summarize_range(hero_range, flop)
        v_made, v_draw = summarize_range(villain_range, flop)

        for k in MADE_BUCKETS:
            df.at[i, f"HeroMade_{k}"] = h_made[k]
            df.at[i, f"VillainMade_{k}"] = v_made[k]
        for k in DRAW_BUCKETS:
            df.at[i, f"HeroDraw_{k}"] = h_draw[k]
            df.at[i, f"VillainDraw_{k}"] = v_draw[k]

    return df


# =========================
# Run
# =========================

if __name__ == "__main__":
    df = pd.read_csv("PokerData.csv")

    # sanity check columns exist
    required = ["Flop", "Hero Range", "Villain Range"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    df2 = add_range_features(df)

    # Optional: quick checks that each set sums to ~1.0 (floating tolerance)
    df2["HeroMade_sum"] = df2[[f"HeroMade_{k}" for k in MADE_BUCKETS]].sum(axis=1)
    df2["VillainMade_sum"] = df2[[f"VillainMade_{k}" for k in MADE_BUCKETS]].sum(axis=1)
    df2["HeroDraw_sum"] = df2[[f"HeroDraw_{k}" for k in DRAW_BUCKETS]].sum(axis=1)
    df2["VillainDraw_sum"] = df2[[f"VillainDraw_{k}" for k in DRAW_BUCKETS]].sum(axis=1)

    print("Sanity check (first 5 rows):")
    print(df2[["HeroMade_sum","VillainMade_sum","HeroDraw_sum","VillainDraw_sum"]].head())

    # Save
    df2.to_csv("PokerData_with_range_features.csv", index=False)
    print("\nSaved: PokerData_with_range_features.csv")


Sanity check (first 5 rows):
   HeroMade_sum  VillainMade_sum  HeroDraw_sum  VillainDraw_sum
0           1.0              1.0           1.0              1.0
1           1.0              1.0           1.0              1.0
2           1.0              1.0           1.0              1.0
3           1.0              1.0           1.0              1.0
4           1.0              1.0           1.0              1.0

Saved: PokerData_with_range_features.csv
