Skip to content

Commit

Permalink
Use a BaseSampleSpace class.
Browse files Browse the repository at this point in the history
  • Loading branch information
chebee7i committed Apr 14, 2014
1 parent de7be43 commit fc1952a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
2 changes: 1 addition & 1 deletion dit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .npdist import Distribution

# Order does not matter for these
from .samplespace import SampleSpace, CartesianProduct
from .samplespace import ScalarSampleSpace, SampleSpace, CartesianProduct
from .distconst import *
from .helpers import copypmf

Expand Down
58 changes: 45 additions & 13 deletions dit/samplespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
experience the penalty discussed above.
"""
from .helpers import parse_rvs, get_outcome_ctor, construct_alphabets
from .helpers import (
parse_rvs, get_outcome_ctor, construct_alphabets, get_product_func
)
from .utils import OrderedDict

try:
Expand All @@ -49,20 +51,17 @@

import numpy as np

class SampleSpace(Set):
class BaseSampleSpace(Set):
"""
An abstract representation of a sample space.
A sized, iterable, container.
"""
def __init__(self, samplespace, product=product):
_meta = {}
def __init__(self, samplespace):
self._samplespace = list(samplespace)
self._length = len(samplespace)
self._product = product
self._outcome_length = len(samplespace[0])
self._outcome_class = samplespace[0].__class__
self._outcome_ctor = get_outcome_ctor(self._outcome_class)

# Store a set for O(1) lookup.
self._set = set(samplespace)
Expand All @@ -83,6 +82,37 @@ def index(self, item):
"""
return self._samplespace.index(item)

def sort(self):
self._samplespace.sort()

class ScalarSampleSpace(BaseSampleSpace):
_meta = {
'is_joint': False,
}

class SampleSpace(ScalarSampleSpace):
"""
An abstract representation of a sample space.
A sized, iterable, container.
"""
_meta = {
'is_joint': True,
}

def __init__(self, samplespace, product=None):
super(SampleSpace, self).__init__(samplespace)

self._outcome_length = len(samplespace[0])
self._outcome_class = samplespace[0].__class__
self._outcome_ctor = get_outcome_ctor(self._outcome_class)
# Since we have access to an outcome, we determine a product from it.
if product is None:
self._product = get_product_func(self._outcome_class)
else:
self._product = product

def coalesce(self, rvs, extract=False):
"""
Returns a new sample space after coalescing the specified indexes.
Expand Down Expand Up @@ -214,21 +244,20 @@ def marginalize(self, rvs):
def outcome_length(self):
return self._outcome_length

def sort(self):
self._samplespace.sort()


class CartesianProduct(SampleSpace):
"""
An abstract representation of a Cartesian product sample space.
"""
def __init__(self, alphabets, product=product):
self.alphabets = list(alphabet for alphabet in alphabets)
self.alphabets = tuple(alphabet if isinstance(alphabet, SampleSpace)
else tuple(alphabet) for alphabet in alphabets)
self._alphabet_sets = [alphabet if isinstance(alphabet, SampleSpace)
else set(alphabet) for alphabet in alphabets]

self.alphabet_sizes = tuple(len(alphabet) for alphabet in alphabets)
# Here, the user MUST specify how we take products.
# We infer the class from the specified product.
self._product = product
self._length = reduce(mul, self.alphabet_sizes)
self._outcome_length = len(self.alphabet_sizes)
Expand Down Expand Up @@ -334,8 +363,11 @@ def coalesce(self, rvs, extract=False):
return ss

def sort(self):
alphabets = []
for i, alphabet in enumerate(self.alphabets):
if isinstance(alphabet, SampleSpace):
alphabet.sort()
else:
self.alphabets[i] = tuple(sorted(alphabet))
alphabet = tuple(sorted(alphabet))
alphabets.append(alphabet)
self.alphabets = tuple(alphabets)

0 comments on commit fc1952a

Please sign in to comment.