# Learning to learn

We explore the bandit problem by representing it as an imperfect information game and solving it with a counterfactual regret minimization algorithm.

In [3]:
import zerosum as zs

from typing import cast
from dataclasses import dataclass, replace
from itertools import islice, chain
import random

## The bandit game

The bandit player chooses from a set of arms at each round, and receives a reward from the arm chosen. The reward is drawn from a distribution that depends on the arm chosen. The player's goal is to maximize their reward. The number of rounds is fixed before-hand.

Our framework can currently only handle zero-sum two player games. We therefore study an alternative problem where two players play the game simultaneously. They do not share information. The payoff of each player is adjusted by the payoff of their opponent.

In [98]:
@dataclass(slots=True, frozen=True)
class Init:
    means: tuple[float, ...]


@dataclass(slots=True, frozen=True)
class Choice:
    arm: int


@dataclass(slots=True, frozen=True)
class Reward:
    reward: float


Action = Init | Reward | Choice


@dataclass(slots=True, frozen=True)
class InfoSet:
    round: int
    arms: int
    history: tuple[Action, ...]

    def actions(self):
        return tuple(Choice(i) for i in range(self.arms))


@dataclass(slots=True, frozen=True)
class Learn:
    rounds: int
    arms: int

    round: int = 0
    history: tuple[Action, ...] = ()

    @property
    def means(self):
        return cast(Init, self.history[0]).means

    @property
    def terminal(self):
        return self.round >= self.rounds

    def payoff(self, player: zs.Player):
        s0 = islice(self.history, 2, None, 4)
        s1 = islice(self.history, 4, None, 4)

        score0 = 0
        score1 = 0
        for a0, a1 in zip(s0, s1):
            a0 = cast(Reward, a0)
            a1 = cast(Reward, a1)
            score0 += a0.reward
            score1 += a1.reward

        return score0 if player == 0 else score1

    @property
    def chance(self):
        return len(self.history) % 2 == 0

    def chances(self):
        raise NotImplementedError
    
    def _init(self):
        means = tuple(random.random() for _ in range(self.arms))
        return Init(means)
    
    # overriding this method in a subclass is an easy method to create
    # a bandit game with different reward distributions

    def sample(self) -> Action:
        if len(self.history) == 0:
            return self._init()

        *_, action = self.history
        action = cast(Choice, action)
        if random.random() < self.means[action.arm]:
            return Reward(1)
        return Reward(0)

    @property
    def active(self) -> zs.Player:
        if (len(self.history) - 1) % 4 == 0:
            return zs.Player(0)
        return zs.Player(1)

    def infoset(self, player: zs.Player):
        choices, rewards = self.history[1::4], self.history[2::4]
        if player == 1:
            choices, rewards = self.history[3::4], self.history[4::4]

        history = tuple(chain.from_iterable(zip(choices, rewards)))
        return InfoSet(self.round, self.arms, history)

    def apply(self, action: Action):
        round = self.round
        if len(self.history) > 1 and len(self.history) % 4 == 0:
            round += 1

        return replace(self, round=round, history=self.history + (action,))


_: zs.Game[Action, InfoSet] = Learn(0, 3)

We can immediately attempt to find a Nash equilibrium for this game, but we will find the number of infosets is quite large compared to the essential complexity of the game.

In [99]:
def game():
    return Learn(10, 2)


impl = zs.OSCFR(0.1)
algo = zs.Algorithm(impl, game)

In [100]:
from tqdm import tqdm


for _ in tqdm(range(1000)):
    algo.once()

100%|██████████| 1000/1000 [00:01<00:00, 543.77it/s]


In [101]:
len(impl.strategies)

7730

We can construct an abstraction of the game by capturing only the essential information. The player how many times each arm was pulled and their average payoff. The latter value is bucketed.

In [102]:
from fractions import Fraction as F


@dataclass(slots=True, frozen=True)
class Scores:
    scores: tuple[F, ...]


def aggregate(buckets: int):
    @zs.algebraic
    def aggregate(infoset: zs.InfoSet):
        infoset = cast(InfoSet, infoset)
        means = [0.0] * infoset.arms
        ns = [0] * infoset.arms

        for choice, action in zip(infoset.history[::2], infoset.history[1::2]):
            arm = cast(Choice, choice).arm
            reward = cast(Reward, action).reward

            means[arm] = (reward + means[arm] * ns[arm]) / (ns[arm] + 1)
            ns[arm] += 1
        
        return Scores(tuple(F(round(m * buckets), buckets) for m in means))
    return aggregate


@dataclass(slots=True, frozen=True)
class Chosen:
    arms: tuple[F, ...]


def chosen(buckets: int):
    @zs.algebraic
    def chosen(infoset: zs.InfoSet):
        infoset = cast(InfoSet, infoset)

        arms = [0.0] * infoset.arms
        for choice in infoset.history[::2]:
            arm = cast(Choice, choice).arm
            arms[arm] += 1 / infoset.round

        return Chosen(tuple(F(round(m * buckets), buckets) for m in arms))
    return chosen

def actions(infoset: zs.InfoSet):
    return infoset.actions()


def abstract(buckets: int):
    @zs.abstract(Learn, aggregate(buckets) * chosen(buckets), actions)
    class Abstraction(Learn):
        ...

    return Abstraction

When the game tree is deep, even if the abstraction is significantly smaller, External Sampling MCCFR does not perform well. Indeed, each iteration is performed in $O(b^{d / 2})$ where $b$ is the branching factor and $d$ is the depth of the game tree. This is because all of the considered player's actions are explored. Outcome Sampling CFR has $O(d)$ iterations. If the abstraction used buckets together infosets at different depths, this can lead to much faster convergence.

In [103]:
abstraction = cast(type[Learn], abstract(10))

def game(rounds: int = 10, arms: int = 2):
    return abstraction(rounds, arms)


impl = zs.OSCFR(0.1)
algo = zs.Algorithm(impl, game)

In [104]:
try:
    for _ in tqdm(range(5000)):
        algo.once()
except KeyboardInterrupt:
    pass

100%|██████████| 1000/1000 [00:08<00:00, 113.01it/s]


Notice there are significantly fewer information sets ! The algorithm will run faster but still takes some time to explore all information sets adquately. MCCFR learns to play relevant situations better than irrelevant ones, as it won't explore actions with low value often.

In [105]:
n = 1000
payoff = 0

for _ in range(n):
    g = game()
    g = g.apply(Init((0, 0.1)))

    while not g.terminal:
        if g.chance:
            g = g.apply(g.sample())
            continue

        infoset = g.infoset(g.active)

        action = infoset.actions()[0]
        if infoset in impl.strategies:
            s = impl.strategies[infoset]
            action = max(s, key=s.__getitem__)

        g = g.apply(action)

    payoff += g.payoff(0)

print(payoff / n)

0.347
