In [13]:
# Lost Cities V1C
# Starting with V1A, adding:
# 1. the new legal block to draw an X card from center when exp started, and 6+ cards in deck
# 2. new bonus for playing a next number card in an open exp
# KEY:
# Adding final score tracking into replay buffer via episode transitions
# Tracking P1 and P2, with boosts for final reward (score) and for step rewards

In [1]:
import random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict, deque
import random
import pandas as pd
from collections import defaultdict
import pprint

In [None]:
# GAME SETUP
COLORS = ['R', 'B', 'G']
NUMBERS = ['X', '2', '3', '4', '5', '6']
CARD_TO_IDX = {color + num: idx for idx, (color, num) in enumerate([(c, n) for c in COLORS for n in NUMBERS])}
draw_to_index = {draw: i for i, draw in enumerate(['deck', 'R', 'B', 'G'])}
tgt_pts = 7
color_cnt=len(COLORS)
per_color=len(NUMBERS)
card_cnt=len(COLORS)*len(NUMBERS)

In [2]:
# BASICS & CLASSES

# Deck creation
def create_deck():
    deck = [color + num for color in COLORS for num in NUMBERS]
    random.shuffle(deck)
    return deck

# Environment Class (unchanged)
class LostCitiesEnv:
    def __init__(self):
        self.reset()
        # self.last_discard = None

    def reset(self):
        self.deck = create_deck()
        self.hands = {f'P{i+1}': [self.deck.pop() for _ in range(3)] for i in range(2)}
        self.expeditions = {player: defaultdict(list) for player in self.hands}
        self.center_piles = {color: [] for color in COLORS}
        self.players = list(self.hands.keys())
        self.current_player_idx = 0
        self.done = False
        # self.last_discard = None
        self.last_discards = {player: None for player in self.players}  # ✅ per-player discard tracking
        return self.get_state()

    def get_state(self):
        state = {
            'hands': {p: sorted(self.hands[p]) for p in self.hands},
            'expeditions': {p: {color: list(cards) for color, cards in self.expeditions[p].items()} for p in self.expeditions},
            'center': {color: list(self.center_piles[color]) for color in COLORS},
            'deck_size': len(self.deck),
            'current_player': self.players[self.current_player_idx]
        }
        return state

    def can_play_to_expedition(self, expedition, card):
        if not expedition:
            return True
        if card[1] == 'X':
            return False
        existing_values = [int(c[1]) for c in expedition if c[1] != 'X']
        last_number = max(existing_values) if existing_values else None
        if last_number is None:
            return True
        return int(card[1]) >= last_number

    def get_legal_actions(self, player):
        hand = self.hands[player]
        expeditions = self.expeditions[player]
        deck_size = len(self.deck)
    
        playable_to_expedition = []
        playable_to_center = []

        for color in COLORS:
            # Number cards (non-'X') in hand
            color_number_cards = [card for card in hand if card[0] == color and card[1] != 'X']
            expedition = expeditions[color]
            existing_numbers = [int(c[1]) for c in expedition if c[1] != 'X']
            highest_played = max(existing_numbers) if existing_numbers else 0
            valid_cards = [card for card in color_number_cards if int(card[1]) >= highest_played]
        
            # Apply lowest-card rule if enough deck remains
            if valid_cards:
                deck_remaining = len(self.deck)
                if deck_remaining >= 3:
                    min_val = min(int(c[1]) for c in valid_cards)
                    valid_cards = [c for c in valid_cards if int(c[1]) == min_val]
        
                playable_to_expedition.extend(valid_cards)
        
            # Multipliers (only if no numbers yet in expedition) - but maybe no other same color cards
            # multiplier_cards = [card for card in hand if card[0] == color and card[1] == 'X']
            # if multiplier_cards and not existing_numbers:
            #     playable_to_expedition.extend(multiplier_cards)

            # Multipliers (only if no numbers yet in expedition) - but must have at least one same color card
            multiplier_cards = [card for card in hand if card[0] == color and card[1] == 'X']
            if multiplier_cards and not existing_numbers:
                # Check for sum of same color cards 4+
                number_points = sum(int(c[1]) for c in hand if c[0] == color and c[1] != 'X')
                if number_points >= 4:
                    playable_to_expedition.extend(multiplier_cards)
                # Check for presence of at least one number card 4+
                # has_decent_number = any(int(c[1]) >= 4 for c in hand if c[0] == color and c[1] != 'X')
                # if has_decent_number:
                #     playable_to_expedition.extend(multiplier_cards)


        if random.random() < 1e-8:
            if not playable_to_expedition:
                print(f"\n[DEBUG] No playable expedition cards for player {player}")
                print(f"Hand: {sorted(hand)}")
                print(f"Expeditions:")
                for c in COLORS:
                    exp_cards = expeditions[c]
                    print(f"  {c}: {exp_cards}")
                print(f"Deck remaining: {len(self.deck)}\n")
                
        # ✅ Multipliers: Only if expedition has no numbers yet
        for card in hand:
            if card[1] == 'X':
                expedition = expeditions[card[0]]
                existing_numbers = [c for c in expedition if c[1] != 'X']
                if not existing_numbers:
                    playable_to_expedition.append(card)
                
        # ✅ Discards: Always legal to discard any card
        # for card in hand:
        #     playable_to_center.append(card)

        # Modification - do not play to center a card which could be put on an open exp
        for card in hand:
            expedition_pile = expeditions[card[0]]
            existing_numbers = [int(c[1]) for c in expedition_pile if c[1] != 'X']
            if existing_numbers:
                top_val = max(existing_numbers)
                card_val = int(card[1]) if card[1] != 'X' else None
                if card_val is not None and card_val >= top_val:
                    continue  # Don't allow discard—card is playable
            playable_to_center.append(card)
    
        actions = [("expedition", card) for card in playable_to_expedition] + [("center", card) for card in playable_to_center]
    
        # ✅ Drawing logic with redraw rule
        draws = ['deck'] + [c for c in COLORS if self.center_piles[c]]
        player_last_discard = self.last_discards.get(player, None)
        if player_last_discard is not None:
            discard_color, discard_card = player_last_discard
            if self.center_piles[discard_color] and self.center_piles[discard_color][-1] == discard_card:
                if discard_color in draws:
                    draws.remove(discard_color)

        # Filter out existing draws
        filtered_draws = []

        # New legal block, drawing an X when already started color
        deck_remaining = len(self.deck)
        
        for d in draws:
            if d == 'deck':
                filtered_draws.append(d)
            elif deck_remaining < 6:
                # Allow any center draws when deck is low
                filtered_draws.append(d)
            else:
                center_pile = self.center_piles[d]
                if center_pile:
                    top_card = center_pile[-1]
                    if top_card[1] == 'X':
                        expedition_started = bool(self.expeditions[player][d])  # expedition already started
                        if not expedition_started:
                            filtered_draws.append(d)
                        # else → skip this draw: can't use multiplier once expedition started
                    else:
                        filtered_draws.append(d)

        # Now reassign
        draws=filtered_draws

        return actions, draws

    def step(self, action, draw_choice):
        if self.done:
            return self.get_state(), 0, True
    
        player = self.players[self.current_player_idx]
        action_type, card = action
    
        # ✅ Remove card from hand BEFORE anything else
        assert card in self.hands[player], f"Player {player} does not have card {card}!"
        self.hands[player].remove(card)
    
        # ✅ Play action: place card on expedition or center pile
        if action_type == 'expedition':
            self.expeditions[player][card[0]].append(card)
            self.last_discards[player] = None
        elif action_type == 'center':
            self.center_piles[card[0]].append(card)
            self.last_discards[player] = (card[0], card)
        else:
            raise ValueError(f"Unknown action type: {action_type}")
    
        # ✅ Draw phase (only happens ONCE per turn)
        if draw_choice == 'deck':
            if self.deck:
                self.hands[player].append(self.deck.pop())
            else:
                # No deck left, but should trigger game over below
                pass
        elif draw_choice in self.center_piles and self.center_piles[draw_choice]:
            self.hands[player].append(self.center_piles[draw_choice].pop())
        else:
            raise ValueError(f"Invalid draw_choice: {draw_choice}")
    
        # ✅ Hand must return to exactly 3 cards
        assert len(self.hands[player]) == 3, f"{player} has {len(self.hands[player])} cards in hand!"
    
        # ✅ Check for end of game (deck empty)
        if not self.deck:
            self.done = True
            reward = self.compute_score(player)
            return self.get_state(), reward, True
    
        # === ASSERTIONS ===

        # 1. Each player must have exactly 3 cards
        for p in self.players:
            assert len(self.hands[p]) == 3, f"{p} has {len(self.hands[p])} cards!"
    
        # 2. Each center pile must have at most 6 cards (can't discard more than that in 18-card game)
        for color, pile in self.center_piles.items():
            assert len(pile) <= per_color, f"Center pile {color} has {len(pile)} cards!"
    
        # 3. Each expedition pile must have at most 6 cards (max number per color)
        for player_exped in self.expeditions.values():
            for color, pile in player_exped.items():
                assert len(pile) <= per_color, f"Expedition {color} has {len(pile)} cards!"
            
        # ✅ Card count integrity check (should always be card_cnt)
        total_cards = sum(len(h) for h in self.hands.values()) + \
                      sum(len(p) for p in self.center_piles.values()) + \
                      sum(len(pile) for player_piles in self.expeditions.values() for pile in player_piles.values()) + \
                      len(self.deck)
        assert total_cards == card_cnt, f"Card count mismatch! Total cards: {total_cards}"

        # Validate expedition state for current player
        player_exped = self.expeditions[player]
        
        # Get list of expedition colors where at least one card has been played
        started_expeditions = [color for color, pile in player_exped.items() if pile]
        
        # Assert: Never more than 3 expeditions started
        assert len(started_expeditions) <= color_cnt, f"Too many expeditions started: {started_expeditions}"
        
        # Assert: No duplicate colors (guaranteed by defaultdict keys, but we double-check)
        assert len(started_expeditions) == len(set(started_expeditions)), f"Duplicate expedition colors: {started_expeditions}"

        # ✅ Switch to next player
        self.current_player_idx = (self.current_player_idx + 1) % len(self.players)

        return self.get_state(), 0, False

    def compute_score(self, player):
        total = 0
        for color, cards in self.expeditions[player].items():
            if cards:
                values = [int(card[1]) for card in cards if card[1] != 'X']
                multiplier = 1 + sum(1 for card in cards if card[1] == 'X')
                expedition_score = multiplier * (sum(values) - tgt_pts)
                total += expedition_score
        return total

