Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Commit

Permalink
Update exception scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbrinkman committed Jun 5, 2018
1 parent 4417bad commit e449984
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 29 deletions.
3 changes: 1 addition & 2 deletions egta/innerloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ async def add_restriction(rest):
return # already explored
if agame.is_pure_restriction(rest):
# Short circuit for pure restriction
await add_deviations(rest, rest.astype(float), init_role_dev)
return
return await add_deviations(rest, rest.astype(float), init_role_dev)
data = await agame.get_restricted_game(rest)
reqa = await loop.run_in_executor(executor, functools.partial(
nash.mixed_nash, data, regret_thresh=regret_thresh,
Expand Down
31 changes: 4 additions & 27 deletions test/test_innerloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ async def test_innerloop_by_role_simple(players, strats):
async def test_innerloop_failures(players, strats, count, when):
"""Test that inner loop handles exceptions during scheduling"""
game = gamegen.game(players, strats)
sched = ExceptionScheduler(game, count, when)
sgame = schedgame.schedgame(sched)
with pytest.raises(SchedulerException):
sched = gamesched.gamesched(game)
esched = tu.ExceptionScheduler(sched, count, when)
sgame = schedgame.schedgame(esched)
with pytest.raises(tu.SchedulerException):
await innerloop.inner_loop(sgame, restricted_game_size=5)


Expand Down Expand Up @@ -163,27 +164,3 @@ async def test_at_least_one(players, strats, _):
game = gamegen.game(players, strats)
eqa = await innerloop.inner_loop(asyncgame.wrap(game), at_least_one=True)
assert eqa.size


class SchedulerException(Exception):
"""Exception to be thrown by ExceptionScheduler"""
pass


class ExceptionScheduler(gamesched._RsGameScheduler): # pylint: disable=protected-access
"""Scheduler that allows triggering exeptions on command"""

def __init__(self, game, error_after, call_type):
super().__init__(game)
self._calls = 0
self._error_after = error_after
self._call_type = call_type

async def sample_payoffs(self, profile):
self._calls += 1
if self._error_after <= self._calls and self._call_type == 'pre':
raise SchedulerException
pay = await super().sample_payoffs(profile)
if self._error_after <= self._calls and self._call_type == 'post':
raise SchedulerException
return pay
20 changes: 20 additions & 0 deletions test/test_schedgame.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for scheduler games"""
import asyncio
import pytest
import numpy as np
from gameanalysis import gamegen

from egta import gamesched
from egta import schedgame
from test import utils # pylint: disable=wrong-import-order


SIZES = [
Expand Down Expand Up @@ -106,3 +108,21 @@ async def test_random_normalize(players, strats, _):
ngame = rgame.normalize()
assert np.all(ngame.payoffs() >= -1e-7)
assert np.all(ngame.payoffs() <= 1 + 1e-7)


@pytest.mark.asyncio
@pytest.mark.parametrize('players,strats', SIZES)
@pytest.mark.parametrize('when', ['pre', 'post'])
@pytest.mark.parametrize('_', range(20))
async def test_exception(players, strats, when, _):
"""Test that exceptions are raised appropriately"""
game = gamegen.samplegame(players, strats)
sched = gamesched.samplegamesched(game)
esched = utils.ExceptionScheduler(sched, 10, when)
sgame = schedgame.schedgame(esched)
rests = np.concatenate([
game.random_restrictions(3),
np.ones((1, game.num_strats), bool)])
with pytest.raises(utils.SchedulerException):
await asyncio.gather(*[
sgame.get_restricted_game(rest) for rest in rests])
27 changes: 27 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for test utilities"""
from egta import profsched


GAMES = [
Expand All @@ -10,3 +11,29 @@
([3, 2], [2, 3]),
([1, 1, 1], [2, 2, 2]),
]


class SchedulerException(Exception):
"""Exception to be thrown by ExceptionScheduler"""
pass


class ExceptionScheduler(profsched._Scheduler): # pylint: disable=protected-access
"""Scheduler that allows triggering exeptions on command"""

def __init__(self, base, error_after, call_type):
super().__init__(
base.role_names, base.strat_names, base.num_role_players)
self._base = base
self._calls = 0
self._error_after = error_after
self._call_type = call_type

async def sample_payoffs(self, profile):
self._calls += 1
if self._error_after <= self._calls and self._call_type == 'pre':
raise SchedulerException
pay = await self._base.sample_payoffs(profile)
if self._error_after <= self._calls and self._call_type == 'post':
raise SchedulerException
return pay

0 comments on commit e449984

Please sign in to comment.