# Learning to learn

We explore the bandit problem by representing it as an imperfect information game and solving it with a counterfactual regret minimization algorithm.

In [19]:
import zerosum as zs

from typing import cast
from dataclasses import dataclass, replace
from itertools import islice, chain
import random

## The bandit game

The bandit player chooses from a set of arms at each round, and receives a reward from the arm chosen. The reward is drawn from a distribution that depends on the arm chosen. The player's goal is to maximize their reward. The number of rounds is fixed before-hand.

Our framework can currently only handle zero-sum two player games. We therefore study an alternative problem where two players play the game simultaneously. They do not share information. The payoff of each player is adjusted by the payoff of their opponent.

In [75]:
@dataclass(slots=True, frozen=True)
class Init:
    means: tuple[float, ...]


@dataclass(slots=True, frozen=True)
class Choice:
    arm: int


@dataclass(slots=True, frozen=True)
class Reward:
    reward: float


Action = Init | Reward | Choice


@dataclass(slots=True, frozen=True)
class InfoSet:
    round: int
    arms: int
    history: tuple[Action, ...]

    def actions(self):
        return tuple(Choice(i) for i in range(self.arms))


@dataclass(slots=True, frozen=True)
class Learn:
    rounds: int
    arms: int

    round: int = 0
    history: tuple[Action, ...] = ()

    @property
    def means(self):
        return cast(Init, self.history[0]).means

    @property
    def terminal(self):
        return self.round >= self.rounds

    def payoff(self, player: zs.Player):
        s0 = islice(self.history, 2, None, 4)
        s1 = islice(self.history, 4, None, 4)

        score = 0
        for a0, a1 in zip(s0, s1):
            a0 = cast(Reward, a0)
            a1 = cast(Reward, a1)
            score += a0.reward - a1.reward

        return score if player == 0 else -score

    @property
    def chance(self):
        return len(self.history) % 2 == 0

    def chances(self):
        raise NotImplementedError
    
    def _init(self):
        means = tuple(random.random() for _ in range(self.arms))
        return Init(means)
    
    # overriding this method in a subclass is an easy method to create
    # a bandit game with different reward distributions

    def sample(self) -> Action:
        if len(self.history) == 0:
            return self._init()

        *_, action = self.history
        action = cast(Choice, action)
        if random.random() > self.means[action.arm]:
            return Reward(1)
        return Reward(0)

    @property
    def active(self) -> zs.Player:
        if len(self.history) % 4 == 1:
            return zs.Player(0)
        return zs.Player(1)

    def infoset(self, player: zs.Player):
        choices, rewards = self.history[1::4], self.history[2::4]
        if player == 1:
            choices, rewards = self.history[3::4], self.history[4::4]

        history = tuple(chain.from_iterable(zip(choices, rewards)))
        return InfoSet(self.round, self.arms, history)

    def apply(self, action: Action):
        round = self.round
        if len(self.history) > 1 and (len(self.history) - 1) % 4 == 0:
            round += 1

        return replace(self, round=round, history=self.history + (action,))


_: zs.Game[Action, InfoSet] = Learn(0, 3)

We can immediately attempt to find a Nash equilibrium for this game, but we will find the number of infosets is quite large compared to the essential complexity of the game.

In [76]:
def game():
    return Learn(5, 2)


impl = zs.ESLCFR(1000)
algo = zs.Algorithm(impl, game)

In [77]:
from tqdm import tqdm


for _ in tqdm(range(100)):
    algo.once()

100%|██████████| 100/100 [00:00<00:00, 145.60it/s]


In [78]:
len(impl.strategies)

1624

We can construct an abstraction of the game by capturing only the essential information. The player how many times each arm was pulled and their average payoff. The latter value is bucketed. The player also knows which round they are playing (it is implicit in the last information).

In [89]:
def aggregate(buckets: int):
    @zs.algebraic
    def aggregate(infoset: zs.InfoSet):
        infoset = cast(InfoSet, infoset)
        means = [0.0] * infoset.arms
        ns = [0] * infoset.arms

        for choice, action in zip(infoset.history[::2], infoset.history[1::2]):
            arm = cast(Choice, choice).arm
            reward = cast(Reward, action).reward

            means[arm] = (reward + means[arm] * ns[arm]) / (ns[arm] + 1)
            ns[arm] += 1
        
        return tuple(round(m * buckets) for m in means)
    return aggregate


@zs.algebraic
def chosen(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)

    arms = [0] * infoset.arms
    for choice in infoset.history[::2]:
        arm = cast(Choice, choice).arm
        arms[arm] += 1

    return tuple(arms)

def actions(infoset: zs.InfoSet):
    return infoset.actions()


def abstract(buckets: int):
    @zs.abstract(Learn, aggregate(buckets) * chosen, actions)
    class Abstraction(Learn):
        ...

    return Abstraction

In [90]:
abstraction = cast(type[Learn], abstract(3))

def game(rounds: int = 3, arms: int = 2):
    return abstraction(rounds, arms)


impl = zs.ESCFRP()
algo = zs.Algorithm(impl, game)

In [94]:
for _ in tqdm(range(5000)):
    algo.once()

100%|██████████| 5000/5000 [00:16<00:00, 306.53it/s]


Notice there are significantly fewer information sets !

In [95]:
len(impl.strategies)

35

In [97]:
zs.normalize(impl.strategies)

{((0, 0), (0, 0)): {Choice(arm=0): 0.5671479099346806,
  Choice(arm=1): 0.4328520900653195},
 ((0, 0), (1, 0)): {Choice(arm=0): 0.04877598623379356,
  Choice(arm=1): 0.9512240137662065},
 ((3, 0), (1, 0)): {Choice(arm=0): 0.9632993137802633,
  Choice(arm=1): 0.03670068621973664},
 ((0, 0), (2, 0)): {Choice(arm=0): 0.02939603456318818,
  Choice(arm=1): 0.9706039654368118},
 ((2, 0), (2, 0)): {Choice(arm=0): 0.3052337896421444,
  Choice(arm=1): 0.6947662103578556},
 ((0, 0), (3, 0)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 ((0, 0), (2, 1)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 ((0, 0), (1, 1)): {Choice(arm=0): 0.34235200386273057,
  Choice(arm=1): 0.6576479961372695},
 ((3, 0), (1, 1)): {Choice(arm=0): 0.9850118269599332,
  Choice(arm=1): 0.014988173040066853},
 ((0, 2), (1, 2)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 ((0, 0), (0, 1)): {Choice(arm=0): 0.9536096268503257,
  Choice(arm=1): 0.04639037314967437},
 ((0, 0), (1, 2)): {Choice(arm=0): 0.5, Choice(arm=1): 0.5},
 ((0,