# Enhanced Feature Extraction
def extract_features(state):
    current_player = state['current_player']
    other_player = [p for p in state['hands'] if p != current_player][0]

    # Hand cards (18)
    hand_counts = np.zeros(len(CARD_TO_IDX))
    for card in state['hands'][current_player]:
        idx = CARD_TO_IDX[card]
        hand_counts[idx] += 1

    # Player expeditions (9 features)
    player_exped = []
    for color in COLORS:
        cards = state['expeditions'][current_player].get(color, [])
        values = [int(c[1]) for c in cards if c[1] != 'X']
        total = sum(values)
        multiplier = int(any(c[1] == 'X' for c in cards))
        count = len(cards)
        player_exped.extend([total, multiplier, count])

    # Discard piles (6)
    discard_info = []
    for color in COLORS:
        pile = state['center'][color]
        top_card = pile[-1][1] if pile else '0'
        top = int(top_card) if top_card != 'X' else 0
        # top = int(pile[-1][1]) if pile else 0
        count = len(pile)
        discard_info.extend([top, count])

    # Opponent expeditions (9 features)
    opp_exped = []
    for color in COLORS:
        cards = state['expeditions'][other_player].get(color, [])
        values = [int(c[1]) for c in cards if c[1] != 'X']
        total = sum(values)
        multiplier = int(any(c[1] == 'X' for c in cards))
        count = len(cards)
        opp_exped.extend([total, multiplier, count])

    # Deck size (1)
    deck_norm = state['deck_size'] / card_cnt

    return np.concatenate([hand_counts, player_exped, discard_info, opp_exped, [deck_norm]])

