Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added AIMA games library, with associated utils.

  • Loading branch information...
commit 5c6d73506ea0847a2c140581be874658073ca383 1 parent 6267b35
Corey Abshire authored
Showing with 1,192 additions and 0 deletions.
  1. +288 −0 games.py
  2. +904 −0 utils.py
288 games.py
View
@@ -0,0 +1,288 @@
+"""Games, or Adversarial Search. (Chapters 6)
+
+"""
+
+from utils import *
+import random
+
+#______________________________________________________________________________
+# Minimax Search
+
+def minimax_decision(state, game):
+ """Given a state in a game, calculate the best move by searching
+ forward all the way to the terminal states. [Fig. 6.4]"""
+
+ player = game.to_move(state)
+
+ def max_value(state):
+ if game.terminal_test(state):
+ return game.utility(state, player)
+ v = -infinity
+ for (a, s) in game.successors(state):
+ v = max(v, min_value(s))
+ return v
+
+ def min_value(state):
+ if game.terminal_test(state):
+ return game.utility(state, player)
+ v = infinity
+ for (a, s) in game.successors(state):
+ v = min(v, max_value(s))
+ return v
+
+ # Body of minimax_decision starts here:
+ action, state = argmax(game.successors(state),
+ lambda ((a, s)): min_value(s))
+ return action
+
+
+#______________________________________________________________________________
+
+def alphabeta_full_search(state, game):
+ """Search game to determine best action; use alpha-beta pruning.
+ As in [Fig. 6.7], this version searches all the way to the leaves."""
+
+ player = game.to_move(state)
+
+ def max_value(state, alpha, beta):
+ if game.terminal_test(state):
+ return game.utility(state, player)
+ v = -infinity
+ for (a, s) in game.successors(state):
+ v = max(v, min_value(s, alpha, beta))
+ if v >= beta:
+ return v
+ alpha = max(alpha, v)
+ return v
+
+ def min_value(state, alpha, beta):
+ if game.terminal_test(state):
+ return game.utility(state, player)
+ v = infinity
+ for (a, s) in game.successors(state):
+ v = min(v, max_value(s, alpha, beta))
+ if v <= alpha:
+ return v
+ beta = min(beta, v)
+ return v
+
+ # Body of alphabeta_search starts here:
+ action, state = argmax(game.successors(state),
+ lambda ((a, s)): min_value(s, -infinity, infinity))
+ return action
+
+def alphabeta_search(state, game, d=4, cutoff_test=None, eval_fn=None):
+ """Search game to determine best action; use alpha-beta pruning.
+ This version cuts off search and uses an evaluation function."""
+
+ player = game.to_move(state)
+
+ def max_value(state, alpha, beta, depth):
+ if cutoff_test(state, depth):
+ return eval_fn(state)
+ v = -infinity
+ for (a, s) in game.successors(state):
+ v = max(v, min_value(s, alpha, beta, depth+1))
+ if v >= beta:
+ return v
+ alpha = max(alpha, v)
+ return v
+
+ def min_value(state, alpha, beta, depth):
+ if cutoff_test(state, depth):
+ return eval_fn(state)
+ v = infinity
+ for (a, s) in game.successors(state):
+ v = min(v, max_value(s, alpha, beta, depth+1))
+ if v <= alpha:
+ return v
+ beta = min(beta, v)
+ return v
+
+ # Body of alphabeta_search starts here:
+ # The default test cuts off at depth d or at a terminal state
+ cutoff_test = (cutoff_test or
+ (lambda state,depth: depth>d or game.terminal_test(state)))
+ eval_fn = eval_fn or (lambda state: game.utility(state, player))
+ action, state = argmax(game.successors(state),
+ lambda ((a, s)): min_value(s, -infinity, infinity, 0))
+ return action
+
+#______________________________________________________________________________
+# Players for Games
+
+def query_player(game, state):
+ "Make a move by querying standard input."
+ game.display(state)
+ return num_or_str(raw_input('Your move? '))
+
+def random_player(game, state):
+ "A player that chooses a legal move at random."
+ return random.choice(game.legal_moves(state))
+
+def alphabeta_player(game, state):
+ return alphabeta_search(state, game)
+
+def play_game(game, *players):
+ "Play an n-person, move-alternating game."
+ state = game.initial
+ while True:
+ for player in players:
+ move = player(game, state)
+ state = game.make_move(move, state)
+ game.display(state)
+ print
+ if game.terminal_test(state):
+ return game.utility(state, players[0])
+
+#______________________________________________________________________________
+# Some Sample Games
+
+class Game:
+ """A game is similar to a problem, but it has a utility for each
+ state and a terminal test instead of a path cost and a goal
+ test. To create a game, subclass this class and implement
+ legal_moves, make_move, utility, and terminal_test. You may
+ override display and successors or you can inherit their default
+ methods. You will also need to set the .initial attribute to the
+ initial state; this can be done in the constructor."""
+
+ def legal_moves(self, state):
+ "Return a list of the allowable moves at this point."
+ abstract
+
+ def make_move(self, move, state):
+ "Return the state that results from making a move from a state."
+ abstract
+
+ def utility(self, state, player):
+ "Return the value of this final state to player."
+ abstract
+
+ def terminal_test(self, state):
+ "Return True if this is a final state for the game."
+ return not self.legal_moves(state)
+
+ def to_move(self, state):
+ "Return the player whose move it is in this state."
+ return state.to_move
+
+ def display(self, state):
+ "Print or otherwise display the state."
+ print state
+
+ def successors(self, state):
+ "Return a list of legal (move, state) pairs."
+ return [(move, self.make_move(move, state))
+ for move in self.legal_moves(state)]
+
+ def __repr__(self):
+ return '<%s>' % self.__class__.__name__
+
+class Fig62Game(Game):
+ """The game represented in [Fig. 6.2]. Serves as a simple test case.
+ >>> g = Fig62Game()
+ >>> minimax_decision('A', g)
+ 'a1'
+ >>> alphabeta_full_search('A', g)
+ 'a1'
+ >>> alphabeta_search('A', g)
+ 'a1'
+ """
+ succs = {'A': [('a1', 'B'), ('a2', 'C'), ('a3', 'D')],
+ 'B': [('b1', 'B1'), ('b2', 'B2'), ('b3', 'B3')],
+ 'C': [('c1', 'C1'), ('c2', 'C2'), ('c3', 'C3')],
+ 'D': [('d1', 'D1'), ('d2', 'D2'), ('d3', 'D3')]}
+ utils = Dict(B1=3, B2=12, B3=8, C1=2, C2=4, C3=6, D1=14, D2=5, D3=2)
+ initial = 'A'
+
+ def successors(self, state):
+ return self.succs.get(state, [])
+
+ def utility(self, state, player):
+ if player == 'MAX':
+ return self.utils[state]
+ else:
+ return -self.utils[state]
+
+ def terminal_test(self, state):
+ return state not in ('A', 'B', 'C', 'D')
+
+ def to_move(self, state):
+ return if_(state in 'BCD', 'MIN', 'MAX')
+
+class TicTacToe(Game):
+ """Play TicTacToe on an h x v board, with Max (first player) playing 'X'.
+ A state has the player to move, a cached utility, a list of moves in
+ the form of a list of (x, y) positions, and a board, in the form of
+ a dict of {(x, y): Player} entries, where Player is 'X' or 'O'."""
+ def __init__(self, h=3, v=3, k=3):
+ update(self, h=h, v=v, k=k)
+ moves = [(x, y) for x in range(1, h+1)
+ for y in range(1, v+1)]
+ self.initial = Struct(to_move='X', utility=0, board={}, moves=moves)
+
+ def legal_moves(self, state):
+ "Legal moves are any square not yet taken."
+ return state.moves
+
+ def make_move(self, move, state):
+ if move not in state.moves:
+ return state # Illegal move has no effect
+ board = state.board.copy(); board[move] = state.to_move
+ moves = list(state.moves); moves.remove(move)
+ return Struct(to_move=if_(state.to_move == 'X', 'O', 'X'),
+ utility=self.compute_utility(board, move, state.to_move),
+ board=board, moves=moves)
+
+ def utility(self, state, player):
+ "Return the value to X; 1 for win, -1 for loss, 0 otherwise."
+ return state.utility
+
+ def terminal_test(self, state):
+ "A state is terminal if it is won or there are no empty squares."
+ return state.utility != 0 or len(state.moves) == 0
+
+ def display(self, state):
+ board = state.board
+ for y in range(self.v, 0, -1):
+ for x in range(1, self.h+1):
+ print board.get((x, y), '.'),
+ print
+
+ def compute_utility(self, board, move, player):
+ "If X wins with this move, return 1; if O return -1; else return 0."
+ if (self.k_in_row(board, move, player, (0, 1)) or
+ self.k_in_row(board, move, player, (1, 0)) or
+ self.k_in_row(board, move, player, (1, -1)) or
+ self.k_in_row(board, move, player, (1, 1))):
+ return if_(player == 'X', +1, -1)
+ else:
+ return 0
+
+ def k_in_row(self, board, move, player, (delta_x, delta_y)):
+ "Return true if there is a line through move on board for player."
+ x, y = move
+ n = 0 # n is number of moves in row
+ while board.get((x, y)) == player:
+ n += 1
+ x, y = x + delta_x, y + delta_y
+ x, y = move
+ while board.get((x, y)) == player:
+ n += 1
+ x, y = x - delta_x, y - delta_y
+ n -= 1 # Because we counted move itself twice
+ return n >= self.k
+
+class ConnectFour(TicTacToe):
+ """A TicTacToe-like game in which you can only make a move on the bottom
+ row, or in a square directly above an occupied square. Traditionally
+ played on a 7x6 board and requiring 4 in a row."""
+
+ def __init__(self, h=7, v=6, k=4):
+ TicTacToe.__init__(self, h, v, k)
+
+ def legal_moves(self, state):
+ "Legal moves are any square not yet taken."
+ return [(x, y) for (x, y) in state.moves
+ if y == 1 or (x, y-1) in state.board]
904 utils.py
View
@@ -0,0 +1,904 @@
+"""Provide some widely useful utilities. Safe for "from utils import *".
+
+"""
+
+from __future__ import generators
+import operator, math, random, copy, sys, os.path, bisect, re
+
+#______________________________________________________________________________
+# Compatibility with Python 2.2, 2.3, and 2.4
+
+# The AIMA code is designed to run in Python 2.2 and up (at some point,
+# support for 2.2 may go away; 2.2 was released in 2001, and so is over
+# 6 years old). The first part of this file brings you up to 2.5
+# compatibility if you are running in Python 2.2 through 2.4:
+
+try: bool, True, False ## Introduced in 2.3
+except NameError:
+ class bool(int):
+ "Simple implementation of Booleans, as in PEP 285"
+ def __init__(self, val): self.val = val
+ def __int__(self): return self.val
+ def __repr__(self): return ('False', 'True')[self.val]
+
+ True, False = bool(1), bool(0)
+
+try: sum ## Introduced in 2.3
+except NameError:
+ def sum(seq, start=0):
+ """Sum the elements of seq.
+ >>> sum([1, 2, 3])
+ 6
+ """
+ return reduce(operator.add, seq, start)
+
+try: enumerate ## Introduced in 2.3
+except NameError:
+ def enumerate(collection):
+ """Return an iterator that enumerates pairs of (i, c[i]). PEP 279.
+ >>> list(enumerate('abc'))
+ [(0, 'a'), (1, 'b'), (2, 'c')]
+ """
+ ## Copied from PEP 279
+ i = 0
+ it = iter(collection)
+ while 1:
+ yield (i, it.next())
+ i += 1
+
+
+try: reversed ## Introduced in 2.4
+except NameError:
+ def reversed(seq):
+ """Iterate over x in reverse order.
+ >>> list(reversed([1,2,3]))
+ [3, 2, 1]
+ """
+ if hasattr(seq, 'keys'):
+ raise ValueError("mappings do not support reverse iteration")
+ i = len(seq)
+ while i > 0:
+ i -= 1
+ yield seq[i]
+
+
+try: sorted ## Introduced in 2.4
+except NameError:
+ def sorted(seq, cmp=None, key=None, reverse=False):
+ """Copy seq and sort and return it.
+ >>> sorted([3, 1, 2])
+ [1, 2, 3]
+ """
+ seq2 = copy.copy(seq)
+ if key:
+ if cmp == None:
+ cmp = __builtins__.cmp
+ seq2.sort(lambda x,y: cmp(key(x), key(y)))
+ else:
+ if cmp == None:
+ seq2.sort()
+ else:
+ seq2.sort(cmp)
+ if reverse:
+ seq2.reverse()
+ return seq2
+
+try:
+ set, frozenset ## set builtin introduced in 2.4
+except NameError:
+ try:
+ import sets ## sets module introduced in 2.3
+ set, frozenset = sets.Set, sets.ImmutableSet
+ except (NameError, ImportError):
+ class BaseSet:
+ "set type (see http://docs.python.org/lib/types-set.html)"
+
+
+ def __init__(self, elements=[]):
+ self.dict = {}
+ for e in elements:
+ self.dict[e] = 1
+
+ def __len__(self):
+ return len(self.dict)
+
+ def __iter__(self):
+ for e in self.dict:
+ yield e
+
+ def __contains__(self, element):
+ return element in self.dict
+
+ def issubset(self, other):
+ for e in self.dict.keys():
+ if e not in other:
+ return False
+ return True
+
+ def issuperset(self, other):
+ for e in other:
+ if e not in self:
+ return False
+ return True
+
+
+ def union(self, other):
+ return type(self)(list(self) + list(other))
+
+ def intersection(self, other):
+ return type(self)([e for e in self.dict if e in other])
+
+ def difference(self, other):
+ return type(self)([e for e in self.dict if e not in other])
+
+ def symmetric_difference(self, other):
+ return type(self)([e for e in self.dict if e not in other] +
+ [e for e in other if e not in self.dict])
+
+ def copy(self):
+ return type(self)(self.dict)
+
+ def __repr__(self):
+ elements = ", ".join(map(str, self.dict))
+ return "%s([%s])" % (type(self).__name__, elements)
+
+ __le__ = issubset
+ __ge__ = issuperset
+ __or__ = union
+ __and__ = intersection
+ __sub__ = difference
+ __xor__ = symmetric_difference
+
+ class frozenset(BaseSet):
+ "A frozenset is a BaseSet that has a hash value and is immutable."
+
+ def __init__(self, elements=[]):
+ BaseSet.__init__(elements)
+ self.hash = 0
+ for e in self:
+ self.hash |= hash(e)
+
+ def __hash__(self):
+ return self.hash
+
+ class set(BaseSet):
+ "A set is a BaseSet that does not have a hash, but is mutable."
+
+ def update(self, other):
+ for e in other:
+ self.add(e)
+ return self
+
+ def intersection_update(self, other):
+ for e in self.dict.keys():
+ if e not in other:
+ self.remove(e)
+ return self
+
+ def difference_update(self, other):
+ for e in self.dict.keys():
+ if e in other:
+ self.remove(e)
+ return self
+
+ def symmetric_difference_update(self, other):
+ to_remove1 = [e for e in self.dict if e in other]
+ to_remove2 = [e for e in other if e in self.dict]
+ self.difference_update(to_remove1)
+ self.difference_update(to_remove2)
+ return self
+
+ def add(self, element):
+ self.dict[element] = 1
+
+ def remove(self, element):
+ del self.dict[element]
+
+ def discard(self, element):
+ if element in self.dict:
+ del self.dict[element]
+
+ def pop(self):
+ key, val = self.dict.popitem()
+ return key
+
+ def clear(self):
+ self.dict.clear()
+
+ __ior__ = update
+ __iand__ = intersection_update
+ __isub__ = difference_update
+ __ixor__ = symmetric_difference_update
+
+
+
+
+#______________________________________________________________________________
+# Simple Data Structures: infinity, Dict, Struct
+
+infinity = 1.0e400
+
+def Dict(**entries):
+ """Create a dict out of the argument=value arguments.
+ >>> Dict(a=1, b=2, c=3)
+ {'a': 1, 'c': 3, 'b': 2}
+ """
+ return entries
+
+class DefaultDict(dict):
+ """Dictionary with a default value for unknown keys."""
+ def __init__(self, default):
+ self.default = default
+
+ def __getitem__(self, key):
+ if key in self: return self.get(key)
+ return self.setdefault(key, copy.deepcopy(self.default))
+
+ def __copy__(self):
+ copy = DefaultDict(self.default)
+ copy.update(self)
+ return copy
+
+class Struct:
+ """Create an instance with argument=value slots.
+ This is for making a lightweight object whose class doesn't matter."""
+ def __init__(self, **entries):
+ self.__dict__.update(entries)
+
+ def __cmp__(self, other):
+ if isinstance(other, Struct):
+ return cmp(self.__dict__, other.__dict__)
+ else:
+ return cmp(self.__dict__, other)
+
+ def __repr__(self):
+ args = ['%s=%s' % (k, repr(v)) for (k, v) in vars(self).items()]
+ return 'Struct(%s)' % ', '.join(args)
+
+def update(x, **entries):
+ """Update a dict; or an object with slots; according to entries.
+ >>> update({'a': 1}, a=10, b=20)
+ {'a': 10, 'b': 20}
+ >>> update(Struct(a=1), a=10, b=20)
+ Struct(a=10, b=20)
+ """
+ if isinstance(x, dict):
+ x.update(entries)
+ else:
+ x.__dict__.update(entries)
+ return x
+
+#______________________________________________________________________________
+# Functions on Sequences (mostly inspired by Common Lisp)
+# NOTE: Sequence functions (count_if, find_if, every, some) take function
+# argument first (like reduce, filter, and map).
+
+def removeall(item, seq):
+ """Return a copy of seq (or string) with all occurences of item removed.
+ >>> removeall(3, [1, 2, 3, 3, 2, 1, 3])
+ [1, 2, 2, 1]
+ >>> removeall(4, [1, 2, 3])
+ [1, 2, 3]
+ """
+ if isinstance(seq, str):
+ return seq.replace(item, '')
+ else:
+ return [x for x in seq if x != item]
+
+def unique(seq):
+ """Remove duplicate elements from seq. Assumes hashable elements.
+ >>> unique([1, 2, 3, 2, 1])
+ [1, 2, 3]
+ """
+ return list(set(seq))
+
+def product(numbers):
+ """Return the product of the numbers.
+ >>> product([1,2,3,4])
+ 24
+ """
+ return reduce(operator.mul, numbers, 1)
+
+def count_if(predicate, seq):
+ """Count the number of elements of seq for which the predicate is true.
+ >>> count_if(callable, [42, None, max, min])
+ 2
+ """
+ f = lambda count, x: count + (not not predicate(x))
+ return reduce(f, seq, 0)
+
+def find_if(predicate, seq):
+ """If there is an element of seq that satisfies predicate; return it.
+ >>> find_if(callable, [3, min, max])
+ <built-in function min>
+ >>> find_if(callable, [1, 2, 3])
+ """
+ for x in seq:
+ if predicate(x): return x
+ return None
+
+def every(predicate, seq):
+ """True if every element of seq satisfies predicate.
+ >>> every(callable, [min, max])
+ 1
+ >>> every(callable, [min, 3])
+ 0
+ """
+ for x in seq:
+ if not predicate(x): return False
+ return True
+
+def some(predicate, seq):
+ """If some element x of seq satisfies predicate(x), return predicate(x).
+ >>> some(callable, [min, 3])
+ 1
+ >>> some(callable, [2, 3])
+ 0
+ """
+ for x in seq:
+ px = predicate(x)
+ if px: return px
+ return False
+
+def isin(elt, seq):
+ """Like (elt in seq), but compares with is, not ==.
+ >>> e = []; isin(e, [1, e, 3])
+ True
+ >>> isin(e, [1, [], 3])
+ False
+ """
+ for x in seq:
+ if elt is x: return True
+ return False
+
+#______________________________________________________________________________
+# Functions on sequences of numbers
+# NOTE: these take the sequence argument first, like min and max,
+# and like standard math notation: \sigma (i = 1..n) fn(i)
+# A lot of programing is finding the best value that satisfies some condition;
+# so there are three versions of argmin/argmax, depending on what you want to
+# do with ties: return the first one, return them all, or pick at random.
+
+
+def argmin(seq, fn):
+ """Return an element with lowest fn(seq[i]) score; tie goes to first one.
+ >>> argmin(['one', 'to', 'three'], len)
+ 'to'
+ """
+ best = seq[0]; best_score = fn(best)
+ for x in seq:
+ x_score = fn(x)
+ if x_score < best_score:
+ best, best_score = x, x_score
+ return best
+
+def argmin_list(seq, fn):
+ """Return a list of elements of seq[i] with the lowest fn(seq[i]) scores.
+ >>> argmin_list(['one', 'to', 'three', 'or'], len)
+ ['to', 'or']
+ """
+ best_score, best = fn(seq[0]), []
+ for x in seq:
+ x_score = fn(x)
+ if x_score < best_score:
+ best, best_score = [x], x_score
+ elif x_score == best_score:
+ best.append(x)
+ return best
+
+def argmin_random_tie(seq, fn):
+ """Return an element with lowest fn(seq[i]) score; break ties at random.
+ Thus, for all s,f: argmin_random_tie(s, f) in argmin_list(s, f)"""
+ best_score = fn(seq[0]); n = 0
+ for x in seq:
+ x_score = fn(x)
+ if x_score < best_score:
+ best, best_score = x, x_score; n = 1
+ elif x_score == best_score:
+ n += 1
+ if random.randrange(n) == 0:
+ best = x
+ return best
+
+def argmax(seq, fn):
+ """Return an element with highest fn(seq[i]) score; tie goes to first one.
+ >>> argmax(['one', 'to', 'three'], len)
+ 'three'
+ """
+ return argmin(seq, lambda x: -fn(x))
+
+def argmax_list(seq, fn):
+ """Return a list of elements of seq[i] with the highest fn(seq[i]) scores.
+ >>> argmax_list(['one', 'three', 'seven'], len)
+ ['three', 'seven']
+ """
+ return argmin_list(seq, lambda x: -fn(x))
+
+def argmax_random_tie(seq, fn):
+ "Return an element with highest fn(seq[i]) score; break ties at random."
+ return argmin_random_tie(seq, lambda x: -fn(x))
+#______________________________________________________________________________
+# Statistical and mathematical functions
+
+def histogram(values, mode=0, bin_function=None):
+ """Return a list of (value, count) pairs, summarizing the input values.
+ Sorted by increasing value, or if mode=1, by decreasing count.
+ If bin_function is given, map it over values first."""
+ if bin_function: values = map(bin_function, values)
+ bins = {}
+ for val in values:
+ bins[val] = bins.get(val, 0) + 1
+ if mode:
+ return sorted(bins.items(), key=lambda x: (x[1],x[0]), reverse=True)
+ else:
+ return sorted(bins.items())
+
+def log2(x):
+ """Base 2 logarithm.
+ >>> log2(1024)
+ 10.0
+ """
+ return math.log10(x) / math.log10(2)
+
+def mode(values):
+ """Return the most common value in the list of values.
+ >>> mode([1, 2, 3, 2])
+ 2
+ """
+ return histogram(values, mode=1)[0][0]
+
+def median(values):
+ """Return the middle value, when the values are sorted.
+ If there are an odd number of elements, try to average the middle two.
+ If they can't be averaged (e.g. they are strings), choose one at random.
+ >>> median([10, 100, 11])
+ 11
+ >>> median([1, 2, 3, 4])
+ 2.5
+ """
+ n = len(values)
+ values = sorted(values)
+ if n % 2 == 1:
+ return values[n/2]
+ else:
+ middle2 = values[(n/2)-1:(n/2)+1]
+ try:
+ return mean(middle2)
+ except TypeError:
+ return random.choice(middle2)
+
+def mean(values):
+ """Return the arithmetic average of the values."""
+ return sum(values) / float(len(values))
+
+def stddev(values, meanval=None):
+ """The standard deviation of a set of values.
+ Pass in the mean if you already know it."""
+ if meanval == None: meanval = mean(values)
+ return math.sqrt(sum([(x - meanval)**2 for x in values]) / (len(values)-1))
+
+def dotproduct(X, Y):
+ """Return the sum of the element-wise product of vectors x and y.
+ >>> dotproduct([1, 2, 3], [1000, 100, 10])
+ 1230
+ """
+ return sum([x * y for x, y in zip(X, Y)])
+
+def vector_add(a, b):
+ """Component-wise addition of two vectors.
+ >>> vector_add((0, 1), (8, 9))
+ (8, 10)
+ """
+ return tuple(map(operator.add, a, b))
+
+def probability(p):
+ "Return true with probability p."
+ return p > random.uniform(0.0, 1.0)
+
+def num_or_str(x):
+ """The argument is a string; convert to a number if possible, or strip it.
+ >>> num_or_str('42')
+ 42
+ >>> num_or_str(' 42x ')
+ '42x'
+ """
+ if isnumber(x): return x
+ try:
+ return int(x)
+ except ValueError:
+ try:
+ return float(x)
+ except ValueError:
+ return str(x).strip()
+
+def normalize(numbers, total=1.0):
+ """Multiply each number by a constant such that the sum is 1.0 (or total).
+ >>> normalize([1,2,1])
+ [0.25, 0.5, 0.25]
+ """
+ k = total / sum(numbers)
+ return [k * n for n in numbers]
+
+## OK, the following are not as widely useful utilities as some of the other
+## functions here, but they do show up wherever we have 2D grids: Wumpus and
+## Vacuum worlds, TicTacToe and Checkers, and markov decision Processes.
+
+orientations = [(1,0), (0, 1), (-1, 0), (0, -1)]
+
+def turn_right(orientation):
+ return orientations[orientations.index(orientation)-1]
+
+def turn_left(orientation):
+ return orientations[(orientations.index(orientation)+1) % len(orientations)]
+
+def distance((ax, ay), (bx, by)):
+ "The distance between two (x, y) points."
+ return math.hypot((ax - bx), (ay - by))
+
+def distance2((ax, ay), (bx, by)):
+ "The square of the distance between two (x, y) points."
+ return (ax - bx)**2 + (ay - by)**2
+
+def clip(vector, lowest, highest):
+ """Return vector, except if any element is less than the corresponding
+ value of lowest or more than the corresponding value of highest, clip to
+ those values.
+ >>> clip((-1, 10), (0, 0), (9, 9))
+ (0, 9)
+ """
+ return type(vector)(map(min, map(max, vector, lowest), highest))
+#______________________________________________________________________________
+# Misc Functions
+
+def printf(format, *args):
+ """Format args with the first argument as format string, and write.
+ Return the last arg, or format itself if there are no args."""
+ sys.stdout.write(str(format) % args)
+ return if_(args, args[-1], format)
+
+def caller(n=1):
+ """Return the name of the calling function n levels up in the frame stack.
+ >>> caller(0)
+ 'caller'
+ >>> def f():
+ ... return caller()
+ >>> f()
+ 'f'
+ """
+ import inspect
+ return inspect.getouterframes(inspect.currentframe())[n][3]
+
+def memoize(fn, slot=None):
+ """Memoize fn: make it remember the computed value for any argument list.
+ If slot is specified, store result in that slot of first argument.
+ If slot is false, store results in a dictionary."""
+ if slot:
+ def memoized_fn(obj, *args):
+ if hasattr(obj, slot):
+ return getattr(obj, slot)
+ else:
+ val = fn(obj, *args)
+ setattr(obj, slot, val)
+ return val
+ else:
+ def memoized_fn(*args):
+ if not memoized_fn.cache.has_key(args):
+ memoized_fn.cache[args] = fn(*args)
+ return memoized_fn.cache[args]
+ memoized_fn.cache = {}
+ return memoized_fn
+
+def if_(test, result, alternative):
+ """Like C++ and Java's (test ? result : alternative), except
+ both result and alternative are always evaluated. However, if
+ either evaluates to a function, it is applied to the empty arglist,
+ so you can delay execution by putting it in a lambda.
+ >>> if_(2 + 2 == 4, 'ok', lambda: expensive_computation())
+ 'ok'
+ """
+ if test:
+ if callable(result): return result()
+ return result
+ else:
+ if callable(alternative): return alternative()
+ return alternative
+
+def name(object):
+ "Try to find some reasonable name for the object."
+ return (getattr(object, 'name', 0) or getattr(object, '__name__', 0)
+ or getattr(getattr(object, '__class__', 0), '__name__', 0)
+ or str(object))
+
+def isnumber(x):
+ "Is x a number? We say it is if it has a __int__ method."
+ return hasattr(x, '__int__')
+
+def issequence(x):
+ "Is x a sequence? We say it is if it has a __getitem__ method."
+ return hasattr(x, '__getitem__')
+
+def print_table(table, header=None, sep=' ', numfmt='%g'):
+ """Print a list of lists as a table, so that columns line up nicely.
+ header, if specified, will be printed as the first row.
+ numfmt is the format for all numbers; you might want e.g. '%6.2f'.
+ (If you want different formats in differnt columns, don't use print_table.)
+ sep is the separator between columns."""
+ justs = [if_(isnumber(x), 'rjust', 'ljust') for x in table[0]]
+ if header:
+ table = [header] + table
+ table = [[if_(isnumber(x), lambda: numfmt % x, x) for x in row]
+ for row in table]
+ maxlen = lambda seq: max(map(len, seq))
+ sizes = map(maxlen, zip(*[map(str, row) for row in table]))
+ for row in table:
+ for (j, size, x) in zip(justs, sizes, row):
+ print getattr(str(x), j)(size), sep,
+ print
+
+def AIMAFile(components, mode='r'):
+ "Open a file based at the AIMA root directory."
+ import utils
+ dir = os.path.dirname(utils.__file__)
+ return open(apply(os.path.join, [dir] + components), mode)
+
+def DataFile(name, mode='r'):
+ "Return a file in the AIMA /data directory."
+ return AIMAFile(['..', 'data', name], mode)
+
+
+#______________________________________________________________________________
+# Queues: Stack, FIFOQueue, PriorityQueue
+
+class Queue:
+ """Queue is an abstract class/interface. There are three types:
+ Stack(): A Last In First Out Queue.
+ FIFOQueue(): A First In First Out Queue.
+ PriorityQueue(lt): Queue where items are sorted by lt, (default <).
+ Each type supports the following methods and functions:
+ q.append(item) -- add an item to the queue
+ q.extend(items) -- equivalent to: for item in items: q.append(item)
+ q.pop() -- return the top item from the queue
+ len(q) -- number of items in q (also q.__len())
+ Note that isinstance(Stack(), Queue) is false, because we implement stacks
+ as lists. If Python ever gets interfaces, Queue will be an interface."""
+
+ def __init__(self):
+ abstract
+
+ def extend(self, items):
+ for item in items: self.append(item)
+
+def Stack():
+ """Return an empty list, suitable as a Last-In-First-Out Queue."""
+ return []
+
+class FIFOQueue(Queue):
+ """A First-In-First-Out Queue."""
+ def __init__(self):
+ self.A = []; self.start = 0
+ def append(self, item):
+ self.A.append(item)
+ def __len__(self):
+ return len(self.A) - self.start
+ def extend(self, items):
+ self.A.extend(items)
+ def pop(self):
+ e = self.A[self.start]
+ self.start += 1
+ if self.start > 5 and self.start > len(self.A)/2:
+ self.A = self.A[self.start:]
+ self.start = 0
+ return e
+
+class PriorityQueue(Queue):
+ """A queue in which the minimum (or maximum) element (as determined by f and
+ order) is returned first. If order is min, the item with minimum f(x) is
+ returned first; if order is max, then it is the item with maximum f(x)."""
+ def __init__(self, order=min, f=lambda x: x):
+ update(self, A=[], order=order, f=f)
+ def append(self, item):
+ bisect.insort(self.A, (self.f(item), item))
+ def __len__(self):
+ return len(self.A)
+ def pop(self):
+ if self.order == min:
+ return self.A.pop(0)[1]
+ else:
+ return self.A.pop()[1]
+
+## Fig: The idea is we can define things like Fig[3,10] later.
+## Alas, it is Fig[3,10] not Fig[3.10], because that would be the same as Fig[3.1]
+Fig = {}
+
+#______________________________________________________________________________
+# Support for doctest
+
+def ignore(x): None
+
+def random_tests(text):
+ """Some functions are stochastic. We want to be able to write a test
+ with random output. We do that by ignoring the output."""
+ def fixup(test):
+ if " = " in test:
+ return ">>> " + test
+ else:
+ return ">>> ignore(" + test + ")"
+ tests = re.findall(">>> (.*)", text)
+ return '\n'.join(map(fixup, tests))
+
+#______________________________________________________________________________
+
+__doc__ += """
+>>> d = DefaultDict(0)
+>>> d['x'] += 1
+>>> d['x']
+1
+
+>>> d = DefaultDict([])
+>>> d['x'] += [1]
+>>> d['y'] += [2]
+>>> d['x']
+[1]
+
+>>> s = Struct(a=1, b=2)
+>>> s.a
+1
+>>> s.a = 3
+>>> s
+Struct(a=3, b=2)
+
+>>> def is_even(x):
+... return x % 2 == 0
+>>> sorted([1, 2, -3])
+[-3, 1, 2]
+>>> sorted(range(10), key=is_even)
+[1, 3, 5, 7, 9, 0, 2, 4, 6, 8]
+>>> sorted(range(10), lambda x,y: y-x)
+[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
+
+>>> removeall(4, [])
+[]
+>>> removeall('s', 'This is a test. Was a test.')
+'Thi i a tet. Wa a tet.'
+>>> removeall('s', 'Something')
+'Something'
+>>> removeall('s', '')
+''
+
+>>> list(reversed([]))
+[]
+
+>>> count_if(is_even, [1, 2, 3, 4])
+2
+>>> count_if(is_even, [])
+0
+
+>>> argmax([1], lambda x: x*x)
+1
+>>> argmin([1], lambda x: x*x)
+1
+
+
+# Test of memoize with slots in structures
+>>> countries = [Struct(name='united states'), Struct(name='canada')]
+
+# Pretend that 'gnp' was some big hairy operation:
+>>> def gnp(country):
+... print 'calculating gnp ...'
+... return len(country.name) * 1e10
+
+>>> gnp = memoize(gnp, '_gnp')
+>>> map(gnp, countries)
+calculating gnp ...
+calculating gnp ...
+[130000000000.0, 60000000000.0]
+>>> countries
+[Struct(_gnp=130000000000.0, name='united states'), Struct(_gnp=60000000000.0, name='canada')]
+
+# This time we avoid re-doing the calculation
+>>> map(gnp, countries)
+[130000000000.0, 60000000000.0]
+
+# Test Queues:
+>>> nums = [1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0]
+>>> def qtest(q):
+... return [q.extend(nums), [q.pop() for i in range(len(q))]][1]
+
+>>> qtest(Stack())
+[0, 3, 4, 99, -99, 6, 5, 7, 2, 8, 1]
+
+>>> qtest(FIFOQueue())
+[1, 8, 2, 7, 5, 6, -99, 99, 4, 3, 0]
+
+>>> qtest(PriorityQueue(min))
+[-99, 0, 1, 2, 3, 4, 5, 6, 7, 8, 99]
+
+>>> qtest(PriorityQueue(max))
+[99, 8, 7, 6, 5, 4, 3, 2, 1, 0, -99]
+
+>>> qtest(PriorityQueue(min, abs))
+[0, 1, 2, 3, 4, 5, 6, 7, 8, -99, 99]
+
+>>> qtest(PriorityQueue(max, abs))
+[99, -99, 8, 7, 6, 5, 4, 3, 2, 1, 0]
+
+>>> vals = [100, 110, 160, 200, 160, 110, 200, 200, 220]
+>>> histogram(vals)
+[(100, 1), (110, 2), (160, 2), (200, 3), (220, 1)]
+>>> histogram(vals, 1)
+[(200, 3), (160, 2), (110, 2), (220, 1), (100, 1)]
+>>> histogram(vals, 1, lambda v: round(v, -2))
+[(200.0, 6), (100.0, 3)]
+
+>>> log2(1.0)
+0.0
+
+>>> def fib(n):
+... return (n<=1 and 1) or (fib(n-1) + fib(n-2))
+
+>>> fib(9)
+55
+
+# Now we make it faster:
+>>> fib = memoize(fib)
+>>> fib(9)
+55
+
+>>> q = Stack()
+>>> q.append(1)
+>>> q.append(2)
+>>> q.pop(), q.pop()
+(2, 1)
+
+>>> q = FIFOQueue()
+>>> q.append(1)
+>>> q.append(2)
+>>> q.pop(), q.pop()
+(1, 2)
+
+
+>>> abc = set('abc')
+>>> bcd = set('bcd')
+>>> 'a' in abc
+True
+>>> 'a' in bcd
+False
+>>> list(abc.intersection(bcd))
+['c', 'b']
+>>> list(abc.union(bcd))
+['a', 'c', 'b', 'd']
+
+## From "What's new in Python 2.4", but I added calls to sl
+
+>>> def sl(x):
+... return sorted(list(x))
+
+
+>>> a = set('abracadabra') # form a set from a string
+>>> 'z' in a # fast membership testing
+False
+>>> sl(a) # unique letters in a
+['a', 'b', 'c', 'd', 'r']
+
+>>> b = set('alacazam') # form a second set
+>>> sl(a - b) # letters in a but not in b
+['b', 'd', 'r']
+>>> sl(a | b) # letters in either a or b
+['a', 'b', 'c', 'd', 'l', 'm', 'r', 'z']
+>>> sl(a & b) # letters in both a and b
+['a', 'c']
+>>> sl(a ^ b) # letters in a or b but not both
+['b', 'd', 'l', 'm', 'r', 'z']
+
+
+>>> a.add('z') # add a new element
+>>> a.update('wxy') # add multiple new elements
+>>> sl(a)
+['a', 'b', 'c', 'd', 'r', 'w', 'x', 'y', 'z']
+>>> a.remove('x') # take one element out
+>>> sl(a)
+['a', 'b', 'c', 'd', 'r', 'w', 'y', 'z']
+
+"""
+
+
Please sign in to comment.
Something went wrong with that request. Please try again.