# public coreyabshire /tron

### Subversion checkout URL

You can clone with HTTPS or Subversion.

Added AIMA games library, with associated utils.

commit 5c6d73506ea0847a2c140581be874658073ca383 1 parent 6267b35
authored

Showing 2 changed files with 1,192 additions and 0 deletions.

1. games.py
2. utils.py
288  games.py
 ... ... @@ -0,0 +1,288 @@ 1 +"""Games, or Adversarial Search. (Chapters 6) 2 + 3 +""" 4 + 5 +from utils import * 6 +import random  7 + 8 +#______________________________________________________________________________ 9 +# Minimax Search 10 + 11 +def minimax_decision(state, game): 12 + """Given a state in a game, calculate the best move by searching 13 + forward all the way to the terminal states. [Fig. 6.4]""" 14 + 15 + player = game.to_move(state) 16 + 17 + def max_value(state): 18 + if game.terminal_test(state): 19 + return game.utility(state, player) 20 + v = -infinity 21 + for (a, s) in game.successors(state): 22 + v = max(v, min_value(s)) 23 + return v 24 + 25 + def min_value(state): 26 + if game.terminal_test(state): 27 + return game.utility(state, player) 28 + v = infinity 29 + for (a, s) in game.successors(state): 30 + v = min(v, max_value(s)) 31 + return v 32 + 33 + # Body of minimax_decision starts here: 34 + action, state = argmax(game.successors(state), 35 + lambda ((a, s)): min_value(s)) 36 + return action 37 + 38 + 39 +#______________________________________________________________________________ 40 +  41 +def alphabeta_full_search(state, game): 42 + """Search game to determine best action; use alpha-beta pruning. 43 + As in [Fig. 6.7], this version searches all the way to the leaves.""" 44 + 45 + player = game.to_move(state) 46 + 47 + def max_value(state, alpha, beta): 48 + if game.terminal_test(state): 49 + return game.utility(state, player) 50 + v = -infinity 51 + for (a, s) in game.successors(state): 52 + v = max(v, min_value(s, alpha, beta)) 53 + if v >= beta: 54 + return v 55 + alpha = max(alpha, v) 56 + return v 57 + 58 + def min_value(state, alpha, beta): 59 + if game.terminal_test(state): 60 + return game.utility(state, player) 61 + v = infinity 62 + for (a, s) in game.successors(state): 63 + v = min(v, max_value(s, alpha, beta)) 64 + if v <= alpha: 65 + return v 66 + beta = min(beta, v) 67 + return v 68 + 69 + # Body of alphabeta_search starts here: 70 + action, state = argmax(game.successors(state), 71 + lambda ((a, s)): min_value(s, -infinity, infinity)) 72 + return action 73 + 74 +def alphabeta_search(state, game, d=4, cutoff_test=None, eval_fn=None): 75 + """Search game to determine best action; use alpha-beta pruning. 76 + This version cuts off search and uses an evaluation function.""" 77 + 78 + player = game.to_move(state) 79 + 80 + def max_value(state, alpha, beta, depth): 81 + if cutoff_test(state, depth): 82 + return eval_fn(state) 83 + v = -infinity 84 + for (a, s) in game.successors(state): 85 + v = max(v, min_value(s, alpha, beta, depth+1)) 86 + if v >= beta: 87 + return v 88 + alpha = max(alpha, v) 89 + return v 90 + 91 + def min_value(state, alpha, beta, depth): 92 + if cutoff_test(state, depth): 93 + return eval_fn(state) 94 + v = infinity 95 + for (a, s) in game.successors(state): 96 + v = min(v, max_value(s, alpha, beta, depth+1)) 97 + if v <= alpha: 98 + return v 99 + beta = min(beta, v) 100 + return v 101 + 102 + # Body of alphabeta_search starts here: 103 + # The default test cuts off at depth d or at a terminal state 104 + cutoff_test = (cutoff_test or 105 + (lambda state,depth: depth>d or game.terminal_test(state))) 106 + eval_fn = eval_fn or (lambda state: game.utility(state, player)) 107 + action, state = argmax(game.successors(state), 108 + lambda ((a, s)): min_value(s, -infinity, infinity, 0)) 109 + return action 110 + 111 +#______________________________________________________________________________ 112 +# Players for Games 113 + 114 +def query_player(game, state): 115 + "Make a move by querying standard input." 116 + game.display(state) 117 + return num_or_str(raw_input('Your move? ')) 118 + 119 +def random_player(game, state): 120 + "A player that chooses a legal move at random." 121 + return random.choice(game.legal_moves(state)) 122 + 123 +def alphabeta_player(game, state): 124 + return alphabeta_search(state, game) 125 + 126 +def play_game(game, *players): 127 + "Play an n-person, move-alternating game." 128 + state = game.initial 129 + while True: 130 + for player in players: 131 + move = player(game, state) 132 + state = game.make_move(move, state) 133 + game.display(state) 134 + print 135 + if game.terminal_test(state): 136 + return game.utility(state, players[0]) 137 + 138 +#______________________________________________________________________________ 139 +# Some Sample Games 140 + 141 +class Game: 142 + """A game is similar to a problem, but it has a utility for each 143 + state and a terminal test instead of a path cost and a goal 144 + test. To create a game, subclass this class and implement 145 + legal_moves, make_move, utility, and terminal_test. You may 146 + override display and successors or you can inherit their default 147 + methods. You will also need to set the .initial attribute to the 148 + initial state; this can be done in the constructor.""" 149 + 150 + def legal_moves(self, state): 151 + "Return a list of the allowable moves at this point." 152 + abstract 153 + 154 + def make_move(self, move, state): 155 + "Return the state that results from making a move from a state." 156 + abstract 157 +  158 + def utility(self, state, player): 159 + "Return the value of this final state to player." 160 + abstract 161 + 162 + def terminal_test(self, state): 163 + "Return True if this is a final state for the game." 164 + return not self.legal_moves(state) 165 + 166 + def to_move(self, state): 167 + "Return the player whose move it is in this state." 168 + return state.to_move 169 + 170 + def display(self, state): 171 + "Print or otherwise display the state." 172 + print state 173 + 174 + def successors(self, state): 175 + "Return a list of legal (move, state) pairs." 176 + return [(move, self.make_move(move, state)) 177 + for move in self.legal_moves(state)] 178 + 179 + def __repr__(self): 180 + return '<%s>' % self.__class__.__name__ 181 + 182 +class Fig62Game(Game): 183 + """The game represented in [Fig. 6.2]. Serves as a simple test case. 184 + >>> g = Fig62Game() 185 + >>> minimax_decision('A', g) 186 + 'a1' 187 + >>> alphabeta_full_search('A', g) 188 + 'a1' 189 + >>> alphabeta_search('A', g) 190 + 'a1' 191 + """ 192 + succs = {'A': [('a1', 'B'), ('a2', 'C'), ('a3', 'D')], 193 + 'B': [('b1', 'B1'), ('b2', 'B2'), ('b3', 'B3')], 194 + 'C': [('c1', 'C1'), ('c2', 'C2'), ('c3', 'C3')], 195 + 'D': [('d1', 'D1'), ('d2', 'D2'), ('d3', 'D3')]} 196 + utils = Dict(B1=3, B2=12, B3=8, C1=2, C2=4, C3=6, D1=14, D2=5, D3=2) 197 + initial = 'A' 198 +  199 + def successors(self, state): 200 + return self.succs.get(state, []) 201 +  202 + def utility(self, state, player): 203 + if player == 'MAX': 204 + return self.utils[state] 205 + else: 206 + return -self.utils[state] 207 +  208 + def terminal_test(self, state): 209 + return state not in ('A', 'B', 'C', 'D') 210 + 211 + def to_move(self, state): 212 + return if_(state in 'BCD', 'MIN', 'MAX') 213 + 214 +class TicTacToe(Game): 215 + """Play TicTacToe on an h x v board, with Max (first player) playing 'X'. 216 + A state has the player to move, a cached utility, a list of moves in 217 + the form of a list of (x, y) positions, and a board, in the form of 218 + a dict of {(x, y): Player} entries, where Player is 'X' or 'O'.""" 219 + def __init__(self, h=3, v=3, k=3): 220 + update(self, h=h, v=v, k=k) 221 + moves = [(x, y) for x in range(1, h+1) 222 + for y in range(1, v+1)] 223 + self.initial = Struct(to_move='X', utility=0, board={}, moves=moves) 224 + 225 + def legal_moves(self, state): 226 + "Legal moves are any square not yet taken." 227 + return state.moves 228 + 229 + def make_move(self, move, state): 230 + if move not in state.moves: 231 + return state # Illegal move has no effect 232 + board = state.board.copy(); board[move] = state.to_move 233 + moves = list(state.moves); moves.remove(move) 234 + return Struct(to_move=if_(state.to_move == 'X', 'O', 'X'), 235 + utility=self.compute_utility(board, move, state.to_move), 236 + board=board, moves=moves) 237 + 238 + def utility(self, state, player): 239 + "Return the value to X; 1 for win, -1 for loss, 0 otherwise." 240 + return state.utility 241 + 242 + def terminal_test(self, state): 243 + "A state is terminal if it is won or there are no empty squares." 244 + return state.utility != 0 or len(state.moves) == 0 245 + 246 + def display(self, state): 247 + board = state.board 248 + for y in range(self.v, 0, -1): 249 + for x in range(1, self.h+1): 250 + print board.get((x, y), '.'), 251 + print 252 + 253 + def compute_utility(self, board, move, player): 254 + "If X wins with this move, return 1; if O return -1; else return 0." 255 + if (self.k_in_row(board, move, player, (0, 1)) or 256 + self.k_in_row(board, move, player, (1, 0)) or 257 + self.k_in_row(board, move, player, (1, -1)) or 258 + self.k_in_row(board, move, player, (1, 1))): 259 + return if_(player == 'X', +1, -1) 260 + else: 261 + return 0 262 + 263 + def k_in_row(self, board, move, player, (delta_x, delta_y)): 264 + "Return true if there is a line through move on board for player." 265 + x, y = move 266 + n = 0 # n is number of moves in row 267 + while board.get((x, y)) == player: 268 + n += 1 269 + x, y = x + delta_x, y + delta_y 270 + x, y = move 271 + while board.get((x, y)) == player: 272 + n += 1 273 + x, y = x - delta_x, y - delta_y 274 + n -= 1 # Because we counted move itself twice 275 + return n >= self.k 276 + 277 +class ConnectFour(TicTacToe): 278 + """A TicTacToe-like game in which you can only make a move on the bottom 279 + row, or in a square directly above an occupied square. Traditionally 280 + played on a 7x6 board and requiring 4 in a row.""" 281 +  282 + def __init__(self, h=7, v=6, k=4): 283 + TicTacToe.__init__(self, h, v, k) 284 + 285 + def legal_moves(self, state): 286 + "Legal moves are any square not yet taken." 287 + return [(x, y) for (x, y) in state.moves 288 + if y == 1 or (x, y-1) in state.board]
