# Question 1: Card Game - Highest Card Wins

This notebook implements a simple card game where two players draw 5 cards each and compare their highest cards.

# Question 2: Trick-Taking Card Game

## Part 2a: Game Rules

The game starts by picking one player to play first. These players will play each trick as follows:

1. The selected player plays a random card from their hand. This is the **starter**.

2. Every following player plays a card from their hand. This card **must be the same suit as the starter**. You can choose any valid card in this step. **Do not worry about player strategy**.

3. If a player does not have any cards in their hand of the same suit as the starter, **they may play any card**.

4. Print out each player and the card that they played, using the message "Player {X} played {Card}".

5. The winner is the player who played the card with the highest rank card that is also the **same suit as the starter**.

6. Print the name of the winner. The **winner of this trick is now the starting player** for the next trick.

7. This is repeated until all of the players have no cards in their hands, then the game ends.

**For this part of the problem, simulate an entire game.**

There should be 13 rounds total in the game.

In [70]:
#Class: Hand, Card, Player given.... 

RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A']
SUITS = ['♣', '♦', '♥', '♠']
RANK_ORDER = {rank:idx for idx, rank in enumerate(RANKS)}

class Card:
    def __init__(self, rank, suit):
        self.rank = rank
        self.suit = suit


class Hand:
    def __init__(self):
        self.cards = []
    
    def add_card(self, card):
        self.cards.append(card)

    def remove_card(self, card):
        self.cards.remove(card)



class Player:
    def __init__(self, name):
        self.name = name
        self.hand = Hand()

    def draw(self, deck):
        card = deck.draw_card()
        self.hand.add_card(card)
        return card


In [71]:

import random 

class Deck:
    def __init__(self):
        self.cards = []
        self._build()

    def _build(self):
        ranks = ['2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A']
        suits = ['♣', '♦', '♥', '♠']
        for suit in suits:
            for rank in ranks:
                self.cards.append(Card(rank, suit))

    def shuffle(self):
        random.shuffle(self.cards)

    def draw_card(self):
        if not self.cards:
            raise ValueError("Deck is empty")
        return self.cards.pop()



In [72]:

# Play a single trick
def play_trick(players, starter_idx):
    num_players = len(players)
    played = []

    # Determine play order
    order = []
    for i in range(num_players):
        order.append((starter_idx + i) % num_players)

    # Starter plays any card
    starter_player = players[starter_idx]
    starter_card = random.choice(starter_player.hand.cards)
    starter_player.hand.remove_card(starter_card)
    starter_suit = starter_card.suit
    print(f"Player {starter_player.name} played {starter_card.rank}{starter_card.suit}")
    played.append((starter_player, starter_card))

    # Remaining players
    for k in range(1, num_players):
        idx = order[k]
        player = players[idx]

        # Find cards of the same suit
        same_suit_cards = []
        for c in player.hand.cards:
            if c.suit == starter_suit:
                same_suit_cards.append(c)

        # Pick a card to play
        if len(same_suit_cards) > 0:
            card_to_play = random.choice(same_suit_cards)
        else:
            card_to_play = random.choice(player.hand.cards)

        player.hand.remove_card(card_to_play)
        print(f"Player {player.name} played {card_to_play.rank}{card_to_play.suit}")
        played.append((player, card_to_play))

    # Determine winner
    valid_plays = []
    for p, c in played:
        if c.suit == starter_suit:
            valid_plays.append((p, c))

    winner = None
    max_value = -1
    for p, c in valid_plays:
        if RANK_ORDER[c.rank] > max_value:
            max_value = RANK_ORDER[c.rank]
            winner = p

    print(f"Winner of this trick: {winner.name}")
    return players.index(winner)


In [73]:
# Play the full game
def play_game(num_players=4):
    if num_players < 2:
        raise ValueError("Number of players must be at least 2")
    if num_players > 52:
        raise ValueError("Too many players for a 52-card deck")

    # Create deck and shuffle
    deck = Deck()
    deck.shuffle()

    # Create players
    players = []
    for i in range(num_players):
        player_name = f"Player {i + 1}"
        player = Player(player_name)
        players.append(player)

    # Deal cards evenly
    cards_per_player = len(deck.cards) // num_players
    for i in range(cards_per_player):
        for j in range(num_players):
            players[j].draw(deck)

    # Random starting player
    starter_idx = random.randint(0, num_players - 1)

    # Play each trick
    for trick_number in range(cards_per_player):
        print(f"\n--- Trick {trick_number + 1} ---")
        starter_idx = play_trick(players, starter_idx)



In [74]:
play_game(4)


--- Trick 1 ---
Player Player 3 played K♠
Player Player 4 played 4♠
Player Player 1 played 8♠
Player Player 2 played 7♠
Winner of this trick: Player 3

--- Trick 2 ---
Player Player 3 played 10♦
Player Player 4 played 2♦
Player Player 1 played A♦
Player Player 2 played 7♦
Winner of this trick: Player 1

--- Trick 3 ---
Player Player 1 played J♣
Player Player 2 played 5♣
Player Player 3 played 4♣
Player Player 4 played 2♣
Winner of this trick: Player 1

--- Trick 4 ---
Player Player 1 played 3♣
Player Player 2 played A♣
Player Player 3 played 8♣
Player Player 4 played Q♣
Winner of this trick: Player 2

--- Trick 5 ---
Player Player 2 played 2♠
Player Player 3 played Q♠
Player Player 4 played 6♠
Player Player 1 played 9♣
Winner of this trick: Player 3

--- Trick 6 ---
Player Player 3 played 4♦
Player Player 4 played J♦
Player Player 1 played 9♦
Player Player 2 played 3♦
Winner of this trick: Player 4

--- Trick 7 ---
Player Player 4 played J♥
Player Player 1 played 6♥
Player Player 2 pl

## Part 2b: Scoring

Each round, the winning player gets a number of points based on the cards played on the table for that round.

- **5** is worth **5 fish points**
- **10** is worth **10 fish points**
- **K** is worth **10 fish points**
- **All other cards** are worth **0 points**

### Modify your code to support the following:

1. At the end of each round, print the number of points that the winner took.
2. At the end of the game, print each player's total number of points.
3. Print the name of any player with highest number of fish points.

In [75]:
# Fish points mapping
FISH_POINTS = {'5': 5, '10': 10, 'K': 10}

def calculate_fish_points(cards):

    points = 0
    for card in cards:
        if card.rank in FISH_POINTS:
            points += FISH_POINTS[card.rank]
    return points

In [76]:
# Play a single trick
def play_trick(players, starter_idx, points_dict):
    num_players = len(players)
    played = []

    # Determine play order
    order = []
    for i in range(num_players):
        order.append((starter_idx + i) % num_players)

    # Starter plays
    starter_player = players[starter_idx]
    starter_card = random.choice(starter_player.hand.cards)
    starter_player.hand.remove_card(starter_card)
    starter_suit = starter_card.suit
    print(f"Player {starter_player.name} played {starter_card.rank}{starter_card.suit}")
    played.append((starter_player, starter_card))

    # Remaining players
    for k in range(1, num_players):
        idx = order[k]
        player = players[idx]

        # Find cards of the same suit
        same_suit_cards = []
        for c in player.hand.cards:
            if c.suit == starter_suit:
                same_suit_cards.append(c)

        if len(same_suit_cards) > 0:
            card_to_play = random.choice(same_suit_cards)
        else:
            card_to_play = random.choice(player.hand.cards)

        player.hand.remove_card(card_to_play)
        print(f"Player {player.name} played {card_to_play.rank}{card_to_play.suit}")
        played.append((player, card_to_play))

    # Determine winner
    valid_plays = []
    for p, c in played:
        if c.suit == starter_suit:
            valid_plays.append((p, c))

    winner = None
    max_value = -1
    for p, c in valid_plays:
        if RANK_ORDER[c.rank] > max_value:
            max_value = RANK_ORDER[c.rank]
            winner = p

    # Calculate fish points using modular function
    trick_cards = [c for _, c in played]
    points_this_trick = calculate_fish_points(trick_cards)
    points_dict[winner.name] += points_this_trick

    print(f"Winner of this trick: {winner.name} (+{points_this_trick} points)")
    return players.index(winner)

