# Learning to learn

We explore the bandit problem by representing it as an imperfect information game and solving it with a counterfactual regret minimization algorithm.

In [199]:
import zerosum as zs

from typing import cast
from dataclasses import dataclass, replace
from itertools import islice, chain
import random

## The bandit game

The bandit player chooses from a set of arms at each round, and receives a reward from the arm chosen. The reward is drawn from a distribution that depends on the arm chosen. The player's goal is to maximize their reward. The number of rounds is fixed before-hand.

Our framework can currently only handle zero-sum two player games. We therefore study an alternative problem where two players play the game simultaneously. They do not share information. The payoff of each player is adjusted by the payoff of their opponent.

In [200]:
@dataclass(slots=True, frozen=True)
class Init:
    means: tuple[float, ...]


@dataclass(slots=True, frozen=True)
class Choice:
    arm: int


@dataclass(slots=True, frozen=True)
class Reward:
    reward: float


Action = Init | Reward | Choice


@dataclass(slots=True, frozen=True)
class InfoSet:
    player: zs.Player
    round: int
    rounds: int
    arms: int
    history: tuple[Action, ...]

    def actions(self):
        return tuple(Choice(i) for i in range(self.arms))


@dataclass(slots=True, frozen=True)
class Learn:
    rounds: int
    arms: int

    round: int = 0
    history: tuple[Action, ...] = ()

    @property
    def means(self):
        return cast(Init, self.history[0]).means

    @property
    def terminal(self):
        return self.round >= self.rounds

    def payoff(self, player: zs.Player):
        s0 = islice(self.history, 2, None, 4)
        s1 = islice(self.history, 4, None, 4)

        sco = 0
        for a0, a1 in zip(s0, s1):
            a0 = cast(Reward, a0)
            a1 = cast(Reward, a1)
            sco += a0.reward - a1.reward
        
        return sco if player == 0 else -sco

    @property
    def chance(self):
        return len(self.history) % 2 == 0

    def chances(self):
        raise NotImplementedError
    
    def _init(self):
        means = tuple(random.random() for _ in range(self.arms))
        return Init(means)
    
    # overriding this method in a subclass is an easy method to create
    # a bandit game with different reward distributions

    def sample(self) -> Action:
        if len(self.history) == 0:
            return self._init()

        *_, action = self.history
        action = cast(Choice, action)

        if random.random() < self.means[action.arm]:
            return Reward(1)

        return Reward(0)

    @property
    def active(self) -> zs.Player:
        if (len(self.history) - 1) % 4 == 0:
            return zs.Player(0)
        return zs.Player(1)

    def infoset(self, player: zs.Player):
        choices, rewards = self.history[1::4], self.history[2::4]
        if player == 1:
            choices, rewards = self.history[3::4], self.history[4::4]

        history = tuple(chain.from_iterable(zip(choices, rewards)))
        return InfoSet(player, self.round, self.rounds, self.arms, history)

    def apply(self, action: Action):
        round = self.round
        if len(self.history) > 1 and len(self.history) % 4 == 0:
            round += 1

        return replace(self, round=round, history=self.history + (action,))


_: zs.Game[Action, InfoSet] = Learn(0, 3)

We can immediately attempt to find a Nash equilibrium for this game, but we will find the number of infosets is quite large compared to the essential complexity of the game.

In [201]:
def game():
    return Learn(5, 2)


impl = zs.ESCFR()
algo = zs.Algorithm(impl, game)

In [202]:
g = game()

g = g.apply(g.sample())
g = g.apply(Choice(0))
g = g.apply(g.sample())
g = g.apply(Choice(0))
g = g.apply(g.sample())
g

Learn(rounds=5, arms=2, round=1, history=(Init(means=(0.30513534175980606, 0.8990895618169512)), Choice(arm=0), Reward(reward=1), Choice(arm=0), Reward(reward=1)))

In [204]:
from tqdm import tqdm


for _ in tqdm(range(500)):
    algo.once()

100%|██████████| 500/500 [00:02<00:00, 231.70it/s]


In [205]:
len(impl.strategies)

682

