In [1]:
import zerosum as zs

from typing import cast
from dataclasses import dataclass, replace
from itertools import islice
import random

In [2]:
@dataclass(slots=True, frozen=True)
class Init:
    means: tuple[float, ...]


@dataclass(slots=True, frozen=True)
class Choice:
    arm: int


@dataclass(slots=True, frozen=True)
class Reward:
    reward: float


Action = Init | Reward | Choice


@dataclass(slots=True, frozen=True)
class InfoSet:
    player: zs.Player
    round: int
    arms: int
    history: tuple[Action, ...]

    def actions(self):
        return tuple(Choice(i) for i in range(self.arms))


@dataclass(slots=True, frozen=True)
class Learn:
    rounds: int
    arms: int

    round: int = 0
    history: tuple[Action, ...] = ()

    @property
    def means(self):
        return cast(Init, self.history[0]).means

    @property
    def terminal(self):
        return self.round >= self.rounds

    def payoff(self, player: zs.Player):
        s0 = islice(self.history, 2, None, 4)
        s1 = islice(self.history, 4, None, 4)

        score = 0
        for a0, a1 in zip(s0, s1):
            a0 = cast(Reward, a0)
            a1 = cast(Reward, a1)
            score += a0.reward - a1.reward

        return score if player == 0 else -score

    @property
    def chance(self):
        return len(self.history) % 2 == 0

    def chances(self):
        raise NotImplementedError
    
    def _init(self):
        means = tuple(random.random() for _ in range(self.arms))
        return Init(means)

    def sample(self) -> Action:
        if len(self.history) == 0:
            return self._init()

        *_, action = self.history
        action = cast(Choice, action)
        if random.random() > self.means[action.arm]:
            return Reward(1)
        return Reward(0)

    @property
    def active(self) -> zs.Player:
        if len(self.history) % 4 == 1:
            return zs.Player(0)
        return zs.Player(1)

    def infoset(self, player: zs.Player):
        return InfoSet(self.active, self.round, self.arms, self.history)

    def apply(self, action: Action):
        round = self.round
        if isinstance(action, Reward):
            round = self.round + 1
        return replace(self, round=round, history=self.history + (action,))


_: zs.Game[Action, InfoSet] = Learn(0, 3)

In [3]:
def game():
    return Learn(3, 3)


impl = zs.ESLCFR(1000)
algo = zs.Algorithm(impl, game)

In [4]:
from tqdm import tqdm


for _ in tqdm(range(1000)):
    algo.once()

100%|██████████| 1000/1000 [00:00<00:00, 1089.03it/s]


In [5]:
len(impl.strategies)

12000

In [6]:
@zs.algebraic
def player(infoset: zs.InfoSet):
    return cast(InfoSet, infoset).player


@zs.algebraic
def ground(infoset: zs.InfoSet):
    return cast(InfoSet, infoset).round


def aggregate(buckets: int):
    @zs.algebraic
    def aggregate(infoset: zs.InfoSet):
        infoset = cast(InfoSet, infoset)
        means = [0.0] * infoset.arms
        ns = [0] * infoset.arms

        for choice, action in zip(infoset.history[1::2], infoset.history[2::2]):
            arm = cast(Choice, choice).arm
            reward = cast(Reward, action).reward

            means[arm] = (reward + means[arm] * ns[arm]) / (ns[arm] + 1)
            ns[arm] += 1
        
        return tuple(round(m * buckets) for m in means)
    return aggregate


@zs.algebraic
def chosen(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)

    arms = [0] * infoset.arms
    for choice in infoset.history[1::2]:
        arm = cast(Choice, choice).arm
        arms[arm] += 1

    return tuple(arms)

def actions(infoset: zs.InfoSet):
    return infoset.actions()


def abstract(buckets: int):
    @zs.abstract(Learn, player * ground * aggregate(buckets) * chosen, actions)
    class Abstraction(Learn):
        ...

    return Abstraction

In [13]:
abstraction = cast(type[Learn], abstract(3))

def game(rounds: int = 5, arms: int = 2):
    return abstraction(rounds, arms)


impl = zs.ESCFRP()
algo = zs.Algorithm(impl, game)

In [15]:
for _ in tqdm(range(1000)):
    algo.once()

100%|██████████| 1000/1000 [00:02<00:00, 463.88it/s]


In [16]:
len(impl.strategies)

68

In [17]:
zs.normalize(impl.strategies)

{(0, 0, (0, 0), (0, 0)): {Choice(arm=0): 0.5019947757363732,
  Choice(arm=1): 0.49800522426362687},
 (1, 1, (3, 0), (1, 0)): {Choice(arm=0): 0.9077851208966068,
  Choice(arm=1): 0.09221487910339317},
 (0, 2, (3, 3), (1, 1)): {Choice(arm=0): 0.8265288012904325,
  Choice(arm=1): 0.17347119870956762},
 (1, 3, (3, 3), (2, 1)): {Choice(arm=0): 0.8227921780938728,
  Choice(arm=1): 0.17720782190612724},
 (0, 4, (3, 3), (2, 2)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 (1, 3, (3, 3), (1, 2)): {Choice(arm=0): 0.1574802924564671,
  Choice(arm=1): 0.842519707543533},
 (1, 1, (0, 3), (0, 1)): {Choice(arm=0): 0.07618676805689943,
  Choice(arm=1): 0.9238132319431006},
 (0, 4, (2, 3), (3, 1)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 (1, 3, (2, 3), (2, 1)): {Choice(arm=0): 0.061903037295513094,
  Choice(arm=1): 0.938096962704487},
 (0, 4, (2, 2), (2, 2)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 (0, 2, (0, 3), (0, 2)): {Choice(arm=0): 0.026911018302259365,
  Choice(arm=1): 0.9730889816977407},