In [77]:
# Play the full game
def play_game(num_players=4):
    if num_players < 2:
        raise ValueError("Number of players must be at least 2")
    if num_players > 52:
        raise ValueError("Too many players for a 52-card deck")

    # Create deck and shuffle
    deck = Deck()
    deck.shuffle()

    # Create players and external points dictionary
    players = []
    points_dict = {}
    for i in range(num_players):
        player_name = f"Player {i + 1}"
        player = Player(player_name)
        players.append(player)
        points_dict[player_name] = 0

    # Deal cards evenly
    cards_per_player = len(deck.cards) // num_players
    for i in range(cards_per_player):
        for j in range(num_players):
            players[j].draw(deck)

    # Random starting player
    starter_idx = random.randint(0, num_players - 1)

    # Play each trick
    for trick_number in range(cards_per_player):
        print(f"\n--- Trick {trick_number + 1} ---")
        starter_idx = play_trick(players, starter_idx, points_dict)

    # Print final points and winner(s)
    print("\nGame over! Final points:")
    max_points = -1
    winners = []
    for player_name in points_dict:
        print(f"{player_name}: {points_dict[player_name]} points")
        if points_dict[player_name] > max_points:
            max_points = points_dict[player_name]
            winners = [player_name]
        elif points_dict[player_name] == max_points:
            winners.append(player_name)

    print(f"\nPlayer(s) with highest fish points: {', '.join(winners)}")

In [78]:
play_game(num_players=4)



--- Trick 1 ---
Player Player 1 played 10♠
Player Player 2 played Q♠
Player Player 3 played A♠
Player Player 4 played 4♠
Winner of this trick: Player 3 (+10 points)

--- Trick 2 ---
Player Player 3 played K♣
Player Player 4 played 5♣
Player Player 1 played 9♣
Player Player 2 played 6♣
Winner of this trick: Player 3 (+15 points)

--- Trick 3 ---
Player Player 3 played J♦
Player Player 4 played K♦
Player Player 1 played 5♦
Player Player 2 played 9♦
Winner of this trick: Player 4 (+15 points)

--- Trick 4 ---
Player Player 4 played 2♥
Player Player 1 played 9♥
Player Player 2 played 6♥
Player Player 3 played 3♥
Winner of this trick: Player 1 (+0 points)

--- Trick 5 ---
Player Player 1 played 5♠
Player Player 2 played 2♠
Player Player 3 played K♠
Player Player 4 played 3♣
Winner of this trick: Player 3 (+15 points)

--- Trick 6 ---
Player Player 3 played A♥
Player Player 4 played J♥
Player Player 1 played 10♥
Player Player 2 played 5♥
Winner of this trick: Player 3 (+15 points)

--- Tric

# Question 3: Poker Hand Validation

Given a set of six poker hand rules (flush, straight, full house, 4-of-a-kind, straight flush, royal flush), determine whether a given hand is valid by checking if it satisfies at least one of these rules.

## Poker Hand Definitions:

- **Flush**: All 5 cards have the same suit
- **Straight**: 5 cards in sequential rank order (e.g., 5-6-7-8-9)
- **Full House**: 3 cards of one rank and 2 cards of another rank (e.g., 3-3-3-K-K)
- **4-of-a-Kind**: 4 cards of the same rank (e.g., 9-9-9-9-3)
- **Straight Flush**: 5 cards in sequential rank order, all of the same suit
- **Royal Flush**: 10-J-Q-K-A all of the same suit

**Note**: Ace can be low (A-2-3-4-5) or high (10-J-Q-K-A) in straights.

In [79]:
#Class: Hand, Card, Player given.... 

RANKS = ['2', '3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A']
SUITS = ['♣', '♦', '♥', '♠']
RANK_ORDER = {rank:idx for idx, rank in enumerate(RANKS)}

class Card:
    def __init__(self, rank, suit):
        self.rank = rank
        self.suit = suit


class Hand:
    def __init__(self):
        self.cards = []
    
    def add_card(self, card):
        self.cards.append(card)

    def remove_card(self, card):
        self.cards.remove(card)



class Player:
    def __init__(self, name):
        self.name = name
        self.hand = Hand()

    def draw(self, deck):
        card = deck.draw_card()
        self.hand.add_card(card)
        return card

In [80]:
def is_flush(hand):
    suits = [card.suit for card in hand.cards]
    return len(set(suits)) == 1

def is_straight(hand):
    indices = sorted([RANK_ORDER[card.rank] for card in hand.cards])
    for i in range(len(indices) - 1):
        if indices[i] + 1 != indices[i+1]:
            return False
    return True

def is_four_of_a_kind(hand):
    ranks = [card.rank for card in hand.cards]
    for rank in set(ranks):
        if ranks.count(rank) == 4:
            return True
    return False

def is_full_house(hand):
    ranks = [card.rank for card in hand.cards]
    unique_ranks = set(ranks)
    return len(unique_ranks) == 2 and (ranks.count(list(unique_ranks)[0]) in [2,3])

def is_straight_flush(hand):
    return is_flush(hand) and is_straight(hand)

def is_royal_flush(hand):
    if not is_flush(hand):
        return False
    royal_ranks = ['10','J','Q','K','A']
    hand_ranks = sorted([card.rank for card in hand.cards])
    return hand_ranks == royal_ranks


In [81]:
def valid_hand(hand):
    """
    Returns True if the hand satisfies at least one poker rule.
    """
    if (is_flush(hand) or
        is_straight(hand) or
        is_full_house(hand) or
        is_four_of_a_kind(hand) or
        is_straight_flush(hand) or
        is_royal_flush(hand)):
        return True
    return False

In [82]:
# Create a hand (Royal Flush)
hand = Hand()
hand.add_card(Card('10','♠'))
hand.add_card(Card('J','♠'))
hand.add_card(Card('Q','♠'))
hand.add_card(Card('K','♠'))
hand.add_card(Card('A','♠'))

if valid_hand(hand):
    print("Hand is valid!")
else:
    print("Hand is invalid.")


Hand is valid!


## Follow-up 1: Wildcards (Jokers)

Modify the approach to account for the presence of wildcards (Jokers) that can represent any card.

In [83]:
def separate_jokers(hand):
    jokers = []
    non_jokers = []
    for card in hand.cards:
        if card.rank == 'Joker':
            jokers.append(card)
        else:
            non_jokers.append(card)
    return non_jokers, jokers

def is_flush_with_joker(hand):
    non_jokers, jokers = separate_jokers(hand)
    if not non_jokers:
        return True  # all jokers can form flush
    suits = [c.suit for c in non_jokers]
    most_common_suit = max(set(suits), key=suits.count)
    missing = len(non_jokers) - suits.count(most_common_suit)
    return missing <= len(jokers)

def is_straight_with_joker(hand):
    non_jokers, jokers = separate_jokers(hand)
    if not non_jokers:
        return True  # all jokers can form straight
    indices = sorted([RANK_ORDER[c.rank] for c in non_jokers])
    # Calculate gaps between consecutive cards
    gaps = 0
    for i in range(len(indices) - 1):
        gaps += indices[i+1] - indices[i] - 1
    return gaps <= len(jokers)