904  utils.py
 ... ... @@ -0,0 +1,904 @@ 1 +"""Provide some widely useful utilities. Safe for "from utils import *". 2 + 3 +""" 4 + 5 +from __future__ import generators 6 +import operator, math, random, copy, sys, os.path, bisect, re 7 + 8 +#______________________________________________________________________________ 9 +# Compatibility with Python 2.2, 2.3, and 2.4 10 + 11 +# The AIMA code is designed to run in Python 2.2 and up (at some point, 12 +# support for 2.2 may go away; 2.2 was released in 2001, and so is over 13 +# 6 years old). The first part of this file brings you up to 2.5 14 +# compatibility if you are running in Python 2.2 through 2.4: 15 + 16 +try: bool, True, False ## Introduced in 2.3 17 +except NameError: 18 + class bool(int): 19 + "Simple implementation of Booleans, as in PEP 285" 20 + def __init__(self, val): self.val = val 21 + def __int__(self): return self.val 22 + def __repr__(self): return ('False', 'True')[self.val] 23 + 24 + True, False = bool(1), bool(0) 25 + 26 +try: sum ## Introduced in 2.3 27 +except NameError: 28 + def sum(seq, start=0):  29 + """Sum the elements of seq. 30 + >>> sum([1, 2, 3]) 31 + 6 32 + """ 33 + return reduce(operator.add, seq, start) 34 + 35 +try: enumerate ## Introduced in 2.3 36 +except NameError: 37 + def enumerate(collection): 38 + """Return an iterator that enumerates pairs of (i, c[i]). PEP 279. 39 + >>> list(enumerate('abc')) 40 + [(0, 'a'), (1, 'b'), (2, 'c')] 41 + """ 42 + ## Copied from PEP 279 43 + i = 0 44 + it = iter(collection) 45 + while 1: 46 + yield (i, it.next()) 47 + i += 1 48 + 49 + 50 +try: reversed ## Introduced in 2.4 51 +except NameError: 52 + def reversed(seq): 53 + """Iterate over x in reverse order. 54 + >>> list(reversed([1,2,3])) 55 + [3, 2, 1] 56 + """ 57 + if hasattr(seq, 'keys'): 58 + raise ValueError("mappings do not support reverse iteration") 59 + i = len(seq) 60 + while i > 0: 61 + i -= 1 62 + yield seq[i] 63 + 64 + 65 +try: sorted ## Introduced in 2.4 66 +except NameError: 67 + def sorted(seq, cmp=None, key=None, reverse=False): 68 + """Copy seq and sort and return it. 69 + >>> sorted([3, 1, 2]) 70 + [1, 2, 3] 71 + """  72 + seq2 = copy.copy(seq) 73 + if key: 74 + if cmp == None: 75 + cmp = __builtins__.cmp 76 + seq2.sort(lambda x,y: cmp(key(x), key(y))) 77 + else: 78 + if cmp == None: 79 + seq2.sort() 80 + else: 81 + seq2.sort(cmp) 82 + if reverse:  83 + seq2.reverse()  84 + return seq2 85 + 86 +try:  87 + set, frozenset ## set builtin introduced in 2.4 88 +except NameError: 89 + try:  90 + import sets ## sets module introduced in 2.3 91 + set, frozenset = sets.Set, sets.ImmutableSet 92 + except (NameError, ImportError): 93 + class BaseSet: 94 + "set type (see http://docs.python.org/lib/types-set.html)" 95 + 96 +  97 + def __init__(self, elements=[]): 98 + self.dict = {} 99 + for e in elements: 100 + self.dict[e] = 1 101 +  102 + def __len__(self): 103 + return len(self.dict) 104 +  105 + def __iter__(self): 106 + for e in self.dict: 107 + yield e 108 +  109 + def __contains__(self, element): 110 + return element in self.dict 111 +  112 + def issubset(self, other): 113 + for e in self.dict.keys(): 114 + if e not in other: 115 + return False 116 + return True 117 + 118 + def issuperset(self, other): 119 + for e in other: 120 + if e not in self: 121 + return False 122 + return True 123 +  124 + 125 + def union(self, other): 126 + return type(self)(list(self) + list(other)) 127 +  128 + def intersection(self, other): 129 + return type(self)([e for e in self.dict if e in other]) 130 + 131 + def difference(self, other): 132 + return type(self)([e for e in self.dict if e not in other]) 133 + 134 + def symmetric_difference(self, other): 135 + return type(self)([e for e in self.dict if e not in other] + 136 + [e for e in other if e not in self.dict]) 137 + 138 + def copy(self): 139 + return type(self)(self.dict) 140 + 141 + def __repr__(self): 142 + elements = ", ".join(map(str, self.dict)) 143 + return "%s([%s])" % (type(self).__name__, elements) 144 + 145 + __le__ = issubset 146 + __ge__ = issuperset 147 + __or__ = union 148 + __and__ = intersection 149 + __sub__ = difference 150 + __xor__ = symmetric_difference 151 + 152 + class frozenset(BaseSet): 153 + "A frozenset is a BaseSet that has a hash value and is immutable." 154 + 155 + def __init__(self, elements=[]): 156 + BaseSet.__init__(elements) 157 + self.hash = 0 158 + for e in self: 159 + self.hash |= hash(e) 160 + 161 + def __hash__(self): 162 + return self.hash 163 + 164 + class set(BaseSet):  165 + "A set is a BaseSet that does not have a hash, but is mutable." 166 +  167 + def update(self, other): 168 + for e in other: 169 + self.add(e) 170 + return self 171 + 172 + def intersection_update(self, other): 173 + for e in self.dict.keys(): 174 + if e not in other: 175 + self.remove(e) 176 + return self 177 + 178 + def difference_update(self, other): 179 + for e in self.dict.keys(): 180 + if e in other: 181 + self.remove(e) 182 + return self 183 + 184 + def symmetric_difference_update(self, other): 185 + to_remove1 = [e for e in self.dict if e in other] 186 + to_remove2 = [e for e in other if e in self.dict]  187 + self.difference_update(to_remove1) 188 + self.difference_update(to_remove2) 189 + return self 190 + 191 + def add(self, element): 192 + self.dict[element] = 1 193 +  194 + def remove(self, element): 195 + del self.dict[element] 196 +  197 + def discard(self, element): 198 + if element in self.dict: 199 + del self.dict[element] 200 +  201 + def pop(self): 202 + key, val = self.dict.popitem() 203 + return key 204 +  205 + def clear(self): 206 + self.dict.clear() 207 +  208 + __ior__ = update 209 + __iand__ = intersection_update 210 + __isub__ = difference_update 211 + __ixor__ = symmetric_difference_update 212 +  213 +  214 + 215 + 216 +#______________________________________________________________________________ 217 +# Simple Data Structures: infinity, Dict, Struct 218 +  219 +infinity = 1.0e400 220 + 221 +def Dict(**entries):  222 + """Create a dict out of the argument=value arguments.  223 + >>> Dict(a=1, b=2, c=3) 224 + {'a': 1, 'c': 3, 'b': 2} 225 + """ 226 + return entries 227 + 228 +class DefaultDict(dict): 229 + """Dictionary with a default value for unknown keys.""" 230 + def __init__(self, default): 231 + self.default = default 232 + 233 + def __getitem__(self, key): 234 + if key in self: return self.get(key) 235 + return self.setdefault(key, copy.deepcopy(self.default)) 236 +  237 + def __copy__(self): 238 + copy = DefaultDict(self.default) 239 + copy.update(self) 240 + return copy 241 +  242 +class Struct: 243 + """Create an instance with argument=value slots. 244 + This is for making a lightweight object whose class doesn't matter.""" 245 + def __init__(self, **entries): 246 + self.__dict__.update(entries) 247 + 248 + def __cmp__(self, other): 249 + if isinstance(other, Struct): 250 + return cmp(self.__dict__, other.__dict__) 251 + else: 252 + return cmp(self.__dict__, other) 253 + 254 + def __repr__(self): 255 + args = ['%s=%s' % (k, repr(v)) for (k, v) in vars(self).items()] 256 + return 'Struct(%s)' % ', '.join(args) 257 + 258 +def update(x, **entries): 259 + """Update a dict; or an object with slots; according to entries. 260 + >>> update({'a': 1}, a=10, b=20) 261 + {'a': 10, 'b': 20} 262 + >>> update(Struct(a=1), a=10, b=20) 263 + Struct(a=10, b=20) 264 + """ 265 + if isinstance(x, dict): 266 + x.update(entries)  267 + else: 268 + x.__dict__.update(entries)  269 + return x  270 + 271 +#______________________________________________________________________________ 272 +# Functions on Sequences (mostly inspired by Common Lisp) 273 +# NOTE: Sequence functions (count_if, find_if, every, some) take function 274 +# argument first (like reduce, filter, and map). 275 + 276 +def removeall(item, seq): 277 + """Return a copy of seq (or string) with all occurences of item removed. 278 + >>> removeall(3, [1, 2, 3, 3, 2, 1, 3]) 279 + [1, 2, 2, 1] 280 + >>> removeall(4, [1, 2, 3]) 281 + [1, 2, 3] 282 + """ 283 + if isinstance(seq, str): 284 + return seq.replace(item, '') 285 + else: 286 + return [x for x in seq if x != item] 287 + 288 +def unique(seq): 289 + """Remove duplicate elements from seq. Assumes hashable elements. 290 + >>> unique([1, 2, 3, 2, 1]) 291 + [1, 2, 3] 292 + """ 293 + return list(set(seq)) 294 +  295 +def product(numbers): 296 + """Return the product of the numbers. 297 + >>> product([1,2,3,4]) 298 + 24 299 + """ 300 + return reduce(operator.mul, numbers, 1) 301 + 302 +def count_if(predicate, seq): 303 + """Count the number of elements of seq for which the predicate is true. 304 + >>> count_if(callable, [42, None, max, min]) 305 + 2 306 + """ 307 + f = lambda count, x: count + (not not predicate(x)) 308 + return reduce(f, seq, 0) 309 +  310 +def find_if(predicate, seq): 311 + """If there is an element of seq that satisfies predicate; return it. 312 + >>> find_if(callable, [3, min, max]) 313 +  314 + >>> find_if(callable, [1, 2, 3]) 315 + """ 316 + for x in seq: 317 + if predicate(x): return x 318 + return None 319 + 320 +def every(predicate, seq): 321 + """True if every element of seq satisfies predicate. 322 + >>> every(callable, [min, max]) 323 + 1 324 + >>> every(callable, [min, 3]) 325 + 0 326 + """ 327 + for x in seq: 328 + if not predicate(x): return False 329 + return True 330 + 331 +def some(predicate, seq): 332 + """If some element x of seq satisfies predicate(x), return predicate(x). 333 + >>> some(callable, [min, 3]) 334 + 1 335 + >>> some(callable, [2, 3]) 336 + 0 337 + """ 338 + for x in seq: 339 + px = predicate(x) 340 + if px: return px 341 + return False  342 + 343 +def isin(elt, seq): 344 + """Like (elt in seq), but compares with is, not ==. 345 + >>> e = []; isin(e, [1, e, 3]) 346 + True 347 + >>> isin(e, [1, [], 3]) 348 + False 349 + """ 350 + for x in seq: 351 + if elt is x: return True 352 + return False 353 + 354 +#______________________________________________________________________________ 355 +# Functions on sequences of numbers 356 +# NOTE: these take the sequence argument first, like min and max, 357 +# and like standard math notation: \sigma (i = 1..n) fn(i) 358 +# A lot of programing is finding the best value that satisfies some condition; 359 +# so there are three versions of argmin/argmax, depending on what you want to 360 +# do with ties: return the first one, return them all, or pick at random. 361 + 362 + 363 +def argmin(seq, fn): 364 + """Return an element with lowest fn(seq[i]) score; tie goes to first one. 365 + >>> argmin(['one', 'to', 'three'], len) 366 + 'to' 367 + """ 368 + best = seq[0]; best_score = fn(best) 369 + for x in seq: 370 + x_score = fn(x) 371 + if x_score < best_score: 372 + best, best_score = x, x_score 373 + return best 374 + 375 +def argmin_list(seq, fn): 376 + """Return a list of elements of seq[i] with the lowest fn(seq[i]) scores. 377 + >>> argmin_list(['one', 'to', 'three', 'or'], len) 378 + ['to', 'or'] 379 + """ 380 + best_score, best = fn(seq[0]), [] 381 + for x in seq: 382 + x_score = fn(x) 383 + if x_score < best_score: 384 + best, best_score = [x], x_score 385 + elif x_score == best_score: 386 + best.append(x) 387 + return best 388 + 389 +def argmin_random_tie(seq, fn): 390 + """Return an element with lowest fn(seq[i]) score; break ties at random. 391 + Thus, for all s,f: argmin_random_tie(s, f) in argmin_list(s, f)""" 392 + best_score = fn(seq[0]); n = 0 393 + for x in seq: 394 + x_score = fn(x) 395 + if x_score < best_score: 396 + best, best_score = x, x_score; n = 1 397 + elif x_score == best_score: 398 + n += 1 399 + if random.randrange(n) == 0: 400 + best = x 401 + return best 402 + 403 +def argmax(seq, fn): 404 + """Return an element with highest fn(seq[i]) score; tie goes to first one. 405 + >>> argmax(['one', 'to', 'three'], len) 406 + 'three' 407 + """ 408 + return argmin(seq, lambda x: -fn(x)) 409 + 410 +def argmax_list(seq, fn): 411 + """Return a list of elements of seq[i] with the highest fn(seq[i]) scores. 412 + >>> argmax_list(['one', 'three', 'seven'], len) 413 + ['three', 'seven'] 414 + """ 415 + return argmin_list(seq, lambda x: -fn(x)) 416 + 417 +def argmax_random_tie(seq, fn): 418 + "Return an element with highest fn(seq[i]) score; break ties at random." 419 + return argmin_random_tie(seq, lambda x: -fn(x)) 420 +#______________________________________________________________________________ 421 +# Statistical and mathematical functions 422 + 423 +def histogram(values, mode=0, bin_function=None): 424 + """Return a list of (value, count) pairs, summarizing the input values. 425 + Sorted by increasing value, or if mode=1, by decreasing count. 426 + If bin_function is given, map it over values first.""" 427 + if bin_function: values = map(bin_function, values) 428 + bins = {} 429 + for val in values: 430 + bins[val] = bins.get(val, 0) + 1 431 + if mode: 432 + return sorted(bins.items(), key=lambda x: (x[1],x[0]), reverse=True) 433 + else: 434 + return sorted(bins.items()) 435 + 436 +def log2(x): 437 + """Base 2 logarithm. 438 + >>> log2(1024) 439 + 10.0 440 + """ 441 + return math.log10(x) / math.log10(2) 442 + 443 +def mode(values): 444 + """Return the most common value in the list of values. 445 + >>> mode([1, 2, 3, 2]) 446 + 2 447 + """ 448 + return histogram(values, mode=1)[0][0] 449 + 450 +def median(values): 451 + """Return the middle value, when the values are sorted. 452 + If there are an odd number of elements, try to average the middle two. 453 + If they can't be averaged (e.g. they are strings), choose one at random. 454 + >>> median([10, 100, 11]) 455 + 11 456 + >>> median([1, 2, 3, 4]) 457 + 2.5 458 + """ 459 + n = len(values) 460 + values = sorted(values) 461 + if n % 2 == 1: 462 + return values[n/2] 463 + else: 464 + middle2 = values[(n/2)-1:(n/2)+1] 465 + try: 466 + return mean(middle2) 467 + except TypeError: 468 + return random.choice(middle2) 469 + 470 +def mean(values): 471 + """Return the arithmetic average of the values.""" 472 + return sum(values) / float(len(values)) 473 + 474 +def stddev(values, meanval=None): 475 + """The standard deviation of a set of values. 476 + Pass in the mean if you already know it.""" 477 + if meanval == None: meanval = mean(values) 478 + return math.sqrt(sum([(x - meanval)**2 for x in values]) / (len(values)-1)) 479 + 480 +def dotproduct(X, Y): 481 + """Return the sum of the element-wise product of vectors x and y. 482 + >>> dotproduct([1, 2, 3], [1000, 100, 10]) 483 + 1230 484 + """ 485 + return sum([x * y for x, y in zip(X, Y)]) 486 + 487 +def vector_add(a, b): 488 + """Component-wise addition of two vectors. 489 + >>> vector_add((0, 1), (8, 9)) 490 + (8, 10) 491 + """ 492 + return tuple(map(operator.add, a, b)) 493 + 494 +def probability(p): 495 + "Return true with probability p." 496 + return p > random.uniform(0.0, 1.0) 497 + 498 +def num_or_str(x): 499 + """The argument is a string; convert to a number if possible, or strip it. 500 + >>> num_or_str('42') 501 + 42 502 + >>> num_or_str(' 42x ') 503 + '42x' 504 + """ 505 + if isnumber(x): return x 506 + try: 507 + return int(x)  508 + except ValueError: 509 + try: 510 + return float(x)  511 + except ValueError: 512 + return str(x).strip()  513 + 514 +def normalize(numbers, total=1.0): 515 + """Multiply each number by a constant such that the sum is 1.0 (or total). 516 + >>> normalize([1,2,1]) 517 + [0.25, 0.5, 0.25] 518 + """ 519 + k = total / sum(numbers) 520 + return [k * n for n in numbers] 521 + 522 +## OK, the following are not as widely useful utilities as some of the other 523 +## functions here, but they do show up wherever we have 2D grids: Wumpus and 524 +## Vacuum worlds, TicTacToe and Checkers, and markov decision Processes. 525 + 526 +orientations = [(1,0), (0, 1), (-1, 0), (0, -1)] 527 + 528 +def turn_right(orientation): 529 + return orientations[orientations.index(orientation)-1] 530 + 531 +def turn_left(orientation): 532 + return orientations[(orientations.index(orientation)+1) % len(orientations)] 533 + 534 +def distance((ax, ay), (bx, by)): 535 + "The distance between two (x, y) points." 536 + return math.hypot((ax - bx), (ay - by)) 537 + 538 +def distance2((ax, ay), (bx, by)): 539 + "The square of the distance between two (x, y) points." 540 + return (ax - bx)**2 + (ay - by)**2 541 + 542 +def clip(vector, lowest, highest): 543 + """Return vector, except if any element is less than the corresponding 544 + value of lowest or more than the corresponding value of highest, clip to 545 + those values. 546 + >>> clip((-1, 10), (0, 0), (9, 9)) 547 + (0, 9) 548 + """ 549 + return type(vector)(map(min, map(max, vector, lowest), highest)) 550 +#______________________________________________________________________________ 551 +# Misc Functions 552 + 553 +def printf(format, *args):  554 + """Format args with the first argument as format string, and write. 555 + Return the last arg, or format itself if there are no args.""" 556 + sys.stdout.write(str(format) % args) 557 + return if_(args, args[-1], format) 558 + 559 +def caller(n=1): 560 + """Return the name of the calling function n levels up in the frame stack. 561 + >>> caller(0) 562 + 'caller' 563 + >>> def f():  564 + ... return caller() 565 + >>> f() 566 + 'f' 567 + """ 568 + import inspect 569 + return inspect.getouterframes(inspect.currentframe())[n][3] 570 + 571 +def memoize(fn, slot=None): 572 + """Memoize fn: make it remember the computed value for any argument list. 573 + If slot is specified, store result in that slot of first argument. 574 + If slot is false, store results in a dictionary.""" 575 + if slot: 576 + def memoized_fn(obj, *args): 577 + if hasattr(obj, slot): 578 + return getattr(obj, slot) 579 + else: 580 + val = fn(obj, *args) 581 + setattr(obj, slot, val) 582 + return val 583 + else: 584 + def memoized_fn(*args): 585 + if not memoized_fn.cache.has_key(args): 586 + memoized_fn.cache[args] = fn(*args) 587 + return memoized_fn.cache[args] 588 + memoized_fn.cache = {} 589 + return memoized_fn 590 + 591 +def if_(test, result, alternative): 592 + """Like C++ and Java's (test ? result : alternative), except 593 + both result and alternative are always evaluated. However, if 594 + either evaluates to a function, it is applied to the empty arglist, 595 + so you can delay execution by putting it in a lambda. 596 + >>> if_(2 + 2 == 4, 'ok', lambda: expensive_computation()) 597 + 'ok' 598 + """ 599 + if test: 600 + if callable(result): return result() 601 + return result 602 + else: 603 + if callable(alternative): return alternative() 604 + return alternative 605 + 606 +def name(object): 607 + "Try to find some reasonable name for the object." 608 + return (getattr(object, 'name', 0) or getattr(object, '__name__', 0) 609 + or getattr(getattr(object, '__class__', 0), '__name__', 0) 610 + or str(object)) 611 + 612 +def isnumber(x): 613 + "Is x a number? We say it is if it has a __int__ method." 614 + return hasattr(x, '__int__') 615 + 616 +def issequence(x): 617 + "Is x a sequence? We say it is if it has a __getitem__ method." 618 + return hasattr(x, '__getitem__') 619 + 620 +def print_table(table, header=None, sep=' ', numfmt='%g'): 621 + """Print a list of lists as a table, so that columns line up nicely. 622 + header, if specified, will be printed as the first row. 623 + numfmt is the format for all numbers; you might want e.g. '%6.2f'. 624 + (If you want different formats in differnt columns, don't use print_table.) 625 + sep is the separator between columns.""" 626 + justs = [if_(isnumber(x), 'rjust', 'ljust') for x in table[0]] 627 + if header: 628 + table = [header] + table 629 + table = [[if_(isnumber(x), lambda: numfmt % x, x) for x in row] 630 + for row in table]  631 + maxlen = lambda seq: max(map(len, seq)) 632 + sizes = map(maxlen, zip(*[map(str, row) for row in table])) 633 + for row in table: 634 + for (j, size, x) in zip(justs, sizes, row): 635 + print getattr(str(x), j)(size), sep, 636 + print 637 + 638 +def AIMAFile(components, mode='r'): 639 + "Open a file based at the AIMA root directory." 640 + import utils 641 + dir = os.path.dirname(utils.__file__) 642 + return open(apply(os.path.join, [dir] + components), mode) 643 + 644 +def DataFile(name, mode='r'): 645 + "Return a file in the AIMA /data directory." 646 + return AIMAFile(['..', 'data', name], mode) 647 + 648 + 649 +#______________________________________________________________________________ 650 +# Queues: Stack, FIFOQueue, PriorityQueue 651 + 652 +class Queue: 653 + """Queue is an abstract class/interface. There are three types: 654 + Stack(): A Last In First Out Queue. 655 + FIFOQueue(): A First In First Out Queue. 656 + PriorityQueue(lt): Queue where items are sorted by lt, (default <). 657 + Each type supports the following methods and functions: 658 + q.append(item) -- add an item to the queue 659 + q.extend(items) -- equivalent to: for item in items: q.append(item) 660 + q.pop() -- return the top item from the queue 661 + len(q) -- number of items in q (also q.__len()) 662 + Note that isinstance(Stack(), Queue) is false, because we implement stacks 663 + as lists. If Python ever gets interfaces, Queue will be an interface.""" 664 + 665 + def __init__(self):  666 + abstract 667 + 668 + def extend(self, items): 669 + for item in items: self.append(item) 670 + 671 +def Stack(): 672 + """Return an empty list, suitable as a Last-In-First-Out Queue.""" 673 + return [] 674 + 675 +class FIFOQueue(Queue): 676 + """A First-In-First-Out Queue.""" 677 + def __init__(self): 678 + self.A = []; self.start = 0 679 + def append(self, item): 680 + self.A.append(item) 681 + def __len__(self): 682 + return len(self.A) - self.start 683 + def extend(self, items): 684 + self.A.extend(items)  685 + def pop(self):  686 + e = self.A[self.start] 687 + self.start += 1 688 + if self.start > 5 and self.start > len(self.A)/2: 689 + self.A = self.A[self.start:] 690 + self.start = 0 691 + return e 692 + 693 +class PriorityQueue(Queue): 694 + """A queue in which the minimum (or maximum) element (as determined by f and 695 + order) is returned first. If order is min, the item with minimum f(x) is 696 + returned first; if order is max, then it is the item with maximum f(x).""" 697 + def __init__(self, order=min, f=lambda x: x): 698 + update(self, A=[], order=order, f=f) 699 + def append(self, item): 700 + bisect.insort(self.A, (self.f(item), item)) 701 + def __len__(self): 702 + return len(self.A) 703 + def pop(self): 704 + if self.order == min: 705 + return self.A.pop(0)[1] 706 + else: 707 + return self.A.pop()[1] 708 + 709 +## Fig: The idea is we can define things like Fig[3,10] later. 710 +## Alas, it is Fig[3,10] not Fig[3.10], because that would be the same as Fig[3.1] 711 +Fig = {}  712 + 713 +#______________________________________________________________________________ 714 +# Support for doctest 715 + 716 +def ignore(x): None 717 + 718 +def random_tests(text): 719 + """Some functions are stochastic. We want to be able to write a test 720 + with random output. We do that by ignoring the output.""" 721 + def fixup(test):  722 + if " = " in test: 723 + return ">>> " + test 724 + else: 725 + return ">>> ignore(" + test + ")" 726 + tests = re.findall(">>> (.*)", text) 727 + return '\n'.join(map(fixup, tests)) 728 + 729 +#______________________________________________________________________________ 730 + 731 +__doc__ += """ 732 +>>> d = DefaultDict(0)  733 +>>> d['x'] += 1 734 +>>> d['x'] 735 +1 736 + 737 +>>> d = DefaultDict([]) 738 +>>> d['x'] += [1] 739 +>>> d['y'] += [2] 740 +>>> d['x'] 741 +[1] 742 + 743 +>>> s = Struct(a=1, b=2) 744 +>>> s.a 745 +1 746 +>>> s.a = 3 747 +>>> s 748 +Struct(a=3, b=2) 749 +  750 +>>> def is_even(x):  751 +... return x % 2 == 0 752 +>>> sorted([1, 2, -3])  753 +[-3, 1, 2] 754 +>>> sorted(range(10), key=is_even) 755 +[1, 3, 5, 7, 9, 0, 2, 4, 6, 8] 756 +>>> sorted(range(10), lambda x,y: y-x)  757 +[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] 758 + 759 +>>> removeall(4, [])  760 +[] 761 +>>> removeall('s', 'This is a test. Was a test.')  762 +'Thi i a tet. Wa a tet.' 763 +>>> removeall('s', 'Something')  764 +'Something' 765 +>>> removeall('s', '')  766 +'' 767 + 768 +>>> list(reversed([]))  769 +[] 770 + 771 +>>> count_if(is_even, [1, 2, 3, 4])  772 +2 773 +>>> count_if(is_even, [])  774 +0 775 + 776 +>>> argmax([1], lambda x: x*x)  777 +1 778 +>>> argmin([1], lambda x: x*x)  779 +1 780 + 781 + 782 +# Test of memoize with slots in structures 783 +>>> countries = [Struct(name='united states'), Struct(name='canada')] 784 + 785 +# Pretend that 'gnp' was some big hairy operation: 786 +>>> def gnp(country):  787 +... print 'calculating gnp ...' 788 +... return len(country.name) * 1e10 789 + 790 +>>> gnp = memoize(gnp, '_gnp') 791 +>>> map(gnp, countries) 792 +calculating gnp ... 793 +calculating gnp ... 794 +[130000000000.0, 60000000000.0] 795 +>>> countries 796 +[Struct(_gnp=130000000000.0, name='united states'), Struct(_gnp=60000000000.0, name='canada')] 797 + 798 +# This time we avoid re-doing the calculation 799 +>>> map(gnp, countries)  800 +[130000000000.0, 60000000000.0] 801 + 802 +# Test Queues: 803 +>>> nums = [1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0] 804 +>>> def qtest(q):  805 +... return [q.extend(nums), [q.pop() for i in range(len(q))]][1] 806 + 807 +>>> qtest(Stack())  808 +[0, 3, 4, 99, -99, 6, 5, 7, 2, 8, 1] 809 + 810 +>>> qtest(FIFOQueue())  811 +[1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0] 812 + 813 +>>> qtest(PriorityQueue(min))  814 +[-99, 0, 1, 2, 3, 4, 5, 6, 7, 8, 99] 815 + 816 +>>> qtest(PriorityQueue(max))  817 +[99, 8, 7, 6, 5, 4, 3, 2, 1, 0, -99] 818 + 819 +>>> qtest(PriorityQueue(min, abs))  820 +[0, 1, 2, 3, 4, 5, 6, 7, 8, -99, 99] 821 + 822 +>>> qtest(PriorityQueue(max, abs))  823 +[99, -99, 8, 7, 6, 5, 4, 3, 2, 1, 0] 824 + 825 +>>> vals = [100, 110, 160, 200, 160, 110, 200, 200, 220] 826 +>>> histogram(vals)  827 +[(100, 1), (110, 2), (160, 2), (200, 3), (220, 1)] 828 +>>> histogram(vals, 1)  829 +[(200, 3), (160, 2), (110, 2), (220, 1), (100, 1)] 830 +>>> histogram(vals, 1, lambda v: round(v, -2))  831 +[(200.0, 6), (100.0, 3)] 832 + 833 +>>> log2(1.0)  834 +0.0 835 + 836 +>>> def fib(n):  837 +... return (n<=1 and 1) or (fib(n-1) + fib(n-2)) 838 + 839 +>>> fib(9) 840 +55 841 + 842 +# Now we make it faster: 843 +>>> fib = memoize(fib) 844 +>>> fib(9)  845 +55 846 + 847 +>>> q = Stack() 848 +>>> q.append(1) 849 +>>> q.append(2) 850 +>>> q.pop(), q.pop() 851 +(2, 1) 852 + 853 +>>> q = FIFOQueue() 854 +>>> q.append(1) 855 +>>> q.append(2) 856 +>>> q.pop(), q.pop() 857 +(1, 2) 858 + 859 + 860 +>>> abc = set('abc') 861 +>>> bcd = set('bcd') 862 +>>> 'a' in abc 863 +True 864 +>>> 'a' in bcd 865 +False 866 +>>> list(abc.intersection(bcd)) 867 +['c', 'b'] 868 +>>> list(abc.union(bcd)) 869 +['a', 'c', 'b', 'd'] 870 + 871 +## From "What's new in Python 2.4", but I added calls to sl 872 + 873 +>>> def sl(x): 874 +... return sorted(list(x)) 875 + 876 + 877 +>>> a = set('abracadabra') # form a set from a string 878 +>>> 'z' in a # fast membership testing 879 +False 880 +>>> sl(a) # unique letters in a 881 +['a', 'b', 'c', 'd', 'r'] 882 + 883 +>>> b = set('alacazam') # form a second set 884 +>>> sl(a - b) # letters in a but not in b 885 +['b', 'd', 'r'] 886 +>>> sl(a | b) # letters in either a or b 887 +['a', 'b', 'c', 'd', 'l', 'm', 'r', 'z'] 888 +>>> sl(a & b) # letters in both a and b 889 +['a', 'c'] 890 +>>> sl(a ^ b) # letters in a or b but not both 891 +['b', 'd', 'l', 'm', 'r', 'z'] 892 + 893 + 894 +>>> a.add('z') # add a new element 895 +>>> a.update('wxy') # add multiple new elements 896 +>>> sl(a)  897 +['a', 'b', 'c', 'd', 'r', 'w', 'x', 'y', 'z'] 898 +>>> a.remove('x') # take one element out 899 +>>> sl(a) 900 +['a', 'b', 'c', 'd', 'r', 'w', 'y', 'z'] 901 + 902 +""" 903 + 904 +