def summarize_rule_firings():
    return pprint.pformat(dict(rule_counter))

In [3]:
# Actor-Critic Network
class ActorCritic(nn.Module):
    def __init__(self, state_size, action_size, draw_size):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_size, 96) # 1=64
        self.fc2 = nn.Linear(96, 32)         # 1=32
        self.dropout = nn.Dropout(p=0.15)    # 1=0.05

        # super(ActorCritic, self).__init__()
        # self.fc1 = nn.Linear(state_size, 64)
        # # self.fc2 = nn.Linear(64, 32)
        # # self.dropout = nn.Dropout(p=0.1)

        # Two separate policy heads
        self.policy_action_head = nn.Linear(32, action_size)
        self.policy_draw_head = nn.Linear(32, draw_size)
        # self.policy_action_head = nn.Linear(64, action_size)
        # self.policy_draw_head = nn.Linear(64, draw_size)

        # Single value head
        self.value_head = nn.Linear(32, 1)
        # self.value_head = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)

        policy_action_logits = self.policy_action_head(x)
        policy_draw_logits = self.policy_draw_head(x)
        value = self.value_head(x)

        return policy_action_logits, policy_draw_logits, value


In [4]:
def compute_step_reward(state, action, draw_choice, env):
    # Setup
    step_reward = 0.0
    max_reward = 2.0
    player_hand = state['hands'][state['current_player']]
    is_expedition = action[0] == 'expedition'
    is_number_card = action[1][1] != 'X'
    played_color = action[1][0]
    played_value = int(action[1][1]) if is_number_card else None
    expedition_pile = env.expeditions[state['current_player']][played_color]
    existing_numbers = [int(c[1]) for c in expedition_pile if c[1] != 'X']
    deck_remaining = state['deck_size']
   
    # Bad Move 1: Playing high card when holding lower card
    same_color_cards = [c for c in player_hand if c[0] == action[1][0] and c[1] != 'X']
    if same_color_cards and action[0] == 'expedition' and is_number_card:
        min_in_hand = min([int(c[1]) for c in same_color_cards])
        if int(action[1][1]) > min_in_hand:
            step_reward -= 1.0
            rule_counter["lower_val_avail"] += 1

    # Bad Move 2: Starting expedition with <6 points in hand
    if is_expedition and len(env.expeditions[state['current_player']][action[1][0]]) == 0:
        color_sum = sum([int(c[1]) for c in player_hand if c[0] == action[1][0] and c[1] != 'X'])
        if color_sum < tgt_pts:
            step_reward -= 0.5
            rule_counter["too_few_pts"] += 1

    # Bad Move 3: Starting expedition when opponent blocks reaching 7
    opponent = [p for p in env.players if p != state['current_player']][0]
    opp_cards = env.expeditions[opponent].get(action[1][0], [])
    opp_sum = sum([int(c[1]) for c in opp_cards if c[1] != 'X'])
    hand_sum = sum([int(c[1]) for c in player_hand if c[0] == action[1][0] and c[1] != 'X'])
    if action[0] == 'expedition' and len(env.expeditions[state['current_player']][action[1][0]]) == 0:
        if opp_sum + hand_sum < tgt_pts:
            step_reward -= 0.8
            rule_counter["blocked_7"] += 1

    # Bad Move 4: Starting expedition with <=3 cards left
    if is_expedition and len(env.expeditions[state['current_player']][action[1][0]]) == 0:
        if state['deck_size'] <= 3:
            step_reward -= 0.5
            rule_counter["exp_small_deck"] += 1

    # Bad Move 5: Discarding to center when expedition is started and playable
    expedition_pile = env.expeditions[state['current_player']][action[1][0]]
    if action[0] == 'center' and expedition_pile:
        top_val = max([int(c[1]) for c in expedition_pile if c[1] != 'X'], default=0)
        card_val = int(action[1][1]) if action[1][1] != 'X' else None
        if card_val is not None and card_val >= top_val:
            step_reward -= 1.5
            rule_counter["exp_was_live"] += 1

    # Good Move: Playing strong expedition (holding >=6 points)
    if is_expedition:
        color_sum = sum([int(c[1]) for c in player_hand if c[0] == action[1][0] and c[1] != 'X'])
        if color_sum >= tgt_pts-1:
            step_reward += 0.3
            rule_counter["good_exp_1"] += 1
        if color_sum >= tgt_pts:
            step_reward += 1.0
            rule_counter["good_exp"] += 1

    # Penalty 3: Playing RX/BX/GX with no number cards in hand
    if is_expedition and action[1][1] == 'X':
        same_color_numbers = [c for c in player_hand if c[0] == action[1][0] and c[1] != 'X']
        if not same_color_numbers:
            step_reward -= 1.2
            rule_counter["bad_X"] += 1

    # Penalty 4: Playing R5 while holding R2 or R3
    if is_expedition and action[1][1] != 'X':
        played_value = int(action[1][1])
        lower_cards = [int(c[1]) for c in player_hand if c[0] == action[1][0] and c[1] != 'X' and int(c[1]) < played_value]
        if lower_cards:
            step_reward -= 0.8
            rule_counter["bad_bigger_val"] += 1

    # Bonus: Excellent lowest-card play with multiple cards in hand
    if is_expedition and is_number_card:
        played_color = action[1][0]
        played_value = int(action[1][1])
        same_color_cards = [int(c[1]) for c in player_hand if c[0] == played_color and c[1] != 'X']
        total_points_in_hand = sum(same_color_cards)
        num_cards_in_hand = len(same_color_cards)
        if num_cards_in_hand >= 2 and total_points_in_hand >= tgt_pts-1 and played_value <= min(same_color_cards):
            step_reward += 1.0
            rule_counter["good_low_val"] += 1

    # Good Draw Move: Drawing card to create >=8 points
    if draw_choice in COLORS:
        center_pile = state['center'][draw_choice]
        if center_pile:
            center_card = center_pile[-1]
            if center_card[1] != 'X':
                center_val = int(center_card[1])
                color_cards = [int(c[1]) for c in player_hand if c[0] == draw_choice and c[1] != 'X']
                total_points = sum(color_cards)
                num_cards = len(color_cards)
                if num_cards in [1, 2] and (total_points + center_val) >= 8:
                    step_reward += 1.5
                    rule_counter["draw_to_8"] += 1

    # Apply penalty: playing a number card to an empty expedition when holding the multiplier,
    # and when the total known value in hand for this color would make the expedition profitable (7+)
    # Bad Move 7: Playing number card before multiplier when expedition is empty and enough points exist
    if is_expedition and is_number_card:
        played_color = action[1][0]
        expedition_pile = env.expeditions[state['current_player']][played_color]
        
        # Check if expedition is empty
        if not expedition_pile:
            # Get all cards in hand for this color
            color_cards_in_hand = [c for c in player_hand if c[0] == played_color]
            number_points = sum(int(c[1]) for c in color_cards_in_hand if c[1] != 'X')
            has_multiplier = any(c[1] == 'X' for c in color_cards_in_hand)
            
            if has_multiplier and number_points >= tgt_pts and state['deck_size'] >= 5:
                # Player should have played the multiplier first!
                step_reward -= 2.0  # penalty can be adjusted
                rule_counter["had_X"] += 1

    # This is captured above as too_few_pts
    # # Penalty: Starting a new expedition with less than 7 points in hand
    # if is_expedition:
    #     color = action[1][0]
    #     expedition_pile = env.expeditions[state['current_player']][color]
    #     if not expedition_pile:  # Starting new expedition
    #         # color_points = sum(int(c[1]) for c in player_hand if c[0] == color and c[1] in '23456')
    #         color_points = sum(int(c[1]) for c in player_hand if c[0] == color and c[1] != 'X')
    #         if color_points < tgt_pts:
    #             # print(f"Penalty triggered: starting {color} expedition with {color_points} points in hand.")
    #             step_reward -= 1.333
    #             rule_counter["less_7"] += 1
                
    # Bonus: Playing the immediate next card in sequence (no gaps)
    if is_expedition and is_number_card:
        played_color = action[1][0]
        played_value = int(action[1][1])
    
        expedition_pile = env.expeditions[state['current_player']][played_color]
        existing_numbers = [int(c[1]) for c in expedition_pile if c[1] != 'X']
    
        if existing_numbers:
            top_value = max(existing_numbers)
            if played_value == top_value + 1:
                step_reward += 0.3
                rule_counter["next_value"] += 1
                if random.random() < 1e-8:
                    print(f"Reward playing next value card {action} on {played_color}{top_value}")

    # ❌ Bad Move: Discarding card of color with strong expedition potential and enough deck remaining
    if action[0] == 'center' and deck_remaining>=5:
        color = action[1][0]
        if not env.expeditions[state['current_player']][color]:  # expedition not started
            color_values = [int(c[1]) for c in player_hand if c[0] == color and c[1] != 'X']
            color_sum = sum(color_values)
            if color_sum >= tgt_pts:
                step_reward -= 1.25  # Adjust weight if needed
                rule_counter["bad_center"] += 1
                if random.random() < 1e-4:
                    print(f"Bad center {action} holding {player_hand}")

    # Reward a smart discard of a value less than the top card on the opp exp pile
    if action[0] == 'center':
        color = action[1][0]
        card_val = int(action[1][1]) if action[1][1] != 'X' else None
        opp_pile = env.expeditions[opponent].get(color, [])
        opp_vals = [int(c[1]) for c in opp_pile if c[1] != 'X']
        if card_val is not None and any(v > card_val for v in opp_vals):
            step_reward += 0.5
            rule_counter["smart_opp_center"] += 1
            if random.random() < 1e-4:
                print(f"Smart center play {action} with opp exp {opp_pile}")

    # This is flawed
    # # ✅ Bonus: Closing out expedition – playing highest remaining card in sequence
    # if is_expedition and is_number_card:
    #     color = played_color
    #     value = played_value
    #     # What values of this color are still in hand after this play?
    #     hand_vals = [int(c[1]) for c in player_hand if c[0] == color and c[1] != 'X']
    #     # What’s on the board?
    #     expedition_vals = [int(c[1]) for c in expedition_pile if c[1] != 'X']
        
    #     if hand_vals:
    #         # If this is the max of hand + board, and all other cards are already played
    #         max_val = max(hand_vals + expedition_vals + [value])
    #         if value == max_val and not any(v > value for v in hand_vals):
    #             step_reward += 0.25
    #             # rule_counter["CloseOutExpedition"] += 1
    #             if random.random() < 1e-1:
    #                 print(f"Close out {action} holding {player_hand} on {env.expeditions[state['current_player']][color]}")

    if step_reward>max_reward:
        step_reward=max_reward
        
    return step_reward