def is_four_of_a_kind_with_joker(hand):
    non_jokers, jokers = separate_jokers(hand)
    ranks = [c.rank for c in non_jokers]
    counts = [ranks.count(r) for r in set(ranks)]
    max_count = max(counts) if counts else 0
    return max_count + len(jokers) >= 4

def is_full_house_with_joker(hand):
    non_jokers, jokers = separate_jokers(hand)
    ranks = [c.rank for c in non_jokers]
    unique = set(ranks)
    if not unique:
        return len(jokers) >= 5  # all jokers
    counts = [ranks.count(r) for r in unique]
    counts.sort(reverse=True)
    # Use jokers to fill 3-of-a-kind and 2-of-a-kind
    needed_for_full_house = max(0, 3 - counts[0]) + max(0, 2 - (counts[1] if len(counts) > 1 else 0))
    return needed_for_full_house <= len(jokers)

def is_straight_flush_with_joker(hand):
    return is_flush_with_joker(hand) and is_straight_with_joker(hand)

def is_royal_flush_with_joker(hand):
    non_jokers, jokers = separate_jokers(hand)
    royal_ranks = ['10','J','Q','K','A']
    hand_ranks = [c.rank for c in non_jokers]
    missing_ranks = len(set(royal_ranks) - set(hand_ranks))
    return is_flush_with_joker(hand) and missing_ranks <= len(jokers)


In [84]:
def valid_hand_with_joker(hand):
    return (is_flush_with_joker(hand) or
            is_straight_with_joker(hand) or
            is_full_house_with_joker(hand) or
            is_four_of_a_kind_with_joker(hand) or
            is_straight_flush_with_joker(hand) or
            is_royal_flush_with_joker(hand))


In [85]:
# Hand with 1 Joker
hand = Hand()
hand.add_card(Card('10','♠'))
hand.add_card(Card('J','♠'))
hand.add_card(Card('Q','♠'))
hand.add_card(Card('K','♠'))
hand.add_card(Card('Joker', None))  # Joker

if valid_hand_with_joker(hand):
    print("Hand is valid with Joker!")
else:
    print("Hand is invalid.")


Hand is valid with Joker!


## Follow-up 2: Comparing Two Hands

Given two players' hands and an ordering of the poker hand rules, compare the two hands to determine which is better.

In [86]:
def compare_hands_with_joker(hand1, hand2, rules_order):
    """
    rules_order: list of functions in order of priority,
        e.g. [is_royal_flush_with_joker, is_straight_flush_with_joker, ...]
    Returns:
        1 if hand1 wins
       -1 if hand2 wins
        0 if tie
    """
    for rule in rules_order:
        h1_valid = rule(hand1)
        h2_valid = rule(hand2)
        if h1_valid and not h2_valid:
            return 1
        if h2_valid and not h1_valid:
            return -1
    return 0  # tie if both or neither satisfy any rule


In [87]:
# Player 1: Royal Flush with Joker acting as A♠
hand1 = Hand()
hand1.add_card(Card('10','♠'))
hand1.add_card(Card('J','♠'))
hand1.add_card(Card('Q','♠'))
hand1.add_card(Card('K','♠'))
hand1.add_card(Card('Joker', None))  # acts as A♠

# Player 2: Straight Flush without Joker
hand2 = Hand()
hand2.add_card(Card('9','♣'))
hand2.add_card(Card('10','♣'))
hand2.add_card(Card('J','♣'))
hand2.add_card(Card('Q','♣'))
hand2.add_card(Card('K','♣'))

rules_order = [
    is_royal_flush_with_joker,
    is_straight_flush_with_joker,
    is_four_of_a_kind_with_joker,
    is_full_house_with_joker,
    is_flush_with_joker,
    is_straight_with_joker
]

winner = compare_hands_with_joker(hand1, hand2, rules_order)

if winner == 1:
    print("Hand 1 wins")
elif winner == -1:
    print("Hand 2 wins")
else:
    print("Tie")


Hand 1 wins


# Question 4: Team Card Game with Lives and Skips

A team of multiple players plays a card game with specific constraints:

## Game Rules:

- **Team Lives**: The team starts with `num_players + 1` lives
- **Skips Available**: The team can skip up to `x` rounds (input parameter)
- **Player Order**: Players play in a fixed order (determined at start)
- **Hand Management**: Each player's hand is always sorted in ascending order
- **Card Playing**: Players must play their smallest available card
- **Round Success**: A round succeeds if each player's card is strictly larger than the previous player's
- **Round Failure Options**:
  - Use a skip (if available) - Players draw new cards and no one plays that round
  - Lose a life
- **Total Rounds**: The game runs for Y rounds
- **Win Condition**: Complete all Y rounds without running out of lives

## Objective:

Determine whether the team can survive all Y rounds given:
- Number of players
- Number of skips available
- Number of rounds to play

In [88]:
import random
from typing import List, Tuple

class TeamPlayer:
    """Represents a player in the team card game."""
    
    def __init__(self, name: str, deck: List[int]):
        self.name = name
        self.deck = deck  # Personal deck to draw from
        self.hand = []
    
    def draw_cards(self, num_cards: int = 5):
        """Draw cards from personal deck and sort hand."""
        for _ in range(num_cards):
            if self.deck:
                self.hand.append(self.deck.pop(0))
        self.hand.sort()
    
    def play_smallest_card(self) -> int:
        """Play and remove the smallest card from hand."""
        if not self.hand:
            raise ValueError(f"{self.name} has no cards to play!")
        return self.hand.pop(0)
    
    def has_cards(self) -> bool:
        """Check if player has cards in hand or deck."""
        return len(self.hand) > 0 or len(self.deck) > 0
    
    def discard_hand(self):
        """Discard current hand."""
        self.hand = []