In [253]:
def evaluate(n, game, strategies):
    payoff = 0
    for _ in tqdm(range(n)):
        g = game()

        bef = False
        while not g.terminal:
            if g.chance:
                action = g.sample()
                if bef and isinstance(action, Reward):
                    payoff += action.reward
                    bef = False

                g = g.apply(action)
                continue

            infoset = g.infoset(g.active)

            action = infoset.actions()[0]
            if infoset in impl.strategies:
                s = impl.strategies[infoset]

                try:
                    # action, = random.choices(tuple(s), weights=tuple(s.values()))
                    action = max(s, key=s.get)
                except ValueError:
                    action = random.choice(infoset.actions())

            bef = (g.active == 0)
            g = g.apply(action)

    return payoff / n

In [207]:
def subg():
    return game().apply(Init((0.2, 0.8)))

In [208]:
evaluate(1000, subg, impl.strategies)

100%|██████████| 1000/1000 [00:00<00:00, 6579.60it/s]


2.932

We can construct an abstraction of the game by capturing only the essential information. The player how many times each arm was pulled and their average payoff. The latter value is bucketed.

In [209]:
@dataclass(slots=True, frozen=True)
class Player:
    _: int


@zs.algebraic
def player(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)
    return Player(infoset.player)


@dataclass(slots=True, frozen=True)
class Round:
    _: int


@zs.algebraic
def whichround(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)
    return Round(infoset.round)


@dataclass(slots=True, frozen=True)
class Scores:
    scores: tuple[int, ...]


@zs.algebraic
def aggregate(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)
    ns = [0] * infoset.arms

    for choice, action in zip(infoset.history[::2], infoset.history[1::2]):
        arm = cast(Choice, choice).arm
        reward = cast(Reward, action).reward
        ns[arm] += int(reward)
    
    return Scores(tuple(ns))


@dataclass(slots=True, frozen=True)
class Chosen:
    arms: tuple[int, ...]


@zs.algebraic
def chosen(infoset: zs.InfoSet):
    infoset = cast(InfoSet, infoset)

    arms = [0] * infoset.arms
    for choice in infoset.history[::2]:
        arm = cast(Choice, choice).arm
        arms[arm] += 1
    
    return Chosen(tuple(arms))


def actions(infoset: zs.InfoSet):
    return infoset.actions()


def abstract(buckets: int):
    @zs.abstract(Learn, player * aggregate * chosen * whichround, actions)
    class Abstraction(Learn):
        ...

    return Abstraction

When the game tree is deep, even if the abstraction is significantly smaller, External Sampling MCCFR does not perform well. Indeed, each iteration is performed in $O(b^{d / 2})$ where $b$ is the branching factor and $d$ is the depth of the game tree. This is because all of the considered player's actions are explored. Outcome Sampling CFR has $O(d)$ iterations. If the abstraction used buckets together infosets at different depths, this can lead to much faster convergence.

In [216]:
abstraction = cast(type[Learn], abstract(4))

def game(rounds: int = 10, arms: int = 2):
    return abstraction(rounds, arms)


impl = zs.ESLCFR(1000)
algo = zs.Algorithm(impl, game)

In [217]:
g = game()

g = g.apply(g.sample())
g = g.apply(Choice(0))
g = g.apply(g.sample())
g = g.apply(Choice(0))
g = g.apply(g.sample())
g.infoset(0)

(Player(_=0), Scores(scores=(1, 0)), Chosen(arms=(1, 0)), Round(_=1))

In [255]:
try:
    for _ in tqdm(range(1000)):
        algo.once()
except KeyboardInterrupt:
    pass

  7%|▋         | 69/1000 [00:40<09:05,  1.71it/s]


Notice there are significantly fewer information sets !

In [235]:
len(impl.strategies)

1430

In [256]:
def subg():
    return game().apply(Init((0.2, 0.8)))

evaluate(1000, game, impl.strategies)

100%|██████████| 1000/1000 [00:01<00:00, 959.83it/s]


5.941

Thomspson sampling remains stronger.

In [239]:
import numpy as np


def thompson(means, n):
    k = len(means)
    priors = np.ones((k, 2))

    sampled = np.empty(k)
    payoff = 0

    for _ in range(n):
        for i in range(k):
            sampled[i] = np.random.beta(*priors[i])
        
        arm = np.argmax(sampled)
        r = np.random.choice((0, 1), p=(1 - means[arm], means[arm]))
        priors[arm] += (r, 1 - r)
        payoff += r
    
    return payoff

In [252]:
p = 0
for _ in range(1000):
    p += thompson((0.2, 0.8), 10)

p / 1000

6.609