In [25]:
# Modified Training Loop with Action + Draw Selection from Policy

file_name='all_rewards.V1C.3.csv'

# V1B.1
num_episodes = 60_000
batch_size = 64
batch_cnt = 3
train_every = 2
step_booster = 5.0
episode_booster = 0.0 # 0.5 for 1, 1.0 for 2, 0.0 for 3
all_rewards = []
mean_rewards = []
epsilon = 0.35
epsilon_min = 0.030 # was 0.35 for .2.
epsilon_decay = 0.999999  # adjust this rate as needed - 0.99995 is too low

env = LostCitiesEnv()
state_size = 43  # Your extracted feature size
num_card_actions = card_cnt
num_draw_choices = color_cnt+1
model = ActorCritic(state_size=state_size, action_size=num_card_actions, draw_size=num_draw_choices)
optimizer = optim.Adam(model.parameters(), lr=0.001)
replay_buffer = deque(maxlen=10000) # was 12000, then 10000 .3.
rule_counter = defaultdict(int)

for episode in range(1, num_episodes + 1):
    state = env.reset()
    done = False
    # episode_reward = 0
    mean_reward = 0.0
    play_cnt=0
    plays_p1 = []
    plays_p2 = []

    while not done:
        play_cnt+=1
        features_np = extract_features(state)
        features = torch.tensor(features_np, dtype=torch.float32)

        # Forward pass through the model
        card_logits, draw_logits, value = model(features)

        # Get legal actions and legal draws
        # actions, draws = env.get_legal_actions(state['current_player'])
        actions, draws = env.get_legal_actions(state['current_player'])
        legal_action_indices = list(range(len(actions)))

        valid_draws = [d for d in draws if d == 'deck' or (d in env.center_piles and env.center_piles[d])]
        if not valid_draws:
            print("No valid draws—forcing episode end.")
            done = True
            break

        # legal_draw_indices = list(range(len(valid_draws)))  # typically ['deck', 'R', 'B', 'G']

        # Stop this game is no legal actions - though this is of course not true, but...
        # Print discard and draw for plays 1001 to 1020
        if 10001 <= play_cnt <= 10020:
            print(f"Legal actions: {actions}")
            print(f"Action indices: {[i for i in range(len(actions))]}")
            # Not valid at this point
            # print(f"Play {play_cnt}: Discard action = {chosen_action}, Draw choice = {chosen_draw}")
        if play_cnt>=10020:
            print(play_cnt, actions, draws, valid_draws)
            print(f"\n--- STUCK STATE at play {play_cnt} ---")
            print(f"Deck size: {state['deck_size']}")
            print(f"Player hand: {state['hands'][state['current_player']]}")
            print(f"Expeditions:")
            for color in env.expeditions[state['current_player']]:
                print(f"  {color}: {env.expeditions[state['current_player']][color]}")
            print(f"Center piles:")
            for color in env.center_piles:
                print(f"  {color}: {env.center_piles[color]}")
            print(f"Available actions: {actions}")
            print(f"Available draws: {draws}")
            print(f"----------------------\n")
            raise SystemExit(f"STOP")
        
        if not actions:
            print(f"No legal actions for player {state['current_player']}. Ending episode early.")
            done = True
            break

        # Sample card action with epsilon-greedy
        if random.random() < epsilon:
            # Random action
            card_idx = random.randint(0, len(actions) - 1)
        else:
            # Model-based action
            card_probs = torch.softmax(card_logits[:len(actions)], dim=0)
            card_dist = torch.distributions.Categorical(card_probs)
            card_idx = card_dist.sample().item()
        
        chosen_action = actions[card_idx]

        # Filter valid draws based on chosen_action (if it's a center discard)
        discard_color = None
        if chosen_action[0] == 'center':
            discard_color = chosen_action[1][0]
        
        filtered_draws = [
            d for d in valid_draws if d != discard_color
        ]
        if not filtered_draws:
            # Failsafe: fallback to deck
            filtered_draws = ['deck']

        valid_draws=filtered_draws

        # Sample draw choice (FIXED)
        if random.random() < epsilon:
            chosen_draw = random.choice(valid_draws)
        else:
            # Correct mapping: get logits only for valid draws
            draw_indices_in_logits = [draw_to_index[d] for d in valid_draws]
            draw_logits_filtered = draw_logits[draw_indices_in_logits]
            draw_probs = torch.softmax(draw_logits_filtered, dim=0)
            draw_dist = torch.distributions.Categorical(draw_probs)
            draw_idx = draw_dist.sample().item()
            chosen_draw = valid_draws[draw_idx]

        # Compute shaped intermediate reward
        step_reward = compute_step_reward(state, chosen_action, chosen_draw, env)

        # Map draw_choice to its index for policy update
        chosen_draw_idx = draw_to_index[chosen_draw]

        # Save the current player before doing env.step
        current_player=state['current_player']
        
        # Take action and draw based on policies
        next_state, reward, done = env.step(chosen_action, chosen_draw)

        # Combine shaped reward + final score (if any)
        # total_reward = reward + booster * step_reward
        total_reward = step_booster * step_reward

        # Store full experience (must include both action idx and draw idx!)
        # replay_buffer.append((features_np, card_idx, chosen_draw_idx, total_reward))
        # Now, do it all at end of game
        if current_player=='P1':
            plays_p1.append((features_np, card_idx, chosen_draw_idx, step_reward))
        else:
            plays_p2.append((features_np, card_idx, chosen_draw_idx, step_reward))
            
        # Advance state
        state = next_state
        mean_reward += total_reward

        # Annealing
        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        ddebug = random.random()<0.0005
        if done:
            reward_p1 = env.compute_score('P1')
            reward_p2 = env.compute_score('P2')
            p1cnt=0
            for (features_np, card_idx, draw_idx, step_reward) in plays_p1:
                total_reward = episode_booster * reward_p1 + step_reward
                replay_buffer.append((features_np, card_idx, draw_idx, total_reward))
                if ddebug:
                    p1cnt+=1
                    print(f"P1 {p1cnt} - {episode_booster} * {reward_p1} + {step_reward} = {total_reward}")
            for (features_np, card_idx, draw_idx, step_reward) in plays_p2:
                total_reward = episode_booster * reward_p2 + step_reward
                replay_buffer.append((features_np, card_idx, draw_idx, total_reward))

    # Final mean reward is the average over plays - approximate over P1 and P2
    mean_reward=1.0*mean_reward/play_cnt
    
    if play_cnt>200:
        print(f"Plays: {play_cnt} in episode {episode}")

    # Train
    if episode % train_every == 0 and len(replay_buffer) >= batch_size:
        for _ in range(batch_cnt):
            minibatch = random.sample(replay_buffer, batch_size)
    
            # Unpack minibatch into separate lists
            states_b, actions_b, draws_b, rewards_b = zip(*minibatch)

            # Convert lists to tensors
            states_np = np.array(states_b)  # Convert list of arrays → single array
            states_t = torch.tensor(states_np, dtype=torch.float32)
    
            # Convert to tensors in batch
            # states_t = torch.tensor(states_b, dtype=torch.float32)  # Shape: [batch_size, state_size]
            actions_t = torch.tensor(actions_b, dtype=torch.long)   # Shape: [batch_size]
            draws_t = torch.tensor(draws_b, dtype=torch.long)       # Shape: [batch_size]
            rewards_t = torch.tensor(rewards_b, dtype=torch.float32)  # Shape: [batch_size]
    
            # Forward pass in batch
            card_logits_b, draw_logits_b, values_b = model(states_t)  # Each output shape: [batch_size, num_actions/draws]
    
            # Compute log probs for card actions
            card_probs_b = torch.softmax(card_logits_b, dim=1)
            log_card_probs_b = torch.log(card_probs_b + 1e-8)
            selected_log_card_probs = log_card_probs_b[range(batch_size), actions_t]
    
            # Compute log probs for draws
            draw_probs_b = torch.softmax(draw_logits_b, dim=1)
            log_draw_probs_b = torch.log(draw_probs_b + 1e-8)
            selected_log_draw_probs = log_draw_probs_b[range(batch_size), draws_t]
    
            # Compute advantage
            advantages = rewards_t - values_b.squeeze(1)  # Shape: [batch_size]
    
            # Losses
            critic_loss = advantages.pow(2).mean()
            actor_loss_card = -(selected_log_card_probs * advantages).mean()
            actor_loss_draw = -(selected_log_draw_probs * advantages).mean()
    
            total_loss = critic_loss + actor_loss_card + actor_loss_draw
    
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

    all_rewards.append(reward_p1)
    all_rewards.append(reward_p2)    
    mean_rewards.append(mean_reward)

    if episode % 1000 == 0:
        print("\n=== Step Rule Firing Counts ===")
        for rule, count in sorted(rule_counter.items(), key=lambda x: -x[1]):
            print(f"{rule:<30}: {count}")    
    
    if episode % 200 == 0:
        avg_score = np.mean(all_rewards[-2000:]) if len(all_rewards) >= 2000 else np.mean(all_rewards)
        print(f"Episode {episode}, Average Reward Last {min(len(all_rewards), 1000)}: {avg_score:.2f}, eps={epsilon:.4f}")
        pd.Series(all_rewards).to_csv(file_name, index=False, header=False)