In [89]:
class TeamCardGame:
    """Manages the team card game with lives and skips."""
    
    def __init__(self, num_players: int, num_skips: int, num_rounds: int, 
                 cards_per_hand: int = 5, verbose: bool = True):
        """
        Initialize the team card game.
        
        Args:
            num_players: Number of players on the team
            num_skips: Number of skips available
            num_rounds: Total rounds to play
            cards_per_hand: Cards dealt per round
            verbose: Whether to print game progress
        """
        self.num_players = num_players
        self.lives = num_players + 1
        self.skips_remaining = num_skips
        self.num_rounds = num_rounds
        self.cards_per_hand = cards_per_hand
        self.verbose = verbose
        
        # Create players with individual decks
        self.players = []
        for i in range(num_players):
            # Each player gets their own shuffled deck of cards (1-100)
            deck = list(range(1, 101))
            random.shuffle(deck)
            self.players.append(TeamPlayer(f"Player {i+1}", deck))
    
    def play_round(self, round_num: int) -> bool:
        """
        Play one round of the game.
        
        Returns:
            True if round succeeded, False if round failed
        """
        if self.verbose:
            print(f"\n{'='*60}")
            print(f"ROUND {round_num}")
            print(f"Lives: {self.lives} | Skips: {self.skips_remaining}")
            print(f"{'='*60}")
        
        # Each player draws cards
        for player in self.players:
            player.draw_cards(self.cards_per_hand)
        
        if self.verbose:
            for player in self.players:
                print(f"{player.name}'s hand: {player.hand}")
        
        # Players play in order
        cards_played = []
        previous_card = -1  # Start with -1 so first card is always valid
        round_success = True
        
        if self.verbose:
            print(f"\nPlaying cards:")
        
        for player in self.players:
            card = player.play_smallest_card()
            cards_played.append((player.name, card))
            
            if card <= previous_card:
                round_success = False
                if self.verbose:
                    print(f"{player.name} played {card} ✗ (not greater than {previous_card})")
                break
            else:
                if self.verbose:
                    print(f"{player.name} played {card} ✓")
                previous_card = card
        
        return round_success
    
    def handle_failed_round(self, round_num: int) -> str:
        """
        Handle a failed round - use skip or lose life.
        
        Returns:
            Action taken: "skip" or "life"
        """
        # Strategy: Use skip if available, otherwise lose a life
        if self.skips_remaining > 0:
            self.skips_remaining -= 1
            if self.verbose:
                print(f"\n→ Round {round_num} FAILED! Using a skip.")
                print(f"→ Skips remaining: {self.skips_remaining}")
            # Discard hands and draw new cards
            for player in self.players:
                player.discard_hand()
            return "skip"
        else:
            self.lives -= 1
            if self.verbose:
                print(f"\n→ Round {round_num} FAILED! Losing a life.")
                print(f"→ Lives remaining: {self.lives}")
            return "life"
    
    def play_game(self) -> bool:
        """
        Play the complete game.
        
        Returns:
            True if team survives all rounds, False otherwise
        """
        if self.verbose:
            print(f"\n{'#'*60}")
            print(f"TEAM CARD GAME")
            print(f"Players: {self.num_players} | Lives: {self.lives} | Skips: {self.skips_remaining}")
            print(f"Total Rounds: {self.num_rounds}")
            print(f"{'#'*60}")
        
        for round_num in range(1, self.num_rounds + 1):
            # Check if team is still alive
            if self.lives <= 0:
                if self.verbose:
                    print(f"\n{'='*60}")
                    print(f"GAME OVER - Team ran out of lives at round {round_num}")
                    print(f"{'='*60}")
                return False
            
            # Play the round
            round_success = self.play_round(round_num)
            
            if round_success:
                if self.verbose:
                    print(f"\n→ Round {round_num} SUCCEEDED! ✓")
            else:
                self.handle_failed_round(round_num)
        
        # Check final result
        if self.lives > 0:
            if self.verbose:
                print(f"\n{'='*60}")
                print(f"VICTORY! Team completed all {self.num_rounds} rounds!")
                print(f"Final lives: {self.lives} | Final skips: {self.skips_remaining}")
                print(f"{'='*60}")
            return True
        else:
            if self.verbose:
                print(f"\n{'='*60}")
                print(f"DEFEAT! Team ran out of lives.")
                print(f"{'='*60}")
            return False

In [90]:
# Example 1: Small game with generous resources
print("Example 1: 3 players, 2 skips, 5 rounds")
game1 = TeamCardGame(num_players=3, num_skips=2, num_rounds=5)
result1 = game1.play_game()
print(f"\nResult: {'SUCCESS' if result1 else 'FAILURE'}")

Example 1: 3 players, 2 skips, 5 rounds

############################################################
TEAM CARD GAME
Players: 3 | Lives: 4 | Skips: 2
Total Rounds: 5
############################################################

ROUND 1
Lives: 4 | Skips: 2
Player 1's hand: [28, 64, 92, 96, 98]
Player 2's hand: [21, 32, 60, 83, 96]
Player 3's hand: [14, 24, 49, 55, 98]

Playing cards:
Player 1 played 28 ✓
Player 2 played 21 ✗ (not greater than 28)

→ Round 1 FAILED! Using a skip.
→ Skips remaining: 1

ROUND 2
Lives: 4 | Skips: 1
Player 1's hand: [8, 14, 47, 69, 93]
Player 2's hand: [4, 26, 63, 76, 85]
Player 3's hand: [4, 12, 28, 39, 69]

Playing cards:
Player 1 played 8 ✓
Player 2 played 4 ✗ (not greater than 8)

→ Round 2 FAILED! Using a skip.
→ Skips remaining: 0

ROUND 3
Lives: 4 | Skips: 0
Player 1's hand: [23, 73, 75, 85, 86]
Player 2's hand: [6, 34, 48, 89, 94]
Player 3's hand: [34, 38, 46, 47, 62]

Playing cards:
Player 1 played 23 ✓
Player 2 played 6 ✗ (not greater than 23)

→ R

In [91]:
# Example 2: Harder game with limited resources
print("\n" + "="*70)
print("Example 2: 4 players, 1 skip, 8 rounds")
game2 = TeamCardGame(num_players=4, num_skips=1, num_rounds=8)
result2 = game2.play_game()
print(f"\nResult: {'SUCCESS' if result2 else 'FAILURE'}")


Example 2: 4 players, 1 skip, 8 rounds

############################################################
TEAM CARD GAME
Players: 4 | Lives: 5 | Skips: 1
Total Rounds: 8
############################################################

ROUND 1
Lives: 5 | Skips: 1
Player 1's hand: [13, 14, 82, 87, 98]
Player 2's hand: [12, 26, 78, 88, 98]
Player 3's hand: [17, 74, 80, 90, 97]
Player 4's hand: [17, 19, 35, 43, 54]

Playing cards:
Player 1 played 13 ✓
Player 2 played 12 ✗ (not greater than 13)

→ Round 1 FAILED! Using a skip.
→ Skips remaining: 0

ROUND 2
Lives: 5 | Skips: 0
Player 1's hand: [27, 37, 61, 74, 78]
Player 2's hand: [51, 54, 77, 90, 91]
Player 3's hand: [4, 11, 54, 63, 87]
Player 4's hand: [10, 12, 15, 89, 93]

Playing cards:
Player 1 played 27 ✓
Player 2 played 51 ✓
Player 3 played 4 ✗ (not greater than 51)

→ Round 2 FAILED! Losing a life.
→ Lives remaining: 4

ROUND 3
Lives: 4 | Skips: 0
Player 1's hand: [23, 37, 59, 61, 66, 74, 78, 79, 99]
Player 2's hand: [3, 46, 50, 54, 59, 75,

In [92]:
# Example 3: Simulation to determine win probability
print("\n" + "="*70)
print("Example 3: Running 100 simulations (3 players, 2 skips, 10 rounds)")
print("="*70)

num_simulations = 100
wins = 0

for i in range(num_simulations):
    game = TeamCardGame(num_players=3, num_skips=2, num_rounds=10, verbose=False)
    if game.play_game():
        wins += 1

win_rate = (wins / num_simulations) * 100
print(f"\nSimulation Results:")
print(f"Wins: {wins}/{num_simulations}")
print(f"Win Rate: {win_rate:.1f}%")

if win_rate >= 50:
    print(f"\n✓ The team CAN likely survive with these parameters ({win_rate:.1f}% success rate)")
else:
    print(f"\n✗ The team will likely FAIL with these parameters ({win_rate:.1f}% success rate)")


Example 3: Running 100 simulations (3 players, 2 skips, 10 rounds)

Simulation Results:
Wins: 0/100
Win Rate: 0.0%

✗ The team will likely FAIL with these parameters (0.0% success rate)


## Analysis Function

Let's create a function to analyze whether a team can survive based on simulation.

