In [11]:
from dataclasses import dataclass
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from torch import distributions as D

SEED = None
rng = np.random.RandomState(SEED)

In [36]:
Player = str

class Game:
    def __init__(self, 
                 player_hit_probs: Dict[Player, float],  # name and likelihood of hit for each player
                 seed: Optional[int] = None,  # seed for the rng for this game
                ):
        
        self.player_hit_probs = player_hit_probs
        self.n_players = len(player_hit_probs)
        self.players = set(player_hit_probs.keys())
        
        # init rng for the game
        self._seed = seed or np.random.randint(int(1e8))
        self.rng = np.random.RandomState(self._seed)
        
        # will all be init'd in self.reset, which we won't call in init to mirror patterns in gym.Env
        self.step_number = None
        self.order: List[Player] = None  # order players go in
        self.player_shots: Dict[Player, Dict[str, int]] = None  # how many hits and misses shot by each player
        self.alive: Dict[Player, bool] = None  # is the player still alive?
        self.num_shot_at: Dict[player, int] = None
        self._curr_player_idx = None  # for keeping track of turn
        
    def inc_player(self) -> Player:
        """move the game to the next player in order that's still alive"""
        prev_idx = self._curr_player_idx
        done = False
        while not done:
            self._curr_player_idx = (self._curr_player_idx + 1) % self.n_players  # inc
            player = self.order[self._curr_player_idx]
            if self.alive[player]:  # stop inc'ing if next player in order is still alive
                done = True
        return player
        
    def reset(self):
        """reset a new game with the same players and true probs of a hit"""
        print(f"player probs for this game:")
        for player, prob in self.player_hit_probs.items():
            print(f"\t{player}: {prob:0.2f}")
        
        # set number of player hits and misses to 0
        # set alive to True for all players
        # set num times shot at to 0 for all players
        self.player_shots = dict()
        self.alive = dict()
        self.num_shot_at = dict()
        for player in self.players:
            self.player_shots[player] = dict(hits=0, misses=0)
            self.alive[player] = True
            self.num_shot_at[player] = 0
            
        # set the order
        self.order = self.rng.permutation(list(self.players))
        print(f"player order for this game: {self.order}")
        self._curr_player_idx = -1
        self.step_number = 0
        
    def player_beta_distr(self, player: Player) -> D.Beta:
        """get the beta distribution for a player given the number 
        of shots they've hit and missed so far"""
        shots = self.player_shots[player]
        output = D.Beta(
            concentration0=shots["hits"] + 1., 
            concentration1=shots["misses"] + 1.)
        return output
        
    def step(self):
        """have the next player take their turn, update game state, and return if game over"""
        self.step_number += 1
        # move to next player
        shooter = self.inc_player()
        
        # determine who to shoot at
        target = self.choose_target(shooter=shooter)
        
        # shoot and see if hit
        hit = self.rng.uniform() < self.player_hit_probs[shooter]
        
        self.num_shot_at[target] += 1
        
        if hit:
            self.player_shots[shooter]["hits"] += 1.
            self.alive[target] = False
        else:
            self.player_shots[shooter]["misses"] += 1.
            
        outcome = "hit" if hit else "missed"
        print(f"{shooter} shot at {target} with prob {self.player_hit_probs[shooter]:0.2f} and {outcome} on step {self.step_number}")
        # determine if the game is done
        done = sum(self.alive.values()) == 1
        if done:
            print(f"player {shooter} wins on turn {self.step_number}!")
        return done
        
        
    def choose_target(self, shooter: Player) -> Player:
        """determine which player a shooter will shoot at"""

        # this is one of many strategies you can put here, including
        # going all the way to trained neural networks implementing learned
        # RL policies.
        #
        # for now, we'll determine who one shoots simply as whoever's inferred
        # to have the best aim.  AKA people greedily always try to take the assumed
        # "best" player out
        # 
        # best aim here will be whoever has the highest probabiliy of hit drawn from a beta
        # distribution after 1000 samples from each beta
        
        # get beta disctibution of p(hit) for each player from their observed shots
        player_betas = {player: self.player_beta_distr(player=player)
                        for player in self.players 
                        if (player != shooter and self.alive[player])}
        # sample probs from the beta distrs
        inferred_probs = {player: beta.sample((1000,)) 
                           for player, beta in player_betas.items()}
        # determine who had the highest prob most often in the 1000 samples from the betas
        df = pd.DataFrame(inferred_probs)
        maxes = df.values.argmax(axis=1)
        idxs, idx_counts = np.unique(maxes, return_counts=True)
        max_idx = idxs[idx_counts.argmax()]
        # shoot at the person who had the highest prob in the most samples
        to_shoot = df.columns[max_idx]
        
        return to_shoot

In [37]:
N = 5
SEED = None

probs = {"player_" + str(k): v for k, v in enumerate(np.linspace(0, 1, N+2)[1:-1])}
game = Game(player_hit_probs=probs, seed=SEED)

game.reset()

done = False
while not done:
    done = game.step()
    if game.step_number > 100:
        break

player probs for this game:
	player_0: 0.17
	player_1: 0.33
	player_2: 0.50
	player_3: 0.67
	player_4: 0.83
player order for this game: ['player_1' 'player_3' 'player_4' 'player_2' 'player_0']
player_1 shot at player_2 with prob 0.33 and missed on step 1
player_3 shot at player_1 with prob 0.67 and missed on step 2
player_4 shot at player_1 with prob 0.83 and hit on step 3
player_2 shot at player_3 with prob 0.50 and missed on step 4
player_0 shot at player_3 with prob 0.17 and missed on step 5
player_3 shot at player_2 with prob 0.67 and missed on step 6
player_4 shot at player_3 with prob 0.83 and hit on step 7
player_2 shot at player_0 with prob 0.50 and hit on step 8
player_4 shot at player_2 with prob 0.83 and missed on step 9
player_2 shot at player_4 with prob 0.50 and hit on step 10
player player_2 wins on turn 10!


In [34]:
np.linspace(0, 1, 7)

array([0.        , 0.16666667, 0.33333333, 0.5       , 0.66666667,
       0.83333333, 1.        ])