Episode 200, Average Reward Last 400: -2.39, eps=0.3486
Episode 400, Average Reward Last 800: -2.08, eps=0.3472
Episode 600, Average Reward Last 1000: -2.11, eps=0.3459
Episode 800, Average Reward Last 1000: -2.19, eps=0.3445

=== Step Rule Firing Counts ===
good_exp_1                    : 5192
too_few_pts                   : 4310
good_exp                      : 3120
blocked_7                     : 2848
good_low_val                  : 2489
smart_opp_center              : 2319
next_value                    : 2016
draw_to_8                     : 1816
bad_center                    : 1231
lower_val_avail               : 793
bad_bigger_val                : 793
exp_small_deck                : 612
bad_X                         : 603
exp_was_live                  : 523
had_X                         : 33
Episode 1000, Average Reward Last 1000: -2.24, eps=0.3432
Episode 1200, Average Reward Last 1000: -2.13, eps=0.3419
Episode 1400, Average Reward Last 1000: -2.13, eps=0.3406
Episode 1600, Avera

In [80]:
torch.save(model.state_dict(), 'lc_model_v1C_1.pt')

In [93]:
def play_bot(epsilon=0.0):
    state = env.reset()
    done = False
    play_cnt = 0

    while not done:
        play_cnt += 1
        current_player = state['current_player']

        print(f"\n=== Turn {play_cnt} — {current_player} ===")

        actions, draws = env.get_legal_actions(current_player)
        valid_draws = [d for d in draws if d == 'deck' or (d in env.center_piles and env.center_piles[d])]

        if not actions or not valid_draws:
            print("No legal actions or draws. Ending game.")
            break

        if current_player == 'P2':
            # ===== HUMAN TURN =====
            hand = state['hands'][current_player]
            exped = env.expeditions[current_player]
            center = env.center_piles
            opp_exped = env.expeditions['P1']

            exped_str = " | ".join(f"{color}:{exped.get(color, [])}" for color in COLORS)
            opp_exped_str = " | ".join(f"{color}:{opp_exped.get(color, [])}" for color in COLORS)
            center_str = " | ".join(f"{color}:{center[color]}" for color in COLORS)

            print(f"Your hand: {hand}")
            print(f"Your expeditions: {exped_str}")
            print(f"Center piles: {center_str}")
            print(f"Opponent expeditions: {opp_exped_str}")
            print(f"Valid draws: {valid_draws}")

            user_input = input("Enter move as (E/C) CARD DRAW (e.g., E B3 D): ").strip().upper()
            try:
                move_type, card_str, draw_choice = user_input.split()
                assert move_type in ['E', 'C']
                assert any(card_str == card for _, card in actions), "Invalid card"
                assert draw_choice in ['D'] + COLORS, "Invalid draw"
            except Exception as e:
                print("Invalid input. Try again.")
                continue

            chosen_action = ('expedition', card_str) if move_type == 'E' else ('center', card_str)
            chosen_draw = 'deck' if draw_choice == 'D' else draw_choice

        else:
            # ===== MODEL TURN =====
            features_np = extract_features(state)
            features = torch.tensor(features_np, dtype=torch.float32)
            card_logits, draw_logits, value = model(features)

            # Card action selection (no change)
            card_probs = torch.softmax(card_logits[:len(actions)], dim=0)
            card_dist = torch.distributions.Categorical(card_probs)
            card_idx = card_dist.sample().item()
            chosen_action = actions[card_idx]

            # Filter valid draws based on chosen_action (if it's a center discard)
            discard_color = None
            if chosen_action[0] == 'center':
                discard_color = chosen_action[1][0]
            
            filtered_draws = [
                d for d in valid_draws if d != discard_color
            ]
            if not filtered_draws:
                # Failsafe: fallback to deck
                filtered_draws = ['deck']

            valid_draws=filtered_draws
            
            # Draw choice selection (FIXED)
            if random.random() < epsilon:
                chosen_draw = random.choice(valid_draws)
            else:
                # Map valid draws to their indices in logits
                draw_indices_in_logits = [draw_to_index[d] for d in valid_draws]
                draw_logits_filtered = draw_logits[draw_indices_in_logits]
                draw_probs = torch.softmax(draw_logits_filtered, dim=0)
                draw_dist = torch.distributions.Categorical(draw_probs)
                draw_idx = draw_dist.sample().item()
                chosen_draw = valid_draws[draw_idx]

            opp_hand = state['hands'][current_player]
            print(f"P1 plays: {chosen_action} | Draws: {chosen_draw} -- holding {opp_hand}")

        # Step and advance
        next_state, reward, done = env.step(chosen_action, chosen_draw)
        state = next_state

        if done:
            p1_score = env.compute_score('P1')
            p2_score = env.compute_score('P2')
            print(f"\n=== Game Over ===")
            print(f"Final Score — P1: {p1_score} | P2: {p2_score}")

In [87]:
model.load_state_dict(torch.load('lc_model_v1B_2.pt'))
model.eval()  # Important: sets model to evaluation mode (no dropout etc.)

ActorCritic(
  (fc1): Linear(in_features=43, out_features=96, bias=True)
  (fc2): Linear(in_features=96, out_features=32, bias=True)
  (dropout): Dropout(p=0.15, inplace=False)
  (policy_action_head): Linear(in_features=32, out_features=18, bias=True)
  (policy_draw_head): Linear(in_features=32, out_features=4, bias=True)
  (value_head): Linear(in_features=32, out_features=1, bias=True)
)

In [95]:
play_bot(0.00)


=== Turn 1 — P1 ===
P1 plays: ('center', 'B3') | Draws: deck -- holding ['B3', 'B4', 'B6']

=== Turn 2 — P2 ===
Your hand: ['R2', 'R3', 'R6']
Your expeditions: R:[] | B:[] | G:[]
Center piles: R:[] | B:['B3'] | G:[]
Opponent expeditions: R:[] | B:[] | G:[]
Valid draws: ['deck', 'B']


KeyboardInterrupt: Interrupted by user