In [93]:
def can_team_survive_theoretical(num_players: int, num_skips: int, num_rounds: int) -> dict:
    """
    Theoretical analysis of whether a team can survive.
    
    The team can afford to fail at most (lives + skips) times.
    Lives = num_players + 1
    Total allowed failures = lives + skips = (num_players + 1) + num_skips
    
    However, the actual probability depends on card distribution.
    This function provides theoretical bounds.
    
    Args:
        num_players: Number of players
        num_skips: Number of skips available
        num_rounds: Total rounds to play
    
    Returns:
        Dictionary with analysis results
    """
    lives = num_players + 1
    total_resources = lives + num_skips
    
    # Run a simulation to estimate failure rate
    num_test_simulations = 1000
    failures = 0
    
    for _ in range(num_test_simulations):
        game = TeamCardGame(num_players, num_skips, num_rounds, verbose=False)
        if not game.play_game():
            failures += 1
    
    failure_rate = failures / num_test_simulations
    success_rate = 1 - failure_rate
    
    return {
        'num_players': num_players,
        'lives': lives,
        'skips': num_skips,
        'total_resources': total_resources,
        'num_rounds': num_rounds,
        'max_allowed_failures': total_resources,
        'success_rate': success_rate,
        'can_survive': success_rate >= 0.5,
        'confidence': 'high' if abs(success_rate - 0.5) > 0.2 else 'medium' if abs(success_rate - 0.5) > 0.1 else 'low'
    }

In [94]:
# Test different scenarios
test_scenarios = [
    (3, 2, 5),   # Easy
    (3, 2, 10),  # Medium
    (4, 1, 8),   # Hard
    (5, 3, 15),  # Very Hard
    (2, 5, 10),  # Easy with many skips
]

print("="*80)
print("TEAM SURVIVAL ANALYSIS")
print("="*80)

for num_players, num_skips, num_rounds in test_scenarios:
    result = can_team_survive_theoretical(num_players, num_skips, num_rounds)
    
    print(f"\nScenario: {num_players} players, {num_skips} skips, {num_rounds} rounds")
    print(f"  Lives: {result['lives']}")
    print(f"  Total Resources (lives + skips): {result['total_resources']}")
    print(f"  Success Rate: {result['success_rate']*100:.1f}%")
    print(f"  Verdict: {'✓ CAN SURVIVE' if result['can_survive'] else '✗ WILL LIKELY FAIL'}")
    print(f"  Confidence: {result['confidence'].upper()}")

print("\n" + "="*80)

TEAM SURVIVAL ANALYSIS

Scenario: 3 players, 2 skips, 5 rounds
  Lives: 4
  Total Resources (lives + skips): 6
  Success Rate: 100.0%
  Verdict: ✓ CAN SURVIVE
  Confidence: HIGH

Scenario: 3 players, 2 skips, 10 rounds
  Lives: 4
  Total Resources (lives + skips): 6
  Success Rate: 0.3%
  Verdict: ✗ WILL LIKELY FAIL
  Confidence: HIGH

Scenario: 4 players, 1 skips, 8 rounds
  Lives: 5
  Total Resources (lives + skips): 6
  Success Rate: 0.1%
  Verdict: ✗ WILL LIKELY FAIL
  Confidence: HIGH

Scenario: 5 players, 3 skips, 15 rounds
  Lives: 6
  Total Resources (lives + skips): 9
  Success Rate: 0.0%
  Verdict: ✗ WILL LIKELY FAIL
  Confidence: HIGH

Scenario: 2 players, 5 skips, 10 rounds
  Lives: 3
  Total Resources (lives + skips): 8
  Success Rate: 96.0%
  Verdict: ✓ CAN SURVIVE
  Confidence: HIGH



# Question 5: Neuron Matrix State Transition

You have a 2D matrix of numbers representing neurons. Each neuron has a state (firing or not firing) and transitions to a new state based on its neighbors.

## Rules:

### Neuron States:
- **Firing neuron**: value > 0
- **Non-firing neuron**: value = 0

### State Transition Rules:
1. **Firing neuron** (value > 0):
   - If exactly 3 neighbors are firing → set to 6
   - Otherwise → keep current value

2. **Non-firing neuron** (value = 0):
   - If 0 or 1 neighbors are firing → decrement by 2 (cannot go below 0)
   - If more than 3 neighbors are firing → decrement by 1 (cannot go below 0)
   - Otherwise → keep current value

### Neighbors:
- A neuron's neighbors are the up to 8 surrounding cells (horizontal, vertical, and diagonal)
- Edge and corner cells have fewer neighbors

## Task:
Given an `input_state` matrix, compute and return the `next_state` matrix.

In [95]:
import numpy as np
from typing import List

def count_firing_neighbors(matrix: List[List[int]], row: int, col: int) -> int:
    """
    Count the number of firing neighbors (value > 0) for a given cell.
    
    Args:
        matrix: The 2D matrix of neuron states
        row: Row index of the cell
        col: Column index of the cell
    
    Returns:
        Number of firing neighbors
    """
    rows = len(matrix)
    cols = len(matrix[0])
    
    # All 8 possible neighbor directions (including diagonals)
    directions = [
        (-1, -1), (-1, 0), (-1, 1),  # top-left, top, top-right
        (0, -1),           (0, 1),   # left, right
        (1, -1),  (1, 0),  (1, 1)    # bottom-left, bottom, bottom-right
    ]
    
    firing_count = 0
    
    for dr, dc in directions:
        new_row = row + dr
        new_col = col + dc
        
        # Check if neighbor is within bounds
        if 0 <= new_row < rows and 0 <= new_col < cols:
            if matrix[new_row][new_col] > 0:
                firing_count += 1
    
    return firing_count

In [96]:
def compute_next_state(input_state: List[List[int]]) -> List[List[int]]:
    """
    Compute the next state of the neuron matrix based on transition rules.
    
    Args:
        input_state: 2D matrix of current neuron states
    
    Returns:
        2D matrix of next neuron states
    """
    rows = len(input_state)
    cols = len(input_state[0])
    
    # Create a new matrix for the next state
    next_state = [[0 for _ in range(cols)] for _ in range(rows)]
    
    for row in range(rows):
        for col in range(cols):
            current_value = input_state[row][col]
            firing_neighbors = count_firing_neighbors(input_state, row, col)
            
            if current_value > 0:
                # Firing neuron
                if firing_neighbors == 3:
                    next_state[row][col] = 6
                else:
                    next_state[row][col] = current_value
            else:
                # Non-firing neuron (value = 0)
                if firing_neighbors <= 1:
                    # Decrement by 2 (but cannot go below 0)
                    next_state[row][col] = max(0, current_value - 2)
                elif firing_neighbors > 3:
                    # Decrement by 1 (but cannot go below 0)
                    next_state[row][col] = max(0, current_value - 1)
                else:
                    # 2 or 3 neighbors firing - keep current value
                    next_state[row][col] = current_value
    
    return next_state

In [97]:
def print_matrix(matrix: List[List[int]], title: str = "Matrix"):
    """
    Pretty print a matrix.
    
    Args:
        matrix: 2D matrix to print
        title: Title for the matrix
    """
    print(f"\n{title}:")
    print("-" * 40)
    for row in matrix:
        print("  ", end="")
        for val in row:
            print(f"{val:3}", end=" ")
        print()
    print()

# Test Example 1: Simple 3x3 matrix
print("="*60)
print("TEST EXAMPLE 1: 3x3 Matrix")
print("="*60)

input_state_1 = [
    [0, 1, 0],
    [1, 2, 1],
    [0, 1, 0]
]

print_matrix(input_state_1, "Input State")

# Count neighbors for each cell
print("Firing neighbor counts:")
for i in range(3):
    print(f"  Row {i}: ", end="")
    for j in range(3):
        count = count_firing_neighbors(input_state_1, i, j)
        print(f"{count} ", end="")
    print()

next_state_1 = compute_next_state(input_state_1)
print_matrix(next_state_1, "Next State")

print("\nExplanation:")
print("  - Center cell (2, firing): has 4 firing neighbors → stays 2")
print("  - Top center (1, firing): has 2 firing neighbors → stays 1")
print("  - Corners (0, non-firing): have 2 firing neighbors → stay 0")
print("  - Sides (1, firing): have 2-3 firing neighbors → stay 1")

TEST EXAMPLE 1: 3x3 Matrix

Input State:
----------------------------------------
    0   1   0 
    1   2   1 
    0   1   0 

Firing neighbor counts:
  Row 0: 3 3 3 
  Row 1: 3 4 3 
  Row 2: 3 3 3 

Next State:
----------------------------------------
    0   6   0 
    6   2   6 
    0   6   0 


Explanation:
  - Center cell (2, firing): has 4 firing neighbors → stays 2
  - Top center (1, firing): has 2 firing neighbors → stays 1
  - Corners (0, non-firing): have 2 firing neighbors → stay 0
  - Sides (1, firing): have 2-3 firing neighbors → stay 1


In [98]:
# Test Example 2: Matrix with firing neurons that have exactly 3 neighbors
print("="*60)
print("TEST EXAMPLE 2: Firing neuron with exactly 3 firing neighbors")
print("="*60)

input_state_2 = [
    [5, 3, 0, 0],
    [2, 4, 1, 0],
    [0, 1, 0, 0],
    [0, 0, 0, 0]
]

print_matrix(input_state_2, "Input State")

# Analyze specific cell
target_row, target_col = 1, 1  # Cell with value 4
neighbors = count_firing_neighbors(input_state_2, target_row, target_col)
print(f"Cell [{target_row}][{target_col}] (value={input_state_2[target_row][target_col]}) has {neighbors} firing neighbors")

next_state_2 = compute_next_state(input_state_2)
print_matrix(next_state_2, "Next State")

print("\nExplanation:")
print(f"  - Cell [1][1] (value 4): has 5 firing neighbors → stays 4")
print(f"  - Cell [0][0] (value 5): has 3 firing neighbors → becomes 6 ✓")
print(f"  - Cell [2][1] (value 1): has 4 firing neighbors → stays 1")
print(f"  - Non-firing cells with 0-1 neighbors would decrement by 2 (but already 0)")

TEST EXAMPLE 2: Firing neuron with exactly 3 firing neighbors

Input State:
----------------------------------------
    5   3   0   0 
    2   4   1   0 
    0   1   0   0 
    0   0   0   0 

Cell [1][1] (value=4) has 5 firing neighbors

Next State:
----------------------------------------
    6   3   0   0 
    2   4   6   0 
    0   6   0   0 
    0   0   0   0 


Explanation:
  - Cell [1][1] (value 4): has 5 firing neighbors → stays 4
  - Cell [0][0] (value 5): has 3 firing neighbors → becomes 6 ✓
  - Cell [2][1] (value 1): has 4 firing neighbors → stays 1
  - Non-firing cells with 0-1 neighbors would decrement by 2 (but already 0)


In [99]:
# Test Example 3: Edge cases with non-firing neurons
print("="*60)
print("TEST EXAMPLE 3: Non-firing neuron behavior")
print("="*60)

# Create a scenario where non-firing neurons have different neighbor counts
input_state_3 = [
    [0, 0, 0, 0, 0],
    [0, 1, 1, 1, 0],
    [0, 1, 0, 1, 0],
    [0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0]
]

print_matrix(input_state_3, "Input State")

# Analyze the center non-firing cell
center_neighbors = count_firing_neighbors(input_state_3, 2, 2)
print(f"Center cell [2][2] (non-firing, value=0) has {center_neighbors} firing neighbors")

# Analyze corner non-firing cell
corner_neighbors = count_firing_neighbors(input_state_3, 0, 0)
print(f"Corner cell [0][0] (non-firing, value=0) has {corner_neighbors} firing neighbors")

next_state_3 = compute_next_state(input_state_3)
print_matrix(next_state_3, "Next State")

print("\nExplanation:")
print(f"  - Center [2][2] (non-firing): has 8 firing neighbors (>3) → max(0, 0-1) = 0")
print(f"  - Corner [0][0] (non-firing): has 1 firing neighbor (≤1) → max(0, 0-2) = 0")
print(f"  - Edge cells [0][2] (non-firing): has 3 firing neighbors (2-3) → stays 0")
print(f"  - Firing neurons: None have exactly 3 neighbors, so all stay same")

TEST EXAMPLE 3: Non-firing neuron behavior

Input State:
----------------------------------------
    0   0   0   0   0 
    0   1   1   1   0 
    0   1   0   1   0 
    0   1   1   1   0 
    0   0   0   0   0 

Center cell [2][2] (non-firing, value=0) has 8 firing neighbors
Corner cell [0][0] (non-firing, value=0) has 1 firing neighbors

Next State:
----------------------------------------
    0   0   0   0   0 
    0   1   1   1   0 
    0   1   0   1   0 
    0   1   1   1   0 
    0   0   0   0   0 


Explanation:
  - Center [2][2] (non-firing): has 8 firing neighbors (>3) → max(0, 0-1) = 0
  - Corner [0][0] (non-firing): has 1 firing neighbor (≤1) → max(0, 0-2) = 0
  - Edge cells [0][2] (non-firing): has 3 firing neighbors (2-3) → stays 0
  - Firing neurons: None have exactly 3 neighbors, so all stay same


In [100]:
# Test Example 4: Complex scenario with multiple state changes
print("="*60)
print("TEST EXAMPLE 4: Complex state transition")
print("="*60)

input_state_4 = [
    [3, 4, 1, 0],
    [2, 0, 5, 1],
    [1, 2, 0, 0],
    [0, 0, 1, 2]
]

print_matrix(input_state_4, "Input State")

print("Detailed neighbor analysis:")
for i in range(4):
    for j in range(4):
        val = input_state_4[i][j]
        neighbors = count_firing_neighbors(input_state_4, i, j)
        status = "firing" if val > 0 else "non-firing"
        print(f"  [{i}][{j}] val={val} ({status}): {neighbors} firing neighbors", end="")
        
        # Determine what happens
        if val > 0:
            if neighbors == 3:
                print(f" → becomes 6")
            else:
                print(f" → stays {val}")
        else:
            if neighbors <= 1:
                print(f" → max(0, {val}-2) = 0")
            elif neighbors > 3:
                print(f" → max(0, {val}-1) = 0")
            else:
                print(f" → stays {val}")

next_state_4 = compute_next_state(input_state_4)
print_matrix(next_state_4, "Next State")

TEST EXAMPLE 4: Complex state transition

Input State:
----------------------------------------
    3   4   1   0 
    2   0   5   1 
    1   2   0   0 
    0   0   1   2 

Detailed neighbor analysis:
  [0][0] val=3 (firing): 2 firing neighbors → stays 3
  [0][1] val=4 (firing): 4 firing neighbors → stays 4
  [0][2] val=1 (firing): 3 firing neighbors → becomes 6
  [0][3] val=0 (non-firing): 3 firing neighbors → stays 0
  [1][0] val=2 (firing): 4 firing neighbors → stays 2
  [1][1] val=0 (non-firing): 7 firing neighbors → max(0, 0-1) = 0
  [1][2] val=5 (firing): 4 firing neighbors → stays 5
  [1][3] val=1 (firing): 2 firing neighbors → stays 1
  [2][0] val=1 (firing): 2 firing neighbors → stays 1
  [2][1] val=2 (firing): 4 firing neighbors → stays 2
  [2][2] val=0 (non-firing): 5 firing neighbors → max(0, 0-1) = 0
  [2][3] val=0 (non-firing): 4 firing neighbors → max(0, 0-1) = 0
  [3][0] val=0 (non-firing): 2 firing neighbors → stays 0
  [3][1] val=0 (non-firing): 3 firing neighbors → s

In [101]:
# Test Example 5: Multiple iterations to see evolution
print("="*60)
print("TEST EXAMPLE 5: Multi-step evolution")
print("="*60)

# Start with a simple pattern
current_state = [
    [0, 0, 0, 0, 0],
    [0, 1, 2, 1, 0],
    [0, 2, 3, 2, 0],
    [0, 1, 2, 1, 0],
    [0, 0, 0, 0, 0]
]

print_matrix(current_state, "Initial State (Step 0)")

# Run 3 iterations
for step in range(1, 4):
    current_state = compute_next_state(current_state)
    print_matrix(current_state, f"State after Step {step}")
    
print("Observation: The pattern evolves as neurons fire and react to their neighbors.")

TEST EXAMPLE 5: Multi-step evolution

Initial State (Step 0):
----------------------------------------
    0   0   0   0   0 
    0   1   2   1   0 
    0   2   3   2   0 
    0   1   2   1   0 
    0   0   0   0   0 


State after Step 1:
----------------------------------------
    0   0   0   0   0 
    0   6   2   6   0 
    0   2   3   2   0 
    0   6   2   6   0 
    0   0   0   0   0 


State after Step 2:
----------------------------------------
    0   0   0   0   0 
    0   6   2   6   0 
    0   2   3   2   0 
    0   6   2   6   0 
    0   0   0   0   0 


State after Step 3:
----------------------------------------
    0   0   0   0   0 
    0   6   2   6   0 
    0   2   3   2   0 
    0   6   2   6   0 
    0   0   0   0   0 

Observation: The pattern evolves as neurons fire and react to their neighbors.


## Summary of Rules

Let's verify our implementation matches all the rules:

### Firing Neuron (value > 0):
- ✓ If exactly 3 neighbors are firing → set to 6
- ✓ Otherwise → keep current value

### Non-Firing Neuron (value = 0):
- ✓ If 0 or 1 neighbors are firing → decrement by 2 (cannot go below 0)
- ✓ If more than 3 neighbors are firing → decrement by 1 (cannot go below 0)
- ✓ If 2 or 3 neighbors are firing → keep current value (stays 0)

The implementation correctly handles all edge cases including:
- Cells at corners (3 neighbors)
- Cells at edges (5 neighbors)
- Interior cells (8 neighbors)
- All neurons update simultaneously based on the current state

# Question 6: Shortest Distance in a Tree

Given a tree represented as a dictionary and two nodes in the tree, find the shortest distance between the two nodes.

## Tree Representation:
The tree is represented as a dictionary where:
- Keys are node names
- Values are lists of children nodes

## Task:
Implement a function to find the shortest distance (number of edges) between two nodes in the tree.

## Follow-up Questions:
1. **Two sets of length 2**: Given `a = [a1, a2]` and `b = [b1, b2]`, find the shortest distance between any pair `(x, y)` where `x ∈ a` and `y ∈ b`
2. **Arbitrary-length sets**: Extend to handle arbitrary-length sets `a` and `b`

In [102]:
from typing import Dict, List, Optional
from collections import deque

# Example tree represented as a dictionary (parent -> children mapping)
tree = {
    'A': ['B', 'C'],
    'B': ['D', 'E'],
    'C': ['F'],
    'D': [],
    'E': ['G', 'H'],
    'F': [],
    'G': [],
    'H': []
}

print("Tree Structure:")
print("="*50)
print("         A")
print("        / \\")
print("       B   C")
print("      / \\   \\")
print("     D   E   F")
print("        / \\")
print("       G   H")
print("\nTree Dictionary:")
for parent, children in tree.items():
    print(f"  {parent}: {children}")

Tree Structure:
         A
        / \
       B   C
      / \   \
     D   E   F
        / \
       G   H

Tree Dictionary:
  A: ['B', 'C']
  B: ['D', 'E']
  C: ['F']
  D: []
  E: ['G', 'H']
  F: []
  G: []
  H: []


In [103]:
def build_graph(tree: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    Build an undirected graph (adjacency list) from the tree dictionary.
    The tree dict has parent -> children, we need bidirectional edges.
    
    Args:
        tree: Dictionary mapping nodes to their children
    
    Returns:
        Adjacency list with bidirectional edges
    """
    graph = {}
    
    # Initialize all nodes
    for node in tree:
        if node not in graph:
            graph[node] = []
    
    # Add bidirectional edges
    for parent, children in tree.items():
        for child in children:
            if child not in graph:
                graph[child] = []
            graph[parent].append(child)
            graph[child].append(parent)
    
    return graph

def find_shortest_distance(tree: Dict[str, List[str]], node1: str, node2: str) -> Optional[int]:
    """
    Find the shortest distance between two nodes in a tree using BFS.
    
    Args:
        tree: Dictionary representing the tree (parent -> children)
        node1: First node
        node2: Second node
    
    Returns:
        Shortest distance (number of edges) between the nodes, or None if no path exists
    """
    # Build bidirectional graph
    graph = build_graph(tree)
    
    # Check if nodes exist
    if node1 not in graph or node2 not in graph:
        return None
    
    # If same node, distance is 0
    if node1 == node2:
        return 0
    
    # BFS to find shortest path
    queue = deque([(node1, 0)])  # (current_node, distance)
    visited = {node1}
    
    while queue:
        current, dist = queue.popleft()
        
        # Check neighbors
        for neighbor in graph[current]:
            if neighbor == node2:
                return dist + 1
            
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, dist + 1))
    
    # No path found
    return None

In [104]:
# Test the shortest distance function
print("="*60)
print("TEST: Finding Shortest Distances")
print("="*60)

test_cases = [
    ('D', 'H'),  # D to H
    ('A', 'G'),  # A to G
    ('F', 'D'),  # F to D
    ('B', 'C'),  # B to C
    ('A', 'A'),  # Same node
    ('G', 'H'),  # Siblings
]

for node1, node2 in test_cases:
    distance = find_shortest_distance(tree, node1, node2)
    print(f"\nDistance from {node1} to {node2}: {distance}")
    
    # Show path explanation for some cases
    if (node1, node2) == ('D', 'H'):
        print("  Path: D -> B -> E -> H")
    elif (node1, node2) == ('A', 'G'):
        print("  Path: A -> B -> E -> G")
    elif (node1, node2) == ('F', 'D'):
        print("  Path: F -> C -> A -> B -> D")
    elif (node1, node2) == ('B', 'C'):
        print("  Path: B -> A -> C")
    elif (node1, node2) == ('G', 'H'):
        print("  Path: G -> E -> H")

TEST: Finding Shortest Distances

Distance from D to H: 3
  Path: D -> B -> E -> H

Distance from A to G: 3
  Path: A -> B -> E -> G

Distance from F to D: 4
  Path: F -> C -> A -> B -> D

Distance from B to C: 2
  Path: B -> A -> C

Distance from A to A: 0

Distance from G to H: 2
  Path: G -> E -> H


## Follow-up 1: Shortest Distance Between Two Sets (Length 2)

Given two sets `a = [a1, a2]` and `b = [b1, b2]`, find the shortest distance between any pair `(x, y)` where `x ∈ a` and `y ∈ b`.

In [105]:
def shortest_distance_two_sets_length2(tree: Dict[str, List[str]], 
                                       a: List[str], 
                                       b: List[str]) -> Optional[int]:
    """
    Find shortest distance between two sets of nodes (both of length 2).
    
    Args:
        tree: Dictionary representing the tree
        a: List of 2 nodes [a1, a2]
        b: List of 2 nodes [b1, b2]
    
    Returns:
        Minimum distance among all pairs (x, y) where x in a and y in b
    """
    if len(a) != 2 or len(b) != 2:
        raise ValueError("Both sets must have exactly 2 elements")
    
    min_distance = float('inf')
    best_pair = None
    
    # Check all 4 combinations: (a1,b1), (a1,b2), (a2,b1), (a2,b2)
    for node_a in a:
        for node_b in b:
            dist = find_shortest_distance(tree, node_a, node_b)
            if dist is not None and dist < min_distance:
                min_distance = dist
                best_pair = (node_a, node_b)
    
    if min_distance == float('inf'):
        return None
    
    print(f"  Best pair: {best_pair[0]} -> {best_pair[1]}")
    return min_distance

# Test with sets of length 2
print("="*60)
print("FOLLOW-UP 1: Two Sets of Length 2")
print("="*60)

test_cases_sets = [
    (['D', 'F'], ['G', 'H']),  # Left leaves vs right leaves under E
    (['A', 'B'], ['F', 'G']),  # Top nodes vs bottom nodes
    (['D', 'E'], ['C', 'F']),  # Left side vs right side
]

for a, b in test_cases_sets:
    print(f"\na = {a}, b = {b}")
    distance = shortest_distance_two_sets_length2(tree, a, b)
    print(f"Shortest distance: {distance}")

FOLLOW-UP 1: Two Sets of Length 2

a = ['D', 'F'], b = ['G', 'H']
  Best pair: D -> G
Shortest distance: 3

a = ['A', 'B'], b = ['F', 'G']
  Best pair: A -> F
Shortest distance: 2

a = ['D', 'E'], b = ['C', 'F']
  Best pair: D -> C
Shortest distance: 3


## Follow-up 2: Shortest Distance Between Arbitrary-Length Sets

Extend to handle arbitrary-length sets `a` and `b`.

In [106]:
def shortest_distance_arbitrary_sets(tree: Dict[str, List[str]], 
                                     a: List[str], 
                                     b: List[str]) -> Optional[int]:
    """
    Find shortest distance between two arbitrary-length sets of nodes.
    
    Args:
        tree: Dictionary representing the tree
        a: List of nodes (arbitrary length)
        b: List of nodes (arbitrary length)
    
    Returns:
        Minimum distance among all pairs (x, y) where x in a and y in b,
        along with the best pair
    """
    if not a or not b:
        return None
    
    min_distance = float('inf')
    best_pair = None
    
    # Check all combinations
    for node_a in a:
        for node_b in b:
            dist = find_shortest_distance(tree, node_a, node_b)
            if dist is not None and dist < min_distance:
                min_distance = dist
                best_pair = (node_a, node_b)
    
    if min_distance == float('inf'):
        return None
    
    print(f"  Best pair: {best_pair[0]} -> {best_pair[1]}")
    return min_distance

# Test with arbitrary-length sets
print("="*60)
print("FOLLOW-UP 2: Arbitrary-Length Sets")
print("="*60)

test_cases_arbitrary = [
    (['D', 'F', 'G'], ['H']),           # 3 nodes vs 1 node
    (['A'], ['D', 'E', 'F', 'G', 'H']), # 1 node vs 5 nodes
    (['D', 'F'], ['G', 'H', 'C']),      # 2 nodes vs 3 nodes
    (['A', 'B', 'C'], ['D', 'E', 'F']), # 3 nodes vs 3 nodes
]

for a, b in test_cases_arbitrary:
    print(f"\na = {a} (size {len(a)})")
    print(f"b = {b} (size {len(b)})")
    distance = shortest_distance_arbitrary_sets(tree, a, b)
    print(f"Shortest distance: {distance}")

FOLLOW-UP 2: Arbitrary-Length Sets

a = ['D', 'F', 'G'] (size 3)
b = ['H'] (size 1)
  Best pair: G -> H
Shortest distance: 2

a = ['A'] (size 1)
b = ['D', 'E', 'F', 'G', 'H'] (size 5)
  Best pair: A -> D
Shortest distance: 2

a = ['D', 'F'] (size 2)
b = ['G', 'H', 'C'] (size 3)
  Best pair: F -> C
Shortest distance: 1

a = ['A', 'B', 'C'] (size 3)
b = ['D', 'E', 'F'] (size 3)
  Best pair: B -> D
Shortest distance: 1


## Optimization: Multi-Source BFS

For large sets, we can optimize by using multi-source BFS instead of checking all pairs individually.

**Time Complexity Comparison:**
- Naive approach: O(|a| × |b| × (V + E)) where V is vertices, E is edges
- Multi-source BFS: O((|a| + |b|) × (V + E))

The multi-source BFS approach is much more efficient when sets are large.

In [107]:
def shortest_distance_optimized(tree: Dict[str, List[str]], 
                                a: List[str], 
                                b: List[str]) -> Optional[int]:
    """
    Optimized version using multi-source BFS.
    Start BFS from all nodes in set 'a' simultaneously and find first node in set 'b'.
    
    Args:
        tree: Dictionary representing the tree
        a: List of source nodes
        b: List of target nodes
    
    Returns:
        Minimum distance from any node in a to any node in b
    """
    if not a or not b:
        return None
    
    # Build graph
    graph = build_graph(tree)
    
    # Check for nodes that exist in both sets
    overlap = set(a) & set(b)
    if overlap:
        print(f"  Found overlap: {overlap}")
        return 0
    
    # Multi-source BFS: start from all nodes in 'a'
    queue = deque()
    visited = set()
    b_set = set(b)
    
    # Initialize queue with all nodes from set 'a' at distance 0
    for node in a:
        if node in graph:
            queue.append((node, 0, node))  # (current_node, distance, source_node)
            visited.add(node)
    
    # BFS
    while queue:
        current, dist, source = queue.popleft()
        
        # Check if we reached any node in set b
        if current in b_set:
            print(f"  Best pair: {source} -> {current}")
            return dist
        
        # Explore neighbors
        for neighbor in graph[current]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, dist + 1, source))
    
    return None

# Test optimized version
print("="*60)
print("OPTIMIZED VERSION: Multi-Source BFS")
print("="*60)

for a, b in test_cases_arbitrary:
    print(f"\na = {a} (size {len(a)})")
    print(f"b = {b} (size {len(b)})")
    distance = shortest_distance_optimized(tree, a, b)
    print(f"Shortest distance: {distance}")

OPTIMIZED VERSION: Multi-Source BFS

a = ['D', 'F', 'G'] (size 3)
b = ['H'] (size 1)
  Best pair: G -> H
Shortest distance: 2

a = ['A'] (size 1)
b = ['D', 'E', 'F', 'G', 'H'] (size 5)
  Best pair: A -> D
Shortest distance: 2

a = ['D', 'F'] (size 2)
b = ['G', 'H', 'C'] (size 3)
  Best pair: F -> C
Shortest distance: 1

a = ['A', 'B', 'C'] (size 3)
b = ['D', 'E', 'F'] (size 3)
  Best pair: B -> D
Shortest distance: 1


## Summary

We've implemented three approaches to find shortest distances in a tree:

1. **Basic Approach**: Find shortest distance between two individual nodes
   - Uses BFS from one node to another
   - Time: O(V + E)

2. **Naive Set Approach**: Find shortest distance between two sets
   - Checks all pairs (x, y) where x ∈ a, y ∈ b
   - Time: O(|a| × |b| × (V + E))

3. **Optimized Multi-Source BFS**: Efficient version for large sets
   - Starts BFS from all nodes in set 'a' simultaneously
   - Stops when any node in set 'b' is reached
   - Time: O((|a| + |b|) × (V + E))
   
**When to use each:**
- Use the basic approach for single node-to-node queries
- Use naive approach when sets are small (|a| × |b| < 10)
- Use optimized approach for larger sets or repeated queries

In [108]:
print(len('fff